Conversation
Add an algebraic simplification pass that recognizes the canonical
three-instruction swizzle sequence (andi/shrui/xori) emitted by
`applySwizzle` and peels period-aligned addends out of the swizzle:
swizzle(base + d) → swizzle(base) + d when d % period == 0
In example04 (preshuffle_gemm) with SwizzleType S<3,3,3> (period=512),
the LDS address for each pipeline stage is `swizzle(thread_offset +
stage * lds_stride)`. Since lds_stride (BLOCK_M × BLOCK_K × sizeof(f16)
= 16384) is a multiple of 512, the pass peels the stage offset out:
Before: each stage recomputes the full swizzle (andi+shrui+xori) on
(thread_offset + stage_offset), duplicating 3 ALU ops per stage.
After: swizzle is computed once on thread_offset; stage offsets become
plain arith.addi with an immediate, which then fold/CSE across
stages.
Concrete effects on the ds_read/ds_write address computation in the
GEMM hot loop:
- Instruction count: eliminates redundant andi+shrui+xori triplets per
pipeline stage (−6 SALU/VALU ops for 2-stage, −9 for 3-stage).
- Immediate offsets: the peeled constant addends enable ds_read_b128
with immediate byte offsets instead of a full v_add + swizzle, which
the ISA can encode directly in the ds instruction offset field.
- Register pressure: fewer live intermediates (%masked, %shifted) per
stage reduces VGPR demand in the address-computation section.
The pass also includes a divisibility lattice that propagates alignment
info both forward (through arith add/mul/shl) and backward (from
IntAttr.divisibility annotations on Make*Op results), enabling the
optimization even for dynamic strides with known alignment.
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Adds an MLIR optimization pass (fly-int-swizzle-simplify) to algebraically simplify the canonical swizzle sequence emitted by applySwizzle, peeling period-aligned addends out of the swizzle input to reduce redundant ALU ops and improve downstream folding/CSE, and wires it into the ROCm backend pipeline.
Changes:
- Introduce
fly-int-swizzle-simplifypass with a divisibility lattice to detect period-aligned addends and rewrite swizzle-shapedandi/shrui/xoripatterns. - Register/build the new transform (CMake + Passes.td) and enable it in the ROCm compilation pipeline.
- Add MLIR tests covering peel/no-peel behavior for i32 and i64.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
lib/Dialect/Fly/Transforms/IntSwizzleSimplify.cpp |
Implements the new swizzle simplification pass and divisibility propagation. |
include/flydsl/Dialect/Fly/Transforms/Passes.td |
Registers the new pass and documents its intent/pattern. |
python/flydsl/compiler/backends/rocm.py |
Inserts the pass into the ROCm lowering pipeline. |
tests/mlir/Transforms/int_swizzle_simplify.mlir |
Adds FileCheck coverage for the optimization and non-optimization cases. |
lib/Dialect/Fly/CMakeLists.txt |
Builds the new transform source file into the Fly dialect library. |
include/flydsl/Dialect/Fly/Utils/IntUtils.h / lib/Dialect/Fly/Utils/IntUtils.cpp |
Adds an isDivisibleBy helper for IntAttr. |
include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td |
Adds SwizzleAttr::period() convenience method. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Address review comments on PR #427: - i64 mask truncation: change mask/low/plus from unsigned/int to uint64_t so i64 swizzle masks are not silently truncated to 32 bits before the ring-of-1s check and period computation. - Constant divisibility overflow: saturate large i64 constants to INT32_MAX instead of casting directly to int, preventing values like 0x1_0000_0000 from truncating to 0 and falsely marking everything as period-aligned. - period <= 0 guard: bail out conservatively in peelByPeriod when period is non-positive, avoiding divide-by-zero UB in the modulo. - dependentDialects: add gpu::GPUDialect so the pass can be run in isolation without relying on upstream dialect loading. - Passes.td description: fix arith.xori to use valid two-operand MLIR SSA form; remove mention of subi (subtraction peeling not implemented). - Pipeline: insert canonicalize after fly-int-swizzle-simplify to eagerly clean up dead andi/shrui ops left behind after xori rewrite. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add an algebraic simplification pass that recognizes the canonical three-instruction swizzle sequence (andi/shrui/xori) emitted by
applySwizzleand peels period-aligned addends out of the swizzle:In example04 (preshuffle_gemm) with SwizzleType S<3,3,3> (period=512), the LDS address for each pipeline stage is
swizzle(thread_offset + stage * lds_stride). Since lds_stride (BLOCK_M × BLOCK_K × sizeof(f16) = 16384) is a multiple of 512, the pass peels the stage offset out:Before: each stage recomputes the full swizzle (andi+shrui+xori) on
(thread_offset + stage_offset), duplicating 3 ALU ops per stage.
After: swizzle is computed once on thread_offset; stage offsets become
plain arith.addi with an immediate, which then fold/CSE across
stages.
Concrete effects on the ds_read/ds_write address computation in the GEMM hot loop:
The pass also includes a divisibility lattice that propagates alignment info both forward (through arith add/mul/shl) and backward (from IntAttr.divisibility annotations on Make*Op results), enabling the optimization even for dynamic strides with known alignment.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist