Skip to content

Commit

Permalink
Merge pull request #395 from dlangbe/develop
Browse files Browse the repository at this point in the history
Fixed compiler warnings
  • Loading branch information
dlangbe committed Apr 30, 2024
2 parents ccecf5b + 8da918d commit f358adf
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 37 deletions.
6 changes: 1 addition & 5 deletions library/include/rocwmma/internal/coop_load.hpp
Expand Up @@ -63,10 +63,7 @@ namespace rocwmma

// Outer loop = index 0,
// Inner loop = index N-1
template <size_t Depth = 0,
typename Iterator,
typename StrideSpace,
typename Strides2d>
template <size_t Depth = 0, typename Iterator, typename StrideSpace, typename Strides2d>
ROCWMMA_DEVICE static inline auto unroll_right(Iterator& out,
DataT const* dataPtr,
uint32_t ldm,
Expand All @@ -93,7 +90,6 @@ namespace rocwmma
// Recurse to the next nested layer
else
{
#pragma unroll
for(int i = 0; i < strideCount; i++)
{
unroll_right<Depth + 1>(out, dataPtr, ldm, strideSpace, strides2d);
Expand Down
6 changes: 1 addition & 5 deletions library/include/rocwmma/internal/coop_store.hpp
Expand Up @@ -64,10 +64,7 @@ namespace rocwmma

// Outer loop = index 0,
// Inner loop = index N-1
template <size_t Depth = 0,
typename Iterator,
typename StrideSpace,
typename Strides2d>
template <size_t Depth = 0, typename Iterator, typename StrideSpace, typename Strides2d>
ROCWMMA_DEVICE static inline auto unroll_right(DataT* dataPtr,
Iterator& in,
uint32_t ldm,
Expand All @@ -94,7 +91,6 @@ namespace rocwmma
// Recurse to the next nested layer
else
{
#pragma unroll
for(int i = 0; i < strideCount; i++)
{
unroll_right<Depth + 1>(dataPtr, in, ldm, strideCounts, strides2d);
Expand Down
35 changes: 8 additions & 27 deletions test/dlrm/lds_mapping_util.hpp
Expand Up @@ -239,13 +239,6 @@ namespace rocwmma
auto waveIndex = get<1>(waveCoord);
auto waveCount = get<1>(workgroupDim);

constexpr auto splitCount = std::min((uint32_t)IOTraits<GlobalReadFragA>::IOCount,
(uint32_t)IOTraits<LocalWriteFragA>::IOCount);

static_assert(((uint32_t)IOTraits<GlobalReadFragA>::IOCount % splitCount == 0)
&& ((uint32_t)IOTraits<LocalWriteFragA>::IOCount % splitCount == 0),
"splitCount is not common divisor of GlobalRead and LocalWrite IOCounts");

for(int32_t i = 0; i < BlocksX; ++i)
{
// Issue global read
Expand All @@ -255,16 +248,14 @@ namespace rocwmma
baseA + GlobalAOffsets::dataOffset(make_coord2d(BlockM * i, 0), lda),
lda,
waveIndex,
waveCount,
splitCount);
waveCount);

// Issue local store
store_matrix_coop_sync(baseLds + baseOffsetA() + waveOffsetA() + blockOffsetA(i),
reinterpret_cast<LocalWriteFragA&>(fetchA),
ld(),
waveIndex,
waveCount,
splitCount);
waveCount);
}
}

Expand All @@ -275,16 +266,10 @@ namespace rocwmma
// we need to ensure that splitCounts are the same on both sides of
// global fetch and local writes - Otherwise the waves don't have the
// same data responsibility.
auto workgroupDim = GlobalBOffsets::workgroupDim();
auto waveCoord = GlobalBOffsets::waveCoord();
auto waveIndex = get<0>(waveCoord);
auto waveCount = get<0>(workgroupDim);
constexpr auto splitCount = std::min((uint32_t)IOTraits<GlobalReadFragB>::IOCount,
(uint32_t)IOTraits<LocalWriteFragB>::IOCount);

static_assert(((uint32_t)IOTraits<GlobalReadFragB>::IOCount % splitCount == 0)
&& ((uint32_t)IOTraits<LocalWriteFragB>::IOCount % splitCount == 0),
"splitCount is not common divisor of GlobalRead and LocalWrite IOCounts");
auto workgroupDim = GlobalBOffsets::workgroupDim();
auto waveCoord = GlobalBOffsets::waveCoord();
auto waveIndex = get<0>(waveCoord);
auto waveCount = get<0>(workgroupDim);

for(int32_t i = 0; i < BlocksY; ++i)
{
Expand All @@ -295,16 +280,14 @@ namespace rocwmma
baseB + GlobalBOffsets::dataOffset(make_coord2d(0, BlockN * i), ldb),
ldb,
waveIndex,
waveCount,
splitCount);
waveCount);

// Issue local store
store_matrix_coop_sync(baseLds + baseOffsetB() + waveOffsetB() + blockOffsetB(i),
reinterpret_cast<LocalWriteFragB&>(fetchB),
ld(),
waveIndex,
waveCount,
splitCount);
waveCount);
}
}

Expand Down Expand Up @@ -364,7 +347,6 @@ namespace rocwmma

__device__ static inline void prefetchLocalA(FragA* fragsA, DataT const* baseLds)
{
#pragma unroll
for(int i = 0; i < BlocksX; i++)
{
prefetchLocalA(fragsA[i], baseLds, i);
Expand All @@ -373,7 +355,6 @@ namespace rocwmma

__device__ static inline void prefetchLocalB(FragB* fragsB, DataT const* baseLds)
{
#pragma unroll
for(int i = 0; i < BlocksY; i++)
{
prefetchLocalB(fragsB[i], baseLds, i);
Expand Down

0 comments on commit f358adf

Please sign in to comment.