Skip to content

Pre v0.1 gemm fix#153

Merged
coderfeli merged 15 commits intopre_v0.1from
pre_v0.1_gemm_fix
Feb 27, 2026
Merged

Pre v0.1 gemm fix#153
coderfeli merged 15 commits intopre_v0.1from
pre_v0.1_gemm_fix

Conversation

@coderfeli
Copy link
Collaborator

@coderfeli coderfeli commented Feb 27, 2026

Motivation

Port the preshuffle GEMM kernel infrastructure from the internal development branch to the pre_v0.1 public branch, adding FP4/INT4/BF16 dtype support, async DMA copy, CUDA graph capture, and a streamlined build/test workflow.

Technical Details

Kernel Enhancements (kernels/)

  • FP4 (MXFP4) GEMM support: Added compile_preshuffle_gemm_w4 for gfx950 FP4 preshuffle GEMM with block-scaled MXFP4 quantization, including per-block scale loading and mfma_scale_f32_16x16x128_f8f6f4 with cbsz/blgp/opsel parameters.
  • INT4 and BF16 dtype support: Extended compile_preshuffle_gemm_a8 to handle int4 (with nibble unpacking) and bf16 element types.
  • CUDA graph capture support: Added mgpuSetCaptureStream TLS mechanism in FlirRocmRuntimeWrappers.cpp to redirect kernel launches into a capture stream; integrated with test_common.py graph capture flow.
  • Arith-optimized layout utilities (kernels/layout_utils.py): Pure-arith crd2idx/idx2crd/get that parse static layout type strings and emit plain arith ops, avoiding fly dialect round-trips in the hot path.
  • buffer_load vec_width=1 fix: Handle scalar buffer_load returning a single value (not a vector) by wrapping with vector.from_elements before bitcast.

DSL / Python Layer (python/flydsl/)

  • primitive.py cleanup: Renamed arith import to _arith to prevent import * namespace collision with flydsl.expr.arith wrapper module. Changed range_constexpr to return range(*args) for direct kernel usage.
  • arith.py function-level API: Added constant, index, index_cast, select, constant_vector, sitofp, trunc_f, andi, xori, shli, unwrap, _to_raw as thin wrappers around MLIR arith ops with ArithValue support. Added _safe_register for idempotent value caster registration. Fixed index-type division (divui for index types that lack .width).
  • rocdl.py MFMA scale op: Restructured mfma_scale_f32_16x16x128_f8f6f4 to explicitly unpack cbsz, blgp, opselA, scaleA, opselB, scaleB from the operand list.
  • buffer_ops.py: Replaced unrealized_conversion_cast with fly.extract_aligned_pointer_as_index for memref-to-pointer extraction.
  • SmemAllocator: Added global_sym_name parameter for multiple independent shared memory allocations.
  • JIT runtime: This is a hack. Fix the cudagraph mode. Will update to better way later. Use libfly_jit_runtime.so (with graph capture) over upstream libmlir_rocm_runtime.so. Temp removed redundant convert-vector-to-llvm pass from the pipeline.

IR / Dialect (include/, lib/)

  • Fly_ExtractAlignedPointerAsIndexOp: New op to extract raw pointer from fly.memref as an index value, with ROCDL lowering via AddrSpaceCastOp.

Test Result

16 passed, 2 skipped, 86 deselected in 13.52s

Benchmark results (gfx942):
  fp8  16x40960x5120:  106.41 TFLOPS, 3.350 TB/s
  fp8  16x77824x5120:  105.53 TFLOPS, 3.322 TB/s
  fp8  5120x5120x8320: 419.46 TFLOPS, 0.132 TB/s
  fp8  9728x8192x8320: 421.83 TFLOPS, 0.098 TB/s
  int8 9728x8192x8320: 420.19 TFLOPS, 0.098 TB/s
  int4 9728x8192x8320: 248.44 TFLOPS, 0.051 TB/s
  bf16 5120x5120x8320: 211.10 TFLOPS, 0.108 TB/s

@coderfeli coderfeli merged commit 78ddb69 into pre_v0.1 Feb 27, 2026
coderfeli added a commit that referenced this pull request Feb 28, 2026
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
coderfeli added a commit that referenced this pull request Mar 2, 2026
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
coderfeli added a commit that referenced this pull request Mar 2, 2026
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
@coderfeli coderfeli deleted the pre_v0.1_gemm_fix branch March 3, 2026 06:15
coderfeli added a commit that referenced this pull request Mar 3, 2026
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
jli-melchior pushed a commit that referenced this pull request Mar 18, 2026
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
jli-melchior pushed a commit that referenced this pull request Mar 18, 2026
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
jli-melchior pushed a commit that referenced this pull request Mar 19, 2026
* fix run error

* port all  gemm from main

* fuix cudagraph hack

* add int4 version

* change flymemref convert

* test ok

* add build script

* fix graph2

* add files

* fix flops

* fix path

* fix local test

* fix

* clean

* update readme
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant