Skip to content

Commit

Permalink
implement constraining Accessors
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Mar 3, 2021
1 parent 745cc1b commit a567f2b
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 11 deletions.
102 changes: 92 additions & 10 deletions include/alpaka/mem/view/Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ namespace alpaka
{
using ReturnType = AccessReturnType<Pointer, Elem, AccessModes>;

ALPAKA_FN_ACC Accessor(Pointer p_, Vec<DimInt<1>, BufferIdx> extents_) : p(p_), extents(extents_)
{
}

template<typename OtherAccessModes>
ALPAKA_FN_ACC Accessor(const Accessor<Pointer, Elem, BufferIdx, 1, OtherAccessModes>& other)
: p(other.p)
, extents(other.extents)
{
}

ALPAKA_FN_ACC auto operator[](Vec<DimInt<1>, BufferIdx> i) const -> ReturnType
{
return (*this)(i[0]);
Expand All @@ -127,6 +138,21 @@ namespace alpaka
{
using ReturnType = AccessReturnType<Pointer, Elem, AccessModes>;

ALPAKA_FN_ACC Accessor(Pointer p_, BufferIdx rowPitchInBytes_, Vec<DimInt<2>, BufferIdx> extents_)
: p(p_)
, rowPitchInBytes(rowPitchInBytes_)
, extents(extents_)
{
}

template<typename OtherAccessModes>
ALPAKA_FN_ACC Accessor(const Accessor<Pointer, Elem, BufferIdx, 2, OtherAccessModes>& other)
: p(other.p)
, rowPitchInBytes(other.rowPitchInBytes)
, extents(other.extents)
{
}

ALPAKA_FN_ACC auto operator[](Vec<DimInt<2>, BufferIdx> i) const -> ReturnType
{
return (*this)(i[0], i[1]);
Expand Down Expand Up @@ -154,6 +180,27 @@ namespace alpaka
{
using ReturnType = AccessReturnType<Pointer, Elem, AccessModes>;

ALPAKA_FN_ACC Accessor(
Pointer p_,
BufferIdx slicePitchInBytes_,
BufferIdx rowPitchInBytes_,
Vec<DimInt<3>, BufferIdx> extents_)
: p(p_)
, slicePitchInBytes(slicePitchInBytes_)
, rowPitchInBytes(rowPitchInBytes_)
, extents(extents_)
{
}

template<typename OtherAccessModes>
ALPAKA_FN_ACC Accessor(const Accessor<Pointer, Elem, BufferIdx, 3, OtherAccessModes>& other)
: p(other.p)
, slicePitchInBytes(other.slicePitchInBytes)
, rowPitchInBytes(other.rowPitchInBytes)
, extents(other.extents)
{
}

ALPAKA_FN_ACC auto operator[](Vec<DimInt<3>, BufferIdx> i) const -> ReturnType
{
return (*this)(i[0], i[1], i[2]);
Expand Down Expand Up @@ -286,9 +333,18 @@ namespace alpaka
getPitchBytes<PitchIs + 1>(buffer)...,
{extent::getExtent<ExtentIs>(buffer)...}};
}

template<typename T>
constexpr bool isAccessor = false;

template<typename MemoryHandle, typename Elem, typename BufferIdx, std::size_t Dim, typename AccessModes>
constexpr bool isAccessor<Accessor<MemoryHandle, Elem, BufferIdx, Dim, AccessModes>> = true;
} // namespace internal

template<typename... AccessModes, typename Buf>
template<
typename... AccessModes,
typename Buf,
typename = std::enable_if_t<!internal::isAccessor<std::decay_t<Buf>>>>
auto accessWith(Buf&& buffer)
{
using Dim = Dim<std::decay_t<Buf>>;
Expand All @@ -298,21 +354,47 @@ namespace alpaka
std::make_index_sequence<Dim::value>{});
}

template<typename Buf>
auto access(Buf&& buffer)
// TODO: currently only allows constraining down to 1 access mode
template<
typename NewAccessMode,
typename MemoryHandle,
typename Elem,
typename BufferIdx,
std::size_t Dim,
typename... PrevAccessModesBefore,
typename... PrevAccessModesAfter>
auto accessWith(Accessor<
MemoryHandle,
Elem,
BufferIdx,
Dim,
std::tuple<PrevAccessModesBefore..., NewAccessMode, PrevAccessModesAfter...>>&& acc)
{
return Accessor<MemoryHandle, Elem, BufferIdx, Dim, NewAccessMode>{acc};
}

// constraining accessor to the same access mode again just passes through
template<typename AccessMode, typename MemoryHandle, typename Elem, typename BufferIdx, std::size_t Dim>
auto accessWith(Accessor<MemoryHandle, Elem, BufferIdx, Dim, AccessMode>&& acc)
{
return acc;
}

template<typename BufOrAcc>
auto access(BufOrAcc&& bufOrAcc)
{
return accessWith<WriteAccess, ReadAccess>(std::forward<Buf>(buffer));
return accessWith<ReadWriteAccess>(std::forward<BufOrAcc>(bufOrAcc));
}

template<typename Buf>
auto readAccess(Buf&& buffer)
template<typename BufOrAcc>
auto readAccess(BufOrAcc&& bufOrAcc)
{
return accessWith<ReadAccess>(std::forward<Buf>(buffer));
return accessWith<ReadAccess>(std::forward<BufOrAcc>(bufOrAcc));
}

template<typename Buf>
auto writeAccess(Buf&& buffer)
template<typename BufOrAcc>
auto writeAccess(BufOrAcc&& bufOrAcc)
{
return accessWith<WriteAccess>(std::forward<Buf>(buffer));
return accessWith<WriteAccess>(std::forward<BufOrAcc>(bufOrAcc));
}
} // namespace alpaka
22 changes: 21 additions & 1 deletion test/unit/mem/view/src/Accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,4 +267,24 @@ TEST_CASE("projection", "[accessor]")
alpaka::memcpy(queue, host, dstBuffer, 1);

REQUIRE(host[0] == 84);
}
}

TEST_CASE("constraining", "[accessor]")
{
using Dim = alpaka::DimInt<1>;
using Size = std::size_t;
using Acc = alpaka::ExampleDefaultAcc<Dim, Size>;

auto const devAcc = alpaka::getDevByIdx<Acc>(0u);
auto buffer = alpaka::allocBuf<int, Size>(devAcc, Size{1});

alpaka::Accessor<int*, int, Size, 1, std::tuple<alpaka::ReadAccess, alpaka::WriteAccess, alpaka::ReadWriteAccess>>
acc = alpaka::accessWith<alpaka::ReadAccess, alpaka::WriteAccess, alpaka::ReadWriteAccess>(buffer);

alpaka::Accessor<int*, int, Size, 1, alpaka::ReadAccess> readAcc = alpaka::readAccess(acc);
alpaka::Accessor<int*, int, Size, 1, alpaka::WriteAccess> writeAcc = alpaka::writeAccess(acc);
alpaka::Accessor<int*, int, Size, 1, alpaka::ReadWriteAccess> readWriteAcc = alpaka::access(acc);
(void) readAcc;
(void) writeAcc;
(void) readWriteAcc;
}

0 comments on commit a567f2b

Please sign in to comment.