Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fixes superfluous kernel template instantiations in the prefix scan #312

Merged
merged 1 commit into from
Jun 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ __global__ void DeviceCompactInitKernel(
* Scan kernel entry point (multi-block)
*/
template <
typename ScanPolicyT, ///< Parameterized ScanPolicyT tuning policy type
typename ChainedPolicyT, ///< Chained tuning policy
typename InputIteratorT, ///< Random-access input iterator type for reading scan inputs \iterator
typename OutputIteratorT, ///< Random-access output iterator type for writing scan outputs \iterator
typename ScanTileStateT, ///< Tile status interface type
typename ScanOpT, ///< Binary scan functor type having member <tt>T operator()(const T &a, const T &b)</tt>
typename InitValueT, ///< Initial value to seed the exclusive scan (cub::NullType for inclusive scans)
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int(ScanPolicyT::BLOCK_THREADS))
__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
__global__ void DeviceScanKernel(
InputIteratorT d_in, ///< Input data
OutputIteratorT d_out, ///< Output data
Expand All @@ -112,6 +112,8 @@ __global__ void DeviceScanKernel(
InitValueT init_value, ///< Initial value to seed the exclusive scan
OffsetT num_items) ///< Total number of scan items for the entire problem
{
typedef typename ChainedPolicyT::ActivePolicy::ScanPolicyT ScanPolicyT;

// Thread block type for scanning input tiles
typedef AgentScan<
ScanPolicyT,
Expand Down Expand Up @@ -387,12 +389,12 @@ struct DispatchScan:
CUB_RUNTIME_FUNCTION __host__ __forceinline__
cudaError_t Invoke()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Host-side Invoke function template is getting instantiated for all tuning policies in the chain. Hence, the explicit kernel template instantiation for DeviceScanKernel<Policy ... below would be instantiated for all Policy.

Instead, the idea is to instantiate the DeviceScanKernel kernel template always with the top of the tuning policy chain and then, in the device code, identify and apply the tuning policy for the GPU architecture of the current compilation pass.

{
typedef typename ActivePolicyT::ScanPolicyT Policy;
typedef typename DispatchScan::MaxPolicy MaxPolicyT;
typedef typename cub::ScanTileState<OutputT> ScanTileStateT;
// Ensure kernels are instantiated.
return Invoke<ActivePolicyT>(
DeviceScanInitKernel<ScanTileStateT>,
DeviceScanKernel<Policy, InputIteratorT, OutputIteratorT, ScanTileStateT, ScanOpT, InitValueT, OffsetT>
DeviceScanKernel<MaxPolicyT, InputIteratorT, OutputIteratorT, ScanTileStateT, ScanOpT, InitValueT, OffsetT>
);
}

Expand Down