Conversation
There was a problem hiding this comment.
Pull request overview
This pull request integrates rocir (ROCm Intermediate Representation) coordinate operations into three GPU kernel test files for softmax, layernorm, and rmsnorm operators. The changes replace direct index calculations with rocir's coordinate transformation system, demonstrating the usage of rocir layout algebra for shared memory indexing.
- Adds rocir dialect import and coordinate transformation infrastructure
- Replaces direct shared memory indexing with rocir.crd2idx transformations
- Implements consistent get_smem_idx helper pattern across all three operators
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| tests/python/gpu/test_softmax.py | Adds rocir layout definitions and coordinate transformations for 1D shared memory access in softmax kernel, replacing direct indexing with get_smem_idx helper |
| tests/python/gpu/test_rmsnorm.py | Integrates rocir coordinate operations for RMSNorm shared memory indexing, using layout-based index computation |
| tests/python/gpu/test_layernorm.py | Implements rocir-based coordinate transformations for LayerNorm shared memory accesses across mean and variance reduction stages |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| n_idx_c = arith.constant(T.index(), N) | ||
| one_idx_c = arith.constant(T.index(), 1) |
There was a problem hiding this comment.
These constants (n_idx_c and one_idx_c) duplicate values that are created again at lines 108-109 and 146-147 within the if blocks. Consider defining zero_idx, one_idx, and n_idx once at the beginning of the function (like in test_rmsnorm.py and test_layernorm.py lines 68-70), and reusing them for both the rocir layout creation and the ForOp loops. This would improve code maintainability and reduce redundant constant creation.
| n_idx_c = arith.constant(T.index(), N) | |
| one_idx_c = arith.constant(T.index(), 1) | |
| zero_idx_c = arith.constant(T.index(), 0) | |
| one_idx_c = arith.constant(T.index(), 1) | |
| n_idx_c = arith.constant(T.index(), N) |
| if_op = scf.IfOp(is_thread_0.value) | ||
| with ir.InsertionPoint(if_op.then_block): | ||
| init_sum = memref.load(smem, [zero_idx.value]) | ||
| zero_smem_idx = get_smem_idx(zero_idx.value) |
There was a problem hiding this comment.
Computing zero_smem_idx here inside the if block means it needs to be recomputed at line 129 when broadcasting. Consider computing zero_smem_idx = get_smem_idx(zero_idx.value) before the if block (similar to test_softmax.py line 93) to avoid redundant coordinate transformation operations.
| if_op = scf.IfOp(is_thread_0.value) | ||
| with ir.InsertionPoint(if_op.then_block): | ||
| init_sum = memref.load(smem, [zero_idx.value]) | ||
| zero_smem_idx = get_smem_idx(zero_idx.value) |
There was a problem hiding this comment.
Computing zero_smem_idx here inside the first if block means it needs to be recomputed at line 126 when broadcasting. Consider computing zero_smem_idx = get_smem_idx(zero_idx.value) before the first if block (similar to test_softmax.py line 93) to avoid redundant coordinate transformation operations.
| if_op_var = scf.IfOp(is_thread_0.value) | ||
| with ir.InsertionPoint(if_op_var.then_block): | ||
| init_var_sum = memref.load(smem, [zero_idx.value]) | ||
| zero_smem_idx = get_smem_idx(zero_idx.value) |
There was a problem hiding this comment.
Computing zero_smem_idx again inside the second if block means it needs to be recomputed at line 164 when broadcasting. If zero_smem_idx was computed before the first if block as suggested for line 103, it could be reused here as well, avoiding multiple redundant coordinate transformations throughout the function.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # Broadcast Mean | ||
| mean = memref.load(smem, [zero_idx.value]) | ||
| mean = memref.load(smem, [get_smem_idx(zero_idx.value)]) |
There was a problem hiding this comment.
Recalculating get_smem_idx(zero_idx.value) inline could be avoided by defining zero_smem_idx once before all if blocks. This would improve efficiency and consistency with test_softmax.py.
| if_op_var = scf.IfOp(is_thread_0.value) | ||
| with ir.InsertionPoint(if_op_var.then_block): | ||
| init_var_sum = memref.load(smem, [zero_idx.value]) | ||
| zero_smem_idx = get_smem_idx(zero_idx.value) |
There was a problem hiding this comment.
Defining zero_smem_idx again inside this if block is redundant. If defined once before all if blocks (as suggested for line 103), this line could be removed.
| zero_smem_idx = get_smem_idx(zero_idx.value) |
|
|
||
| # Broadcast Variance | ||
| variance = memref.load(smem, [zero_idx.value]) | ||
| variance = memref.load(smem, [get_smem_idx(zero_idx.value)]) |
There was a problem hiding this comment.
Recalculating get_smem_idx(zero_idx.value) inline could be avoided by defining zero_smem_idx once before all if blocks. This is the third recalculation in this function.
| variance = memref.load(smem, [get_smem_idx(zero_idx.value)]) | |
| variance = memref.load(smem, [zero_smem_idx]) |
| n_idx_c = arith.constant(T.index(), N) | ||
| one_idx_c = arith.constant(T.index(), 1) | ||
|
|
||
| smem_shape = rocir.make_shape(n_idx_c) | ||
| smem_stride = rocir.make_stride(one_idx_c) | ||
| smem_layout = rocir.make_layout(smem_shape, smem_stride) | ||
|
|
There was a problem hiding this comment.
The constants n_idx_c and one_idx_c could be named n_idx and one_idx (without the _c suffix) for consistency with test_rmsnorm.py and test_layernorm.py. Additionally, these variables could be reused on lines 108-109 and 146-147 instead of recreating the same constants, reducing code duplication.
| n_idx_c = arith.constant(T.index(), N) | |
| one_idx_c = arith.constant(T.index(), 1) | |
| smem_shape = rocir.make_shape(n_idx_c) | |
| smem_stride = rocir.make_stride(one_idx_c) | |
| smem_layout = rocir.make_layout(smem_shape, smem_stride) | |
| n_idx = arith.constant(T.index(), N) | |
| one_idx = arith.constant(T.index(), 1) | |
| smem_stride = rocir.make_stride(one_idx) |
| one_idx = arith.constant(T.index(), 1) | ||
| n_idx = arith.constant(T.index(), N) |
There was a problem hiding this comment.
These constants are redundant. Consider reusing the one_idx_c and n_idx_c variables defined on lines 65-66 (or rename them to one_idx and n_idx and reuse them here and on lines 146-147). This would improve consistency with test_rmsnorm.py and test_layernorm.py.
| # Use zero_smem_idx again (re-calculate since it's outside the if block scope, or just use get_smem_idx) | ||
| mean_sq = memref.load(smem, [get_smem_idx(zero_idx.value)]) |
There was a problem hiding this comment.
The recalculation of get_smem_idx(zero_idx.value) can be avoided by defining zero_smem_idx once before all if blocks (similar to test_softmax.py line 93). The comment on line 128 acknowledges this but the inefficiency could be eliminated with better scoping.
| one_idx = arith.constant(T.index(), 1) | ||
| n_idx = arith.constant(T.index(), N) |
There was a problem hiding this comment.
These constants are redundant duplicates of lines 108-109. Consider defining one_idx and n_idx once before all if blocks (similar to test_rmsnorm.py and test_layernorm.py) and reusing them throughout the function.
Motivation
Technical Details
Test Plan
Test Result
========================================================================
Test Summary
MLIR IR Tests (Lowering): 22/22 passed
Python IR Tests (Generation): 15/15 passed
Example Tests (ROCDL): 0/0 passed
GPU Execution Tests: 13/13 passed
Benchmark Tests: 2/2 passed
Verified Capabilities:
✓ Rocir IR generation and lowering
✓ Coordinate operations (crd2idx, layouts)
✓ ROCDL dialect operations (381 ops exposed)
✓ GPU kernel compilation (MLIR → HSACO)
✓ GPU kernel execution (HIP runtime)
✓ Shared memory optimizations (LDS)
✓ MFMA operations (Pure Python API)
✓ Performance benchmarking (bandwidth tests)
Submission Checklist