Skip to content

add multi_wave_cached rms norm#113

Merged
xjmxyt merged 4 commits into
NVIDIA:mainfrom
liqiangxl:rmsnorm-mode-parameter
Apr 29, 2026
Merged

add multi_wave_cached rms norm#113
xjmxyt merged 4 commits into
NVIDIA:mainfrom
liqiangxl:rmsnorm-mode-parameter

Conversation

@liqiangxl
Copy link
Copy Markdown
Collaborator

Add rms_norm_kernel_multi_wave_cached, a single-tile RMSNorm kernel that caches inputs in registers to avoid reloading from memory.

Replace the boolean static_persistent parameter with a mode parameter for explicit kernel selection:

  • None: heuristic selection based on tensor shape (default)
  • "static_persistent": rms_norm_kernel_static_persistent
  • "multi_wave_reload": rms_norm_kernel_multi_wave_reload
  • "multi_wave_cached": rms_norm_kernel_multi_wave_cached

Rename kernels for consistency:

  • rms_norm_kernel_gather -> rms_norm_kernel_multi_wave_reload
  • rms_norm_kernel_gather_regs_cached -> rms_norm_kernel_multi_wave_cached

Update benchmark to compare all kernel modes side-by-side per dtype.

Performance on GB200 with M = 4096

Current do_bench_cudagraph based performance is not reliable see #82, so I doubled checked with torch.profiler, see 3c8de12

dtype N Reload do_bench_cudagraph (GB/s) Cached do_bench_cudagraph (GB/s) Speedup do_bench_cudagraph Reload profiler (GB/s) Cached profiler (GB/s) Speedup profiler
float16 1024 4314.1 4428.8 1.03x 3084.4 3139.8 1.02x
float16 2048 6253.6 6252.2 1.00x 4128.8 4128.8 1.00x
float16 4096 8152.3 8241.5 1.01x 5128.1 5204.1 1.01x
float16 8192 5916.0 6317.6 1.07x 5645.5 5949.8 1.05x
float16 16384 6197.5 6404.7 1.03x 5967.0 6365.4 1.07x
bfloat16 1024 4364.3 4374.6 1.00x 3066.4 3102.7 1.01x
bfloat16 2048 6319.6 6328.8 1.00x 4049.5 4033.5 1.00x
bfloat16 4096 8027.3 7560.4 0.94x 4958.8 5103.2 1.03x
bfloat16 8192 5984.2 6400.0 1.07x 5563.6 5908.2 1.06x
bfloat16 16384 6017.0 6385.3 1.06x 5742.3 6341.7 1.10x

Description

CI Configuration

config:
  build: true
  # valid options are "ops", "benchmark", and "sanity"
  test: ["ops", "benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

…ection

Add rms_norm_kernel_multi_wave_cached, a single-tile RMSNorm kernel that
caches inputs in registers to avoid reloading from memory.

Replace the boolean static_persistent parameter with a mode parameter
for explicit kernel selection:
- None: heuristic selection based on tensor shape (default)
- "static_persistent": rms_norm_kernel_static_persistent
- "multi_wave_reload": rms_norm_kernel_multi_wave_reload
- "multi_wave_cached": rms_norm_kernel_multi_wave_cached

Rename kernels for consistency:
- rms_norm_kernel_gather -> rms_norm_kernel_multi_wave_reload
- rms_norm_kernel_gather_regs_cached -> rms_norm_kernel_multi_wave_cached

Update benchmark to compare all kernel modes side-by-side per dtype.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def reference_rms_norm(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should keep the same interface as src/tilegym/ops/ops.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revised

@liqiangxl liqiangxl requested a review from xjmxyt April 27, 2026 13:12
@xjmxyt
Copy link
Copy Markdown
Collaborator

xjmxyt commented Apr 29, 2026

/ok to test 99c6f79

@xjmxyt xjmxyt merged commit d9bf003 into NVIDIA:main Apr 29, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants