Description
PR
Intro
Support FP8 WGMMA based on the async pipeline design of FlashMLA. The TransV part draws on the implementation of SmemTranspose64x64 in Fa3.
Currently, Q/K/V only support symmetric PerTensor quantization. Since the maximum value of P does not exceed 1, the f32tofp8_cast is directly used for quantization.
Performance
cuda driver version: 535.183.06
nvcc version: 12.8
torch version: 2.6
On the H20, MLA typically demonstrate a high degree of arithmetic intensity. Consequently, the Memory Floating - point Utilization (MFU) is employed as a performance metric.
On the H800, MLA typically encounter memory-bound situations. Consequently, the Memory Bandwidth Utilization (MBU) metric is adopted to evaluate the performance of the kernel. There is still a lot of room for optimization on the H800. Look forward to working together.
Reproduction
python3 ./tests/test_flash_mla.py --dtype e4m3