Skip to content

fix stride check error in fused_qk_norm_group_quant#2637

Merged
yzhou103 merged 1 commit intomainfrom
fix_fuse_qk_norm_group_quant_with_stride
Apr 8, 2026
Merged

fix stride check error in fused_qk_norm_group_quant#2637
yzhou103 merged 1 commit intomainfrom
fix_fuse_qk_norm_group_quant_with_stride

Conversation

@yzhou103
Copy link
Copy Markdown
Contributor

@yzhou103 yzhou103 commented Apr 7, 2026

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@yzhou103 yzhou103 requested review from a team and Copilot April 7, 2026 09:01
@yzhou103 yzhou103 changed the title fix stride error check in fused_qk_norm_group_quant fix stride check error in fused_qk_norm_group_quant Apr 7, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 7, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2637 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates fused_qk_rmsnorm_group_quant to correctly support non-contiguous (but last-dimension-contiguous) Q/K/residual inputs by replacing overly strict contiguity checks with stride-based validation, and it fixes scale indexing for transpose_scale by passing explicit row/column strides into the kernel.

Changes:

  • Relax input/output layout requirements from fully contiguous to stride(1)==1 with sane leading-dimension stride checks.
  • Make q_out_scale writes use explicit (row_stride, col_stride) so transpose_scale is handled via correct addressing rather than shape-only assumptions.
  • Update the op test to generate realistic strided Q/K/residual views via torch.split, exercising the new stride checks.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
csrc/kernels/fused_qk_rmsnorm_group_quant.cu Replaces .is_contiguous() gating with stride-based checks and threads scale strides into the kernel to fix transpose_scale addressing.
op_tests/test_fused_qk_rmsnorm_group_quant.py Constructs strided Q/K/residual inputs (via split-views) and adjusts scale buffer allocation for transpose_scale testing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@yzhou103 yzhou103 merged commit c9c63d5 into main Apr 8, 2026
29 checks passed
@yzhou103 yzhou103 deleted the fix_fuse_qk_norm_group_quant_with_stride branch April 8, 2026 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants