[Codegen] Avoid local memory spill for barrier register tensors#93
Merged
yaoyaoding merged 1 commit intomainfrom Mar 12, 2026
Merged
[Codegen] Avoid local memory spill for barrier register tensors#93yaoyaoding merged 1 commit intomainfrom
yaoyaoding merged 1 commit intomainfrom
Conversation
Barrier addresses stored in register arrays (e.g., `uint32_t regs[6]`) were being spilled to local memory by nvcc when indexed with runtime variables (e.g., pipeline stage counters in persistent kernels). This change introduces a ConstRegTensorEmitContext that tracks CTA-invariant register tensors and their element expressions. When SliceRegisterInst accesses a tracked tensor, the emitter computes the value via arithmetic (e.g., `barriers + stage * 8`) instead of array indexing, keeping everything in registers. Changes: - Add ConstRegTensorEmitContext for tracking CTA-invariant tensors - Update BarrierAllocContext to return contiguous base address - Update AllocBarrierInst emitter to register barrier tensors - Update SliceRegisterInst emitter to use arithmetic for tracked tensors - Suppress nvcc warning 550 (set but unused variable) for fallback arrays Result: eliminates 64 bytes stack frame in blackwell matmul v7/v8 kernels. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yaoyao Ding <dingyaoyao.cs@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Barrier addresses stored in register arrays (e.g.,
uint32_t regs[6]) were being spilled to local memory by nvcc when indexed with runtime variables (e.g., pipeline stage counters in persistent kernels).This change introduces a ConstRegTensorEmitContext that tracks CTA-invariant register tensors and their element expressions. When SliceRegisterInst accesses a tracked tensor, the emitter computes the value via arithmetic (e.g.,
barriers + stage * 8) instead of array indexing, keeping everything in registers.Changes:
Result: eliminates 64 bytes stack frame in blackwell matmul v7/v8 kernels.