Skip to content

FP8 Support #56

Open
Open
@endurehero

Description

@endurehero

PR

#54

Intro

Support FP8 WGMMA based on the async pipeline design of FlashMLA. The TransV part draws on the implementation of SmemTranspose64x64 in Fa3.
Currently, Q/K/V only support symmetric PerTensor quantization. Since the maximum value of P does not exceed 1, the f32tofp8_cast is directly used for quantization.

Performance

cuda driver version: 535.183.06
nvcc version: 12.8
torch version: 2.6

On the H20, MLA typically demonstrate a high degree of arithmetic intensity. Consequently, the Memory Floating - point Utilization (MFU) is employed as a performance metric.
image

On the H800, MLA typically encounter memory-bound situations. Consequently, the Memory Bandwidth Utilization (MBU) metric is adopted to evaluate the performance of the kernel. There is still a lot of room for optimization on the H800. Look forward to working together.
image

Reproduction

python3 ./tests/test_flash_mla.py --dtype e4m3

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions