Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

add support FutureValue for reduce #460

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions cub/device/device_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,45 @@ struct DeviceReduce
// Signed integer type for global offsets
typedef int OffsetT;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT>::Dispatch(
return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT, detail::InputValue<T>>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
reduction_op,
init,
detail::InputValue<T>(init),
stream,
debug_synchronous);
}

template <
typename InputIteratorT,
typename OutputIteratorT,
typename ReductionOpT,
typename InitValueT,
typename InitValueIterT = InitValueT *>
CUB_RUNTIME_FUNCTION static cudaError_t Reduce(
void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items
OutputIteratorT d_out, ///< [out] Pointer to the output aggregate
int num_items, ///< [in] Total number of input items (i.e., length of \p d_in)
ReductionOpT reduction_op, ///< [in] Binary reduction functor
FutureValue<InitValueT> init, ///< [in] Initial value of the reduction
cudaStream_t stream = 0, ///< [in] <b>[optional]</b> CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< [in] <b>[optional]</b> Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
// Signed integer type for global offsets
typedef int OffsetT;
return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT, detail::InputValue<InitValueT>>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
reduction_op,
detail::InputValue<InitValueT>(init),
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -239,14 +270,14 @@ struct DeviceReduce
cub::detail::non_void_value_t<OutputIteratorT,
cub::detail::value_t<InputIteratorT>>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Sum>::Dispatch(
return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Sum, detail::InputValue<OutputT>>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Sum(),
OutputT(), // zero-initialize
detail::InputValue<OutputT>(OutputT{}), // zero-initialize
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -314,14 +345,15 @@ struct DeviceReduce
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Min>::Dispatch(
auto init_val = Traits<InputT>::Max();
return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Min, detail::InputValue<InputT>>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Min(),
Traits<InputT>::Max(), // replace with std::numeric_limits<T>::max() when C++11 support is more prevalent
detail::InputValue<InputT>(init_val), // replace with std::numeric_limits<T>::max() when C++11 support is more prevalent
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -407,15 +439,15 @@ struct DeviceReduce

// Initial value
OutputTupleT initial_value(1, Traits<InputValueT>::Max()); // replace with std::numeric_limits<T>::max() when C++11 support is more prevalent

return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMin>::Dispatch(
return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMin, detail::InputValue<OutputTupleT>>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMin(),
initial_value,
detail::InputValue<OutputTupleT>(initial_value),
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -483,14 +515,15 @@ struct DeviceReduce
// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Max>::Dispatch(
auto init_val = Traits<InputT>::Lowest();
return DispatchReduce<InputIteratorT, OutputIteratorT, OffsetT, cub::Max, detail::InputValue<InputT>>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_in,
d_out,
num_items,
cub::Max(),
Traits<InputT>::Lowest(), // replace with std::numeric_limits<T>::lowest() when C++11 support is more prevalent
detail::InputValue<InputT>(init_val), // replace with std::numeric_limits<T>::lowest() when C++11 support is more prevalent
stream,
debug_synchronous);
}
Expand Down Expand Up @@ -577,14 +610,14 @@ struct DeviceReduce
// Initial value
OutputTupleT initial_value(1, Traits<InputValueT>::Lowest()); // replace with std::numeric_limits<T>::lowest() when C++11 support is more prevalent

return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMax>::Dispatch(
return DispatchReduce<ArgIndexInputIteratorT, OutputIteratorT, OffsetT, cub::ArgMax, detail::InputValue<OutputTupleT>>::Dispatch(
d_temp_storage,
temp_storage_bytes,
d_indexed_in,
d_out,
num_items,
cub::ArgMax(),
initial_value,
detail::InputValue<OutputTupleT>(initial_value),
stream,
debug_synchronous);
}
Expand Down
23 changes: 13 additions & 10 deletions cub/device/dispatch/dispatch_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,17 @@ template <
typename OutputIteratorT, ///< Output iterator type for recording the reduced aggregate \iterator
typename OffsetT, ///< Signed integer type for global offsets
typename ReductionOpT, ///< Binary reduction functor type having member <tt>T operator()(const T &a, const T &b)</tt>
typename OutputT> ///< Data element type that is convertible to the \p value type of \p OutputIteratorT
typename InitValT> ///< Data element type that is convertible to the \p value type of \p OutputIteratorT
__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1)
__global__ void DeviceReduceSingleTileKernel(
InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items
OutputIteratorT d_out, ///< [out] Pointer to the output aggregate
OffsetT num_items, ///< [in] Total number of input data items
ReductionOpT reduction_op, ///< [in] Binary reduction functor
OutputT init) ///< [in] The initial value of the reduction
InitValT init) ///< [in] The initial value of the reduction
{
using RealInitValT = typename InitValT::value_type;
RealInitValT real_init = init;
// Thread block type for reducing input tiles
typedef AgentReduce<
typename ChainedPolicyT::ActivePolicy::SingleTilePolicy,
Expand All @@ -130,18 +132,18 @@ __global__ void DeviceReduceSingleTileKernel(
if (num_items == 0)
{
if (threadIdx.x == 0)
*d_out = init;
*d_out = real_init;
return;
}

// Consume input tiles
OutputT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op).ConsumeRange(
RealInitValT block_aggregate = AgentReduceT(temp_storage, d_in, reduction_op).ConsumeRange(
OffsetT(0),
num_items);

// Output result
if (threadIdx.x == 0)
*d_out = reduction_op(init, block_aggregate);
*d_out = reduction_op(real_init, block_aggregate);
}


Expand Down Expand Up @@ -317,6 +319,7 @@ template <
typename OutputIteratorT, ///< Output iterator type for recording the reduced aggregate \iterator
typename OffsetT, ///< Signed integer type for global offsets
typename ReductionOpT, ///< Binary reduction functor type having member <tt>T operator()(const T &a, const T &b)</tt>
typename InitValT,
typename OutputT = ///< Data type of the output iterator
cub::detail::non_void_value_t<
OutputIteratorT,
Expand All @@ -339,7 +342,7 @@ struct DispatchReduce :
OutputIteratorT d_out; ///< [out] Pointer to the output aggregate
OffsetT num_items; ///< [in] Total number of input items (i.e., length of \p d_in)
ReductionOpT reduction_op; ///< [in] Binary reduction functor
OutputT init; ///< [in] The initial value of the reduction
InitValT init; ///< [in] The initial value of the reduction
cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
int ptx_version; ///< [in] PTX version
Expand All @@ -357,7 +360,7 @@ struct DispatchReduce :
OutputIteratorT d_out,
OffsetT num_items,
ReductionOpT reduction_op,
OutputT init,
InitValT init,
cudaStream_t stream,
bool debug_synchronous,
int ptx_version)
Expand Down Expand Up @@ -570,14 +573,14 @@ struct DispatchReduce :
{
// Small, single tile size
return InvokeSingleTile<ActivePolicyT>(
DeviceReduceSingleTileKernel<MaxPolicyT, InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT, OutputT>);
DeviceReduceSingleTileKernel<MaxPolicyT, InputIteratorT, OutputIteratorT, OffsetT, ReductionOpT, InitValT>);
}
else
{
// Regular size
return InvokePasses<ActivePolicyT>(
DeviceReduceKernel<typename DispatchReduce::MaxPolicy, InputIteratorT, OutputT*, OffsetT, ReductionOpT>,
DeviceReduceSingleTileKernel<MaxPolicyT, OutputT*, OutputIteratorT, OffsetT, ReductionOpT, OutputT>);
DeviceReduceSingleTileKernel<MaxPolicyT, OutputT*, OutputIteratorT, OffsetT, ReductionOpT, InitValT>);
}
}

Expand All @@ -597,7 +600,7 @@ struct DispatchReduce :
OutputIteratorT d_out, ///< [out] Pointer to the output aggregate
OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in)
ReductionOpT reduction_op, ///< [in] Binary reduction functor
OutputT init, ///< [in] The initial value of the reduction
InitValT init, ///< [in] The initial value of the reduction
cudaStream_t stream, ///< [in] <b>[optional]</b> CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous) ///< [in] <b>[optional]</b> Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
{
Expand Down