Conversation
Adds a custom AIR TopK implementation (header-only, vendored into transformer_engine/common/util/) exposed as a JAX FFI custom call via the TE JAX extension. Key changes: - transformer_engine/common/util/air_topk.cu: AIR TopK CUDA kernel - transformer_engine/common/util/standalone_air_topk.cuh: vendored header - transformer_engine/common/include/transformer_engine/air_topk.h: C API - transformer_engine/jax/csrc/extensions/air_topk.cpp: JAX FFI binding - transformer_engine/jax/cpp_extensions/air_topk.py: Python wrapper - CMakeLists.txt: compile new kernel; use CCCL from CUDA toolkit - CMakeLists.txt: fix SM100 arch handling when all arches are special-cased Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR vendored the AIR radix-selection top-K algorithm as a new CUDA kernel ( Confidence Score: 5/5Safe to merge; all findings are P2 style/hygiene issues that do not affect correctness or runtime behaviour. All P0/P1 concerns from prior review rounds have been addressed. The four remaining comments are P2: two style issues in the vendored header (global-namespace helpers, dead code with magic constants), one unresolved UB TODO in a union that nvcc handles correctly in practice, and one missing Python-level guard for k > seq_len that the kernel already handles gracefully. transformer_engine/common/util/standalone_topk.cuh (three minor P2 issues); transformer_engine/jax/cpp_extensions/topk.py (missing k <= seq_len guard) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["topk(x, k_value)\nPython API"] --> B{"x.ndim == 1?"}
B -->|yes| C["unsqueeze → (1, seq_len)"]
B -->|no| D["(batch_size, seq_len)"]
C --> D
D --> E["TopKPrimitive.outer_primitive.bind()\nlengths = full(batch_size, seq_len, int32)"]
E --> F["TopkFFI (C++)\nJAX FFI handler"]
F --> G["nvte_topk (C API)\ntopk.cu"]
G --> H{"len ≤ 32768?"}
H -->|yes – one-block| I["radix_topk_one_block_kernel\n<<<batch_size, 1024>>>"]
H -->|no – multi-block| J["calc_grid_dim → grid_dim\n(cached sm_cnt)"]
J --> K["radix_kernel loop\n<<<grid_dim × batch, 256>>>"]
K --> L["last_filter_kernel"]
I --> M["out_keys (batch, k)\nout_indices (batch, k)"]
L --> M
M --> N{"squeezed?"}
N -->|yes| O["squeeze → (k,)"]
N -->|no| P["return (values, indices)"]
O --> P
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
…ing export, cache sm_cnt - Move WARP_SIZE/WARP_BITS/FULL_WARP_MASK/VECTORIZED_READ_SIZE into namespace nv - Remove unused keys_element_bytes variable in AirTopkFFI; collapse switch to dtype validation - Add missing `from .air_topk import *` export in jax/cpp_extensions/__init__.py - Cache sm_cnt per device with static vars to avoid repeated cudaGetDevice/cudaDeviceGetAttribute calls - Add CMAKE_BUILD_WITH_INSTALL_RPATH=ON to build_ext.py Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
36e8405 to
1e6c976
Compare
for more information, see https://pre-commit.ci
Remove the `air_` prefix from all TopK-related identifiers: file names, C API functions (nvte_air_topk -> nvte_topk), FFI handler/primitive names (te_air_topk_ffi -> te_topk_ffi), Python symbols, and the internal `air_topk` namespace in standalone_topk.cuh. No functional changes. Signed-off-by: Diego Campora <dcampora@nvidia.com> Signed-off-by: dcampora <961215+dcampora@users.noreply.github.com>
for more information, see https://pre-commit.ci
| * \param[in] k Top-K count. | ||
| * \return Required workspace size in bytes. | ||
| */ | ||
| size_t nvte_get_topk_workspace_bytes(int batch_size, int seq_len, int k); |
There was a problem hiding this comment.
In the other parts of TE we follow the convention of running the main function with empty workspace to get the size, rather than a specialized function, see e.g. the layernorm functions. Could we make that consistent?
| // Helper: convert a float literal to type T without relying on implicit | ||
| // conversions (needed when __CUDA_NO_BFLOAT16_CONVERSIONS__ is defined). |
| @@ -0,0 +1,1281 @@ | |||
| /************************************************************************* | |||
There was a problem hiding this comment.
A general comment - there is some duplication here with the rest of the codebase. My assumption is though that this is mostly temporary and we will want to switch to cub once it has this implementation, so I'm fine with merging this file as is.
There was a problem hiding this comment.
Yes, this is temporary and we will switch to cub the moment the optimizations to top-k land there.
| # If all architectures were special-cased and removed, disable CMake's automatic | ||
| # CUDA_ARCHITECTURES management — compilation flags are set via COMPILE_OPTIONS below. | ||
| if(NOT CMAKE_CUDA_ARCHITECTURES) | ||
| set(CMAKE_CUDA_ARCHITECTURES OFF) | ||
| endif() |
There was a problem hiding this comment.
This change is not needed for this PR.
| if squeezed: | ||
| x = x[jnp.newaxis, :] # (1, seq_len) | ||
|
|
||
| batch_size, seq_len = x.shape |
There was a problem hiding this comment.
nit self-resolve: Can we add this assert before this line?
assert x.ndim == 2, f"topk expected 2D input tensor 'x' but {x.shape=}"
|
/te-ci |
Description
Adds a custom AIR TopK implementation (header-only, vendored into
transformer_engine/common/util/) exposed as a JAX FFI custom callvia the TE JAX extension.
Type of change
Changes
transformer_engine/common/util/air_topk.cu: AIR TopK CUDA kerneltransformer_engine/common/util/standalone_air_topk.cuh: vendored header (AIR TopK, header-only)transformer_engine/common/include/transformer_engine/air_topk.h: C APItransformer_engine/jax/csrc/extensions/air_topk.cpp: JAX FFI bindingtransformer_engine/jax/cpp_extensions/air_topk.py: Python wrappertransformer_engine/common/CMakeLists.txt: compile new kernel; use CCCL from CUDA toolkit; fix SM100 arch handling when all arches are special-casedChecklist: