Skip to content

[CUB] Adds benchmarks for batched indexed top-k (aka batched arg top-k)#9288

Merged
elstehle merged 2 commits into
NVIDIA:mainfrom
elstehle:enh/topk-bench-batched-var-len-pairs
Jun 8, 2026
Merged

[CUB] Adds benchmarks for batched indexed top-k (aka batched arg top-k)#9288
elstehle merged 2 commits into
NVIDIA:mainfrom
elstehle:enh/topk-bench-batched-var-len-pairs

Conversation

@elstehle

@elstehle elstehle commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

Closes #9287

@elstehle elstehle requested a review from a team as a code owner June 7, 2026 15:58
@elstehle elstehle requested a review from oleksandr-pavlyk June 7, 2026 15:58
@github-project-automation github-project-automation Bot moved this to Todo in CCCL Jun 7, 2026
@elstehle elstehle changed the title Adds benchmarks for batched indexed top-k (aka batched arg top-k) [CUB] Adds benchmarks for batched indexed top-k (aka batched arg top-k) Jun 7, 2026
@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL Jun 7, 2026
@elstehle elstehle requested review from gevtushenko and pauleonix June 7, 2026 15:59
@coderabbitai

coderabbitai Bot commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: bb54e904-7ea3-48c8-845e-f6e5497f04cb

📥 Commits

Reviewing files that changed from the base of the PR and between 34276b2 and f639e25.

📒 Files selected for processing (2)
  • cub/benchmarks/bench/segmented_topk/variable/indexed.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
🚧 Files skipped from review as they are similar to previous changes (1)
  • cub/benchmarks/bench/segmented_topk/variable/indexed.cu

Note: CodeRabbit is enabled on this repository as a convenience for maintainers
and contributors. Use your best judgment when considering its review comments and
suggestions — a suggested change may be inadequate, unnecessary, or safe to ignore.
Contributors are not expected to address every comment. Human reviews are what
ultimately matter for merging.

Summary

This PR adds nvbench benchmarks for indexed (arg) batched top-k that output both keys and per-segment indices, closing issue #9287. It introduces shared key/pattern utilities and registers both keys-only and indexed benchmarks with nvbench.

Changes

New File

cub/benchmarks/bench/segmented_topk/variable/common.cuh

  • Adds shared benchmark utilities in an unnamed namespace:
    • enum class pattern_kind with patterns: random, quantized_random, relu_quantized, tie_heavy, pivot_tie.
    • string_to_pattern(const std::string&) which throws on invalid pattern strings.
    • Templated gen_data<MaxSegmentSize, K>(int num_segments, pattern_kind, const int64_t* d_seg_sizes) returning thrust::device_vector<float>, filling keys with pattern-specific generators (including quantization, ReLU-clamped quantization, tie-heavy, and pivot-tail behavior that reads device segment sizes). Uses a computed gt_count for pivot tails.
  • Exposes benchmark axes/config: valid_patterns, key_type_list (float), max_segment_size_list (512..8192), and k_list (512, 1024, 2048).

New Benchmark Translation Unit

cub/benchmarks/bench/segmented_topk/variable/indexed.cu

  • Adds indexed (arg-top-k) benchmark template:
    • decode_style_variable_topk_indexed<KeyT, IndexT, MaxSegmentSize, K>(nvbench::state&, ...).
    • Skips when K > MaxSegmentSize.
    • Generates per-segment sizes with generate(...) and computes element counts.
    • Produces keys via gen_data and supplies per-segment input indices via a counting iterator that restarts at 0 for each segment (implemented by providing a constant iterator of a counting iterator), avoiding pre-materialized global index arrays.
    • Allocates device output buffers for keys and indices, constructs CUB dispatch parameters, and invokes cub::detail::batched_topk::dispatch_with_env.
    • Records nvbench metrics including element counts and global memory reads/writes (including reads of SegmentSizes).
  • Restricts index type axis to 32-bit (index_type_list = nvbench::type_list<cuda::std::int32_t>).
  • Registers the benchmark as "decode_style_variable_topk_indexed" with axes: KeyT, IndexT (i32), MaxSegmentSize, K, NumSegments {1,2,4,8,16,32}, and Pattern (all valid patterns).

Modified Benchmark Translation Unit

cub/benchmarks/bench/segmented_topk/variable/keys.cu

  • Refactored to use common.cuh for pattern definitions and key generation (removed duplicate helpers and local axis typedefs).
  • Declares the keys-only benchmark entry as decode_style_variable_topk_keys (templated on KeyT, MaxSegmentSize, K).
  • Tracks global memory reads for SegmentSizes.
  • Registers "decode_style_variable_topk_keys" with axes: KeyT, MaxSegmentSize, K, NumSegments {1,2,4,8,16,32}, and Pattern.

Implementation notes

  • The indexed benchmark implements per-segment local indices by providing an iterator that yields segment-local counting indices (no global index buffer), satisfying the issue requirement to avoid pre-materializing indices.
  • Index type is limited to 32-bit integers for the initial benchmark; the code structure allows future extension to other index widths.

Walkthrough

important: Adds a shared header with pattern enums and a device key generator, a new indexed segmented top-k nvbench benchmark that emits per-segment indices, and refactors the keys-only benchmark to use the shared utilities and record segment-size memory reads.

Changes

Indexed segmented top-k benchmark

Layer / File(s) Summary
Shared benchmark utilities and configuration
cub/benchmarks/bench/segmented_topk/variable/common.cuh
Introduces pattern_kind enum, string_to_pattern, templated gen_data<MaxSegmentSize, K> using thrust::tabulate with per-pattern quantization/tie/pivot logic, and benchmark config lists/types (valid_patterns, key_type_list, max_segment_size_list, k_list).
Indexed top-k benchmark implementation and registration
cub/benchmarks/bench/segmented_topk/variable/indexed.cu
Implements decode_style_variable_topk_indexed template: skips invalid K, generates device segment sizes, constructs strided key iterators from gen_data, constructs per-segment counting_iterator indices, allocates outputs, registers nvbench memory/element metrics, launches cub::detail::batched_topk::dispatch_with_env, and registers the benchmark with int32_t index type and axes NumSegments / Pattern (name: "decode_style_variable_topk_indexed").
Refactor keys.cu to use shared utilities
cub/benchmarks/bench/segmented_topk/variable/keys.cu
Replaces local pattern/helpers with common.cuh, renames entry to decode_style_variable_topk_keys, adds SegmentSizes global memory read metric, and removes inline axis typedefs to rely on shared lists.

Assessment against linked issues

Objective Addressed Explanation
Implement indexed top-k benchmark with counting_iterator per segment [#9287]
Output indices (within segment) along with keys [#9287]
Restrict indices to int32_t [#9287]

Suggested reviewers

  • pauleonix

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
cub/benchmarks/bench/segmented_topk/variable/common.cuh (3)

152-153: ⚡ Quick win

suggestion: Mark valid_patterns as const

This global is never modified and should be const for clarity and potential compiler optimization.

Proposed fix
-const std::vector<std::string> valid_patterns = {
+const std::vector<std::string> const valid_patterns = {
   "random", "quantized_random", "relu_quantized", "tie_heavy", "pivot_tie"};

Or better, use constexpr with std::array:

-const std::vector<std::string> valid_patterns = {
-  "random", "quantized_random", "relu_quantized", "tie_heavy", "pivot_tie"};
+constexpr std::array<const char*, 5> valid_patterns = {
+  "random", "quantized_random", "relu_quantized", "tie_heavy", "pivot_tie"};

33-56: 💤 Low value

suggestion: Add function annotations and noexcept specification

Per guidelines, functions should be marked with _CCCL_HOST_API and use noexcept(false) if they can throw. While benchmark utilities may have relaxed requirements, annotations improve clarity.

Proposed fix
-[[nodiscard]] pattern_kind string_to_pattern(const std::string& pattern)
+[[nodiscard]] inline _CCCL_HOST_API pattern_kind string_to_pattern(const std::string& pattern) noexcept(false)

Source: Coding guidelines


58-149: 💤 Low value

suggestion: Add host API annotation to gen_data

This template is host-only (returns thrust::device_vector). Per guidelines it should be annotated, though benchmark code may have relaxed rules.

Proposed fix
 template <int MaxSegmentSize, int K>
-[[nodiscard]] thrust::device_vector<float>
+[[nodiscard]] inline _CCCL_HOST_API thrust::device_vector<float>
 gen_data(int num_segments, pattern_kind pattern, const cuda::std::int64_t* d_seg_sizes)

Source: Coding guidelines

cub/benchmarks/bench/segmented_topk/variable/indexed.cu (1)

59-63: suggestion: The constant_iterator + counting_iterator pattern at line 60 correctly produces segment-local indices [0, segment_size) for each segment. The agent dereferences d_value_segments_it[segment_id] to extract the segment-specific counting_iterator (line 251 in agent_batched_topk.cuh), then BlockLoad reads from it, generating the expected per-segment index sequence. Pattern is sound but worth documenting more explicitly in comments since the iterator-of-iterators chain is unconventional.


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 2061071e-4ea9-4569-b0e7-0b56cd3365f4

📥 Commits

Reviewing files that changed from the base of the PR and between cea7dcd and 07ad36d.

📒 Files selected for processing (3)
  • cub/benchmarks/bench/segmented_topk/variable/common.cuh
  • cub/benchmarks/bench/segmented_topk/variable/indexed.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu

Comment thread cub/benchmarks/bench/segmented_topk/variable/common.cuh
Comment thread cub/benchmarks/bench/segmented_topk/variable/common.cuh
@github-actions

This comment has been minimized.

Comment thread cub/benchmarks/bench/segmented_topk/variable/indexed.cu Outdated
Comment thread cub/benchmarks/bench/segmented_topk/variable/indexed.cu
@elstehle elstehle enabled auto-merge (squash) June 8, 2026 07:58

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 617ce461-cf61-45a1-b55c-6f5c437e4d50

📥 Commits

Reviewing files that changed from the base of the PR and between 07ad36d and 34276b2.

📒 Files selected for processing (2)
  • cub/benchmarks/bench/segmented_topk/variable/indexed.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
🚧 Files skipped from review as they are similar to previous changes (1)
  • cub/benchmarks/bench/segmented_topk/variable/indexed.cu

Comment thread cub/benchmarks/bench/segmented_topk/variable/keys.cu Outdated
@elstehle elstehle force-pushed the enh/topk-bench-batched-var-len-pairs branch from 34276b2 to f639e25 Compare June 8, 2026 08:23
@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

🥳 CI Workflow Results

🟩 Finished in 36m 15s: Pass: 100%/242 | Total: 1d 09h | Max: 33m 56s | Hits: 99%/147326

See results here.

@elstehle elstehle merged commit d67bf7f into NVIDIA:main Jun 8, 2026
265 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

Add benchmarks for indexed batched top-k

2 participants