Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix Arg{Min,Max} for infinite use case
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Mar 15, 2023
1 parent 9003e88 commit b577f8a
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 60 deletions.
57 changes: 29 additions & 28 deletions cub/device/device_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,11 @@ struct DeviceReduce

// The output tuple type
using OutputTupleT =
cub::detail::non_void_value_t<OutputIteratorT,
KeyValuePair<OffsetT, InputValueT>>;
cub::detail::non_void_value_t<OutputIteratorT, KeyValuePair<OffsetT, InputValueT>>;

using InitT = OutputTupleT;
using AccumT = OutputTupleT;

using InitT = detail::reduce::empty_problem_init_t<AccumT>;

// The output value type
using OutputValueT = typename OutputTupleT::Value;
Expand All @@ -627,23 +628,22 @@ struct DeviceReduce
ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value

// replace with std::numeric_limits<T>::max() when C++11 support is
// more prevalent
InitT initial_value(1, Traits<InputValueT>::Max());
// TODO Address https://github.com/NVIDIA/cub/issues/651
InitT initial_value{AccumT(1, Traits<InputValueT>::Max())};

return DispatchReduce<ArgIndexInputIteratorT,
OutputIteratorT,
OffsetT,
cub::ArgMin,
InitT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMin(),
initial_value,
stream);
InitT,
AccumT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMin(),
initial_value,
stream);
}

template <typename InputIteratorT, typename OutputIteratorT>
Expand Down Expand Up @@ -900,10 +900,12 @@ struct DeviceReduce
cub::detail::non_void_value_t<OutputIteratorT,
KeyValuePair<OffsetT, InputValueT>>;

using AccumT = OutputTupleT;

// The output value type
using OutputValueT = typename OutputTupleT::Value;

using InitT = OutputTupleT;
using InitT = detail::reduce::empty_problem_init_t<AccumT>;

// Wrapped input iterator to produce index-value <OffsetT, InputT> tuples
using ArgIndexInputIteratorT =
Expand All @@ -912,23 +914,22 @@ struct DeviceReduce
ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value

// replace with std::numeric_limits<T>::lowest() when C++11 support is
// more prevalent
InitT initial_value(1, Traits<InputValueT>::Lowest());
// TODO Address https://github.com/NVIDIA/cub/issues/651
InitT initial_value{AccumT(1, Traits<InputValueT>::Lowest())};

return DispatchReduce<ArgIndexInputIteratorT,
OutputIteratorT,
OffsetT,
cub::ArgMax,
InitT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMax(),
initial_value,
stream);
InitT,
AccumT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMax(),
initial_value,
stream);
}

template <typename InputIteratorT, typename OutputIteratorT>
Expand Down
47 changes: 28 additions & 19 deletions cub/device/device_segmented_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -757,34 +757,37 @@ struct DeviceSegmentedReduce
// The output value type
using OutputValueT = typename OutputTupleT::Value;

using AccumT = OutputTupleT;

using InitT = detail::reduce::empty_problem_init_t<AccumT>;

// Wrapped input iterator to produce index-value <OffsetT, InputT> tuples
using ArgIndexInputIteratorT =
ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;

ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value
OutputTupleT initial_value(1, Traits<InputValueT>::Max()); // replace with
// std::numeric_limits<T>::max()
// when C++11
// support is
// more prevalent
// TODO Address https://github.com/NVIDIA/cub/issues/651
InitT initial_value{AccumT(1, Traits<InputValueT>::Max())};

return DispatchSegmentedReduce<ArgIndexInputIteratorT,
OutputIteratorT,
BeginOffsetIteratorT,
EndOffsetIteratorT,
OffsetT,
cub::ArgMin>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_segments,
d_begin_offsets,
d_end_offsets,
cub::ArgMin(),
initial_value,
stream);
cub::ArgMin,
InitT,
AccumT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_segments,
d_begin_offsets,
d_end_offsets,
cub::ArgMin(),
initial_value,
stream);
}

template <typename InputIteratorT,
Expand Down Expand Up @@ -1127,6 +1130,10 @@ struct DeviceSegmentedReduce
cub::detail::non_void_value_t<OutputIteratorT,
KeyValuePair<OffsetT, InputValueT>>;

using AccumT = OutputTupleT;

using InitT = detail::reduce::empty_problem_init_t<AccumT>;

// The output value type
using OutputValueT = typename OutputTupleT::Value;

Expand All @@ -1136,16 +1143,18 @@ struct DeviceSegmentedReduce

ArgIndexInputIteratorT d_indexed_in(d_in);

// Initial value, replace with std::numeric_limits<T>::lowest() when C++11
// support is more prevalent
OutputTupleT initial_value(1, Traits<InputValueT>::Lowest());
// Initial value
// TODO Address https://github.com/NVIDIA/cub/issues/651
InitT initial_value{AccumT(1, Traits<InputValueT>::Lowest())};

return DispatchSegmentedReduce<ArgIndexInputIteratorT,
OutputIteratorT,
BeginOffsetIteratorT,
EndOffsetIteratorT,
OffsetT,
cub::ArgMax>::Dispatch(d_temp_storage,
cub::ArgMax,
InitT,
AccumT>::Dispatch(d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
Expand Down
62 changes: 59 additions & 3 deletions cub/device/dispatch/dispatch_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,59 @@

CUB_NAMESPACE_BEGIN

namespace detail
{
namespace reduce
{

/**
* All cub::DeviceReduce::* algorithms are using the same implementation. Some of them, however,
* should use initial value only for empty problems. If this struct is used as initial value with
* one of the `DeviceReduce` algorithms, the `init` value wrapped by this struct will only be used
* for empty problems; it will not be incorporated into the aggregate of non-empty problems.
*/
template <class T>
struct empty_problem_init_t
{
T init;

__host__ __device__ operator T() const { return init; }
};

/**
* @brief Applies initial value to the block aggregate and stores the result to the output iterator.
*
* @param d_out Iterator to the output aggregate
* @param reduction_op Binary reduction functor
* @param init Initial value
* @param block_aggregate Aggregate value computed by the block
*/
template <class OutputIteratorT, class ReductionOpT, class InitT, class AccumT>
__host__ __device__ void finalize_and_store_aggregate(OutputIteratorT d_out,
ReductionOpT reduction_op,
InitT init,
AccumT block_aggregate)
{
*d_out = reduction_op(init, block_aggregate);
}

/**
* @brief Ignores initial value and stores the block aggregate to the output iterator.
*
* @param d_out Iterator to the output aggregate
* @param block_aggregate Aggregate value computed by the block
*/
template <class OutputIteratorT, class ReductionOpT, class InitT, class AccumT>
__host__ __device__ void finalize_and_store_aggregate(OutputIteratorT d_out,
ReductionOpT,
empty_problem_init_t<InitT>,
AccumT block_aggregate)
{
*d_out = block_aggregate;
}
} // namespace reduce
} // namespace detail

/******************************************************************************
* Kernel entry points
*****************************************************************************/
Expand Down Expand Up @@ -215,7 +268,7 @@ __global__ void DeviceReduceSingleTileKernel(InputIteratorT d_in,
// Output result
if (threadIdx.x == 0)
{
*d_out = reduction_op(init, block_aggregate);
detail::reduce::finalize_and_store_aggregate(d_out, reduction_op, init, block_aggregate);
}
}

Expand Down Expand Up @@ -334,7 +387,7 @@ __global__ void DeviceSegmentedReduceKernel(
{
if (threadIdx.x == 0)
{
d_out[blockIdx.x] = init;
*(d_out + blockIdx.x) = init;
}
return;
}
Expand All @@ -348,7 +401,10 @@ __global__ void DeviceSegmentedReduceKernel(

if (threadIdx.x == 0)
{
d_out[blockIdx.x] = reduction_op(init, block_aggregate);
detail::reduce::finalize_and_store_aggregate(d_out + blockIdx.x,
reduction_op,
init,
block_aggregate);
}
}

Expand Down
Loading

0 comments on commit b577f8a

Please sign in to comment.