[FEAT][kernels] Fix CUDA extension build on non-Hopper SM>=90 (Blackkwell SM120)#91
Conversation
…kwell SM120) `pip install -e .` force-built the Hopper-only TMA fused-logp kernel (csrc/cuda/fused_logp_sm90.cu) on every device with compute capability >= 9, including Blackwell (SM120) and SM100. Its hardcoded `-gencode=arch=compute_90a,code=sm_90a` also suppressed PyTorch's automatic native-arch gencode, so the entire extension — including the generic and attention kernels — was compiled for sm_90a only and could not load on the actual device. The TMA kernel is additionally non-functional on all architectures (TMA box width exceeds the 256-element cuTensorMapEncodeTiled limit; its warp-specialized layout deadlocks cub::BlockReduce across a partial block), so it should not be built by default. setup.py: - Build the experimental TMA kernel only via KERNEL_ALIGN_FORCE_SM90=1 (off by default), so the default build compiles the generic fused kernel for the detected native architecture and runs on SM120 + CUDA 13. - When opted in, emit the detected device's architecture-specific gencode (SM90->90a, SM120->120a) instead of a hardcoded compute_90a. registry.py: - Prioritize the TMA logp op only when its symbol is compiled into _C, and drop the misleading "Failed to instantiate CUDA_FUSED_LOGP_SM90" ERROR that fired on every non-Hopper SM>=9 run before falling back. Verified on RTX PRO 6000 (SM120) + CUDA 13: build succeeds, the example selects FusedLogpGenericOp, --require-fused-logp passes (kernel_max_abs_error 4.77e-07). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR refines SM90 CUDA kernel availability detection and compilation. The build system now gates SM90 extension compilation behind an environment variable and derives gencode targets from detected device capability. Runtime kernel selection validates extension presence and SM major version before prioritizing the fused TMA LogP backend. ChangesSM90 CUDA Kernel Build and Runtime Selection
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
LGTM. The default build path looks good and the registry fallback behavior is cleaner. I only noticed a minor non-blocking edge case around |
Flink-ddd
left a comment
There was a problem hiding this comment.
LGTM, I think we can merge this PR. Thanks.
#87
Summary
pip install -e .force-built the Hopper-only TMA fused-logp kernel (csrc/cuda/fused_logp_sm90.cu) on every device with compute capability >= 9, including Blackwell (SM120) and SM100. Its hardcodedgencode=arch=compute_90a,code=sm_90aalso suppressed PyTorch's automatic native-arch gencode, so the entire extension — including the generic and attention kernels — was compiled for sm_90a only and could not load on the actual device.Change
setup.py: When opted in, emit the detected device's architecture-specific gencode (SM90->90a, SM120->120a) instead of a hardcoded compute_90a.registry.py: Prioritize the TMA logp op only when its symbol is compiled into _C, and drop the misleading "Failed to instantiate CUDA_FUSED_LOGP_SM90" ERROR that fired on every non-Hopper SM>=9 run before falling back.Tests on Blackwell and Hopper
Summary by CodeRabbit
KERNEL_ALIGN_FORCE_SM90="1").