Skip to content

Commit

Permalink
Merge pull request #47 from STEllAR-GROUP/fixing_44_1
Browse files Browse the repository at this point in the history
More fixes for 4D reductions
  • Loading branch information
hkaiser committed Sep 13, 2019
2 parents 9848ed0 + c0bffc7 commit 9bfebb6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
6 changes: 4 additions & 2 deletions blaze_tensor/math/expressions/DArrReduceExpr.h
Expand Up @@ -962,10 +962,12 @@ inline ElementType_t<MT> darrayreduce( const DenseArray<MT>& dm, OP op )

ET redux = ( ~dm )( dims ); // start with first element

ArrayForEachGrouped( ( ~dm ).dimensions(),
std::array< size_t, N > starts_with{1};
ArrayForEachGroupedStartsWith( ( ~dm ).dimensions(),
[&]( std::array< size_t, N > const& dims ) {
redux = op( redux, tmp( dims ) );
} );
},
starts_with );

return redux;
}
Expand Down
44 changes: 44 additions & 0 deletions blaze_tensor/util/ArrayForEach.h
Expand Up @@ -231,6 +231,50 @@ void ArrayForEachGrouped(
}
//*************************************************************************************************

//*************************************************************************************************
/*!\brief ArrayForEachGroupedStartsWith function to iterate over arbitrary dimension data. Start with given index.
// \ingroup util
*/
template< typename F, size_t M >
void ArrayForEachGroupedStartsWith(
size_t dim0, F const& f, std::array< size_t, M >& currdims, size_t starts_with )
{
for( currdims[0] = starts_with; currdims[0] != dim0; ++currdims[0] ) {
f( currdims );
}
}

template< typename F, size_t M, size_t OrgN >
void ArrayForEachGroupedStartsWith( std::array< size_t, 2 > const& dims,
F const& f, std::array< size_t, M >& currdims, std::array< size_t, OrgN >& starts_with )
{
for( currdims[1] = starts_with[1]; currdims[1] != dims[1]; ++currdims[1] ) {
ArrayForEachGroupedStartsWith( dims[0], f, currdims, starts_with[0] );
starts_with[0] = 0;
}
}

template< size_t N, typename F, size_t M, size_t OrgN >
void ArrayForEachGroupedStartsWith( std::array< size_t, N > const& dims,
F const& f, std::array< size_t, M >& currdims, std::array< size_t, OrgN >& starts_with )
{
BLAZE_STATIC_ASSERT( N > 2 );
std::array< size_t, N - 1 > shifted_dims = shiftDims( dims );
for( currdims[N - 1] = starts_with[N - 1]; currdims[N - 1] != dims[N - 1]; ++currdims[N - 1] ) {
ArrayForEachGroupedStartsWith( shifted_dims, f, currdims, starts_with );
starts_with[N - 2] = 0;
}
}

template< size_t N, typename F >
void ArrayForEachGroupedStartsWith(
std::array< size_t, N > const& dims, F const& f, std::array< size_t, N >& starts_with )
{
std::array< size_t, N > currdims{};
ArrayForEachGroupedStartsWith( dims, f, currdims, starts_with );
}
//*************************************************************************************************

//*************************************************************************************************
/*!\brief ArrayForEachGrouped function to iterate over arbitrary dimension data.
// \ingroup util
Expand Down

0 comments on commit 9bfebb6

Please sign in to comment.