Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #712 from senior-zero/enh-main/github/scan_sm90_tu…
Browse files Browse the repository at this point in the history
…ning

Introduce SM90 tuning policy into scan
  • Loading branch information
gevtushenko committed Jun 8, 2023
2 parents b87c356 + 8a6b822 commit aaf5498
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 4 deletions.
2 changes: 1 addition & 1 deletion benchmarks/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def iterate_case_dfs(args, callable):

for gpu in ctk_cub_df['gpu'].unique():
target_df = ctk_cub_df[ctk_cub_df['gpu'] == gpu]
target_df.drop(columns=['ctk', 'cub', 'gpu'], inplace=True)
target_df = target_df.drop(columns=['ctk', 'cub', 'gpu'])
target_df = compute_speedup(target_df)

for ct_point in ct_space(target_df):
Expand Down
65 changes: 62 additions & 3 deletions cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,51 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
* Policy
******************************************************************************/

template <typename AccumT> ///< Data type
namespace detail
{
namespace scan
{

template <int Threads, int Items, int L2B, int L2W>
struct tuning
{
static constexpr int threads = Threads;
static constexpr int items = Items;

using delay_constructor = detail::fixed_delay_constructor_t<L2B, L2W>;
};

template <class AccumT,
bool PrimitiveOp,
bool PrimitiveAccumulator = Traits<AccumT>::PRIMITIVE,
std::size_t AccumSize = sizeof(AccumT)>
struct sm90_tuning
{
static constexpr int threads = 128;
static constexpr int items = 15;

using delay_constructor = detail::default_delay_constructor_t<AccumT>;
};

// clang-format off
template <class T> struct sm90_tuning<T, true, true, 1> : tuning<192, 22, 168, 1140> {};
template <class T> struct sm90_tuning<T, true, true, 2> : tuning<512, 12, 376, 1125> {};
template <class T> struct sm90_tuning<T, true, true, 4> : tuning<128, 24, 648, 1245> {};
template <class T> struct sm90_tuning<T, true, true, 8> : tuning<224, 24, 632, 1290> {};

template <> struct sm90_tuning<float, true, true, sizeof(float)> : tuning<128, 24, 688, 1140> {};
template <> struct sm90_tuning<double, true, true, sizeof(double)> : tuning<224, 24, 576, 1215> {};

#if CUB_IS_INT128_ENABLED
template <> struct sm90_tuning< __int128_t, true, false, sizeof(__int128_t)> : tuning<576, 21, 860, 630> {};
template <> struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)> : tuning<576, 21, 860, 630> {};
#endif
// clang-format on

} // namespace scan
} // namespace detail

template <typename AccumT, typename ScanOpT = Sum>
struct DeviceScanPolicy
{
// For large values, use timesliced loads/stores to fit shared memory.
Expand Down Expand Up @@ -271,7 +315,22 @@ struct DeviceScanPolicy
detail::default_delay_constructor_t<AccumT>>;
};

using MaxPolicy = Policy600;
/// SM900
struct Policy900 : ChainedPolicy<900, Policy900, Policy600>
{
using tuning = detail::scan::sm90_tuning<AccumT, detail::basic_binary_op_t<ScanOpT>::value>;

using ScanPolicyT = policy_t<tuning::threads,
tuning::items,
AccumT,
ScanTransposedLoad,
LOAD_DEFAULT,
ScanTransposedStore,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

using MaxPolicy = Policy900;
};

/******************************************************************************
Expand Down Expand Up @@ -312,7 +371,7 @@ template <typename InputIteratorT,
cub::detail::value_t<InputIteratorT>,
typename InitValueT::value_type>,
cub::detail::value_t<InputIteratorT>>,
typename SelectedPolicy = DeviceScanPolicy<AccumT>>
typename SelectedPolicy = DeviceScanPolicy<AccumT, ScanOpT>>
struct DispatchScan : SelectedPolicy
{
//---------------------------------------------------------------------
Expand Down
27 changes: 27 additions & 0 deletions cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,33 @@ struct ArgMin
}
};

namespace detail
{
template <class OpT>
struct basic_binary_op_t
{
static constexpr bool value = false;
};

template <>
struct basic_binary_op_t<Sum>
{
static constexpr bool value = true;
};

template <>
struct basic_binary_op_t<Min>
{
static constexpr bool value = true;
};

template <>
struct basic_binary_op_t<Max>
{
static constexpr bool value = true;
};
} // namespace detail

/// @brief Default cast functor
template <typename B>
struct CastOp
Expand Down

0 comments on commit aaf5498

Please sign in to comment.