Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions cub/cub/device/dispatch/dispatch_three_way_partition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,16 @@ DeviceThreeWayPartitionInitKernel(ScanTileStateT tile_state, int num_tiles, NumS
* Dispatch
******************************************************************************/

template <typename InputIteratorT,
typename FirstOutputIteratorT,
typename SecondOutputIteratorT,
typename UnselectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectFirstPartOp,
typename SelectSecondPartOp,
typename OffsetT,
typename SelectedPolicy =
detail::device_three_way_partition_policy_hub<cub::detail::value_t<InputIteratorT>, OffsetT>>
template <
typename InputIteratorT,
typename FirstOutputIteratorT,
typename SecondOutputIteratorT,
typename UnselectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectFirstPartOp,
typename SelectSecondPartOp,
typename OffsetT,
typename SelectedPolicy = detail::three_way_partition::policy_hub<cub::detail::value_t<InputIteratorT>, OffsetT>>
struct DispatchThreeWayPartitionIf
{
/*****************************************************************************
Expand Down
249 changes: 97 additions & 152 deletions cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ CUB_NAMESPACE_BEGIN

namespace detail
{

namespace three_way_partition
{

enum class input_size
{
_1,
Expand Down Expand Up @@ -92,246 +90,193 @@ template <class InputT,
class OffsetT,
input_size InputSize = classify_input_size<InputT>(),
offset_size OffsetSize = classify_offset_size<OffsetT>()>
struct sm90_tuning
{
static constexpr int threads = 256;
static constexpr int items = Nominal4BItemsToItems<InputT>(9);

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t<OffsetT>;
using AccumPackT = typename AccumPackHelperT::pack_t;
using delay_constructor = detail::default_delay_constructor_t<AccumPackT>;
};
struct sm80_tuning;

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<445>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<104, 512>;
static constexpr int threads = 256;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = no_delay_constructor_t<910>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
{
static constexpr int threads = 320;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1105>;
static constexpr int threads = 256;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
using delay_constructor = no_delay_constructor_t<1120>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
{
static constexpr int threads = 384;
static constexpr int items = 7;

static constexpr int threads = 224;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<464, 1165>;
using delay_constructor = fixed_delay_constructor_t<264, 1080>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
struct sm80_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
{
static constexpr int threads = 128;
static constexpr int items = 7;

static constexpr int threads = 128;
static constexpr int items = 10;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1040>;
using delay_constructor = fixed_delay_constructor_t<672, 1120>;
};

template <class InputT,
class OffsetT,
input_size InputSize = classify_input_size<InputT>(),
offset_size OffsetSize = classify_offset_size<OffsetT>()>
struct sm90_tuning;

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 24;

static constexpr int threads = 256;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<4, 285>;
using delay_constructor = no_delay_constructor_t<445>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
{
static constexpr int threads = 640;
static constexpr int items = 24;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<245>;
static constexpr int threads = 256;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = fixed_delay_constructor_t<104, 512>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 23;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<910>;
static constexpr int threads = 320;
static constexpr int items = 12;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
using delay_constructor = no_delay_constructor_t<1105>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 18;

static constexpr int threads = 384;
static constexpr int items = 7;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1145>;
using delay_constructor = fixed_delay_constructor_t<464, 1165>;
};

template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 11;

static constexpr int threads = 128;
static constexpr int items = 7;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1050>;
using delay_constructor = no_delay_constructor_t<1040>;
};

template <class InputT,
class OffsetT,
input_size InputSize = classify_input_size<InputT>(),
offset_size OffsetSize = classify_offset_size<OffsetT>()>
struct sm80_tuning
template <class Input, class OffsetT>
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_8>
{
static constexpr int threads = 256;
static constexpr int items = Nominal4BItemsToItems<InputT>(9);

static constexpr int threads = 256;
static constexpr int items = 24;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t<OffsetT>;
using AccumPackT = typename AccumPackHelperT::pack_t;
using delay_constructor = detail::default_delay_constructor_t<AccumPackT>;
using delay_constructor = fixed_delay_constructor_t<4, 285>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_8>
{
static constexpr int threads = 256;
static constexpr int items = 12;

static constexpr int threads = 640;
static constexpr int items = 24;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<910>;
using delay_constructor = no_delay_constructor_t<245>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_8>
{
static constexpr int threads = 256;
static constexpr int items = 11;

static constexpr int threads = 256;
static constexpr int items = 23;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1120>;
using delay_constructor = no_delay_constructor_t<910>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_8>
{
static constexpr int threads = 224;
static constexpr int items = 11;

static constexpr int threads = 256;
static constexpr int items = 18;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<264, 1080>;
using delay_constructor = no_delay_constructor_t<1145>;
};

template <class Input, class OffsetT>
struct sm80_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
{
static constexpr int threads = 128;
static constexpr int items = 10;

static constexpr int threads = 256;
static constexpr int items = 11;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<672, 1120>;
using delay_constructor = no_delay_constructor_t<1050>;
};

} // namespace three_way_partition

template <class InputT, class OffsetT>
struct device_three_way_partition_policy_hub
struct policy_hub
{
struct DefaultTuning
template <typename DelayConstructor>
struct DefaultPolicy
{
static constexpr int ITEMS_PER_THREAD = Nominal4BItemsToItems<InputT>(9);

using ThreeWayPartitionPolicy =
cub::AgentThreeWayPartitionPolicy<256,
ITEMS_PER_THREAD,
cub::BLOCK_LOAD_DIRECT,
cub::LOAD_DEFAULT,
cub::BLOCK_SCAN_WARP_SCANS>;
AgentThreeWayPartitionPolicy<256,
Nominal4BItemsToItems<InputT>(9),
BLOCK_LOAD_DIRECT,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
DelayConstructor>;
};

/// SM35
struct Policy350
: DefaultTuning
: DefaultPolicy<fixed_delay_constructor_t<350, 450>>
, ChainedPolicy<350, Policy350, Policy350>
{};

// Use values from tuning if a specialization exists, otherwise pick DefaultPolicy
template <typename Tuning>
static auto select_agent_policy(int)
-> AgentThreeWayPartitionPolicy<Tuning::threads,
Tuning::items,
Tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename Tuning::delay_constructor>;

template <typename Tuning>
static auto select_agent_policy(long) ->
typename DefaultPolicy<
default_delay_constructor_t<typename accumulator_pack_t<OffsetT>::pack_t>>::ThreeWayPartitionPolicy;

struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
{
using tuning = detail::three_way_partition::sm80_tuning<InputT, OffsetT>;

using ThreeWayPartitionPolicy =
AgentThreeWayPartitionPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
cub::LOAD_DEFAULT,
cub::BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
using ThreeWayPartitionPolicy = decltype(select_agent_policy<sm80_tuning<InputT, OffsetT>>(0));
};

struct Policy860
: DefaultTuning
: DefaultPolicy<fixed_delay_constructor_t<350, 450>>
, ChainedPolicy<860, Policy860, Policy800>
{};

/// SM90
struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::three_way_partition::sm90_tuning<InputT, OffsetT>;

using ThreeWayPartitionPolicy =
AgentThreeWayPartitionPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
cub::LOAD_DEFAULT,
cub::BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
using ThreeWayPartitionPolicy = decltype(select_agent_policy<sm90_tuning<InputT, OffsetT>>(0));
};

using MaxPolicy = Policy900;
};

} // namespace three_way_partition
} // namespace detail

CUB_NAMESPACE_END
Loading