Skip to content

Commit

Permalink
Fix Vec generator constructor for non-integral types.
Browse files Browse the repository at this point in the history
  • Loading branch information
sliwowitz authored and psychocoderHPC committed Mar 8, 2024
1 parent 1c06aa8 commit 7dce4ec
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
14 changes: 9 additions & 5 deletions include/alpaka/vec/Vec.hpp
Expand Up @@ -87,18 +87,22 @@ namespace alpaka
ALPAKA_FN_HOST_ACC constexpr explicit Vec(
F&& generator,
std::void_t<decltype(generator(std::integral_constant<std::size_t, 0>{}))>* ignore = nullptr)
: Vec(std::forward<F>(generator), std::make_index_sequence<TDim::value>{})
{
static_cast<void>(ignore);
}
#else
template<typename F, std::enable_if_t<std::is_invocable_v<F, std::integral_constant<std::size_t, 0>>, int> = 0>
ALPAKA_FN_HOST_ACC constexpr explicit Vec(F&& generator)
#endif
: Vec(std::forward<F>(generator), std::make_integer_sequence<TVal, TDim::value>{})
: Vec(std::forward<F>(generator), std::make_index_sequence<TDim::value>{})
{
}
#endif

private:
template<typename F, TVal... Is>
ALPAKA_FN_HOST_ACC constexpr explicit Vec(F&& generator, std::integer_sequence<TVal, Is...>)
: m_data{generator(std::integral_constant<TVal, Is>{})...}
template<typename F, std::size_t... Is>
ALPAKA_FN_HOST_ACC constexpr explicit Vec(F&& generator, std::index_sequence<Is...>)
: m_data{generator(std::integral_constant<std::size_t, Is>{})...}
{
}

Expand Down
24 changes: 24 additions & 0 deletions test/unit/vec/src/VecTest.cpp
Expand Up @@ -468,3 +468,27 @@ TEST_CASE("accessByNameConstexpr", "[vec]")
STATIC_REQUIRE(v4.z() == 3);
STATIC_REQUIRE(v4.w() == 4);
}

TEMPLATE_TEST_CASE("Vec generator constructor", "[vec]", std::size_t, int, unsigned, float, double)
{
// Define a generator function
auto generator = [](auto index) { return static_cast<TestType>(index.value + 1); };

// Create a Vec object using the generator function
alpaka::Vec<alpaka::DimInt<5>, TestType> vec(generator);

// Check that the values in the Vec object are as expected
for(std::size_t i = 0; i < 5; ++i)
{
// Floating point types require a precision check instead of an exact == match
if constexpr(std::is_floating_point<TestType>::value)
{
auto const precision = std::numeric_limits<TestType>::epsilon() * 5; // Arbitrary precision requirement
CHECK(std::abs(vec[i] - static_cast<TestType>(i + 1)) < precision);
}
else
{
CHECK(vec[i] == static_cast<TestType>(i + 1));
}
}
}

0 comments on commit 7dce4ec

Please sign in to comment.