In [None]:
import torch

from triton.tools.tensor_descriptor import TensorDescriptor
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor

from kernel_round_fp4 import round_to_fp4, quantize_to_fp4

In [3]:
import torch

x = torch.randn(128 ,128)


print(x.is_contiguous())
print(x.T.is_contiguous())



True
False


In [None]:
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
FP4_MAX = 6.0
NVFP4_GROUP_SIZE = 16


def group_tensor(x, gs, col_wise=False):
    assert len(x.shape)==2
    if col_wise:
        return x.T.reshape(x.shape[1], x.shape[0] // gs, gs)
    return x.reshape(x.shape[0], x.shape[1] // gs, gs)


torch.compile()
def get_nvfp_scales(x, col_wise):
    x_shape = x.shape

    s_enc = (FP8_MAX * FP4_MAX) / x.abs().max().float()
    s_dec = 1 / s_enc

    x = group_tensor(x, NVFP4_GROUP_SIZE, col_wise=col_wise)

    s_dec_b = x.abs().amax(dim=-1) / FP4_MAX
    s_dec_b_fp8 = (s_dec_b * s_enc).to(torch.float8_e4m3fn)

    s_enc_b = 1 / (s_dec_b_fp8.float() * s_dec)

    x = (x * s_enc_b.unsqueeze(-1)).reshape(x_shape)
    return x_scaled, 


def quantize_NVFP4(x, col_wise=False):
    x_shape = x.shape

    s_enc = (FP8_MAX * FP4_MAX) / x.abs().max().float()
    s_dec = 1 / s_enc

    x = group_tensor(x, NVFP4_GROUP_SIZE, col_wise=col_wise)

    s_dec_b = x.abs().amax(dim=-1) / FP4_MAX
    s_dec_b_fp8 = (s_dec_b * s_enc).to(torch.float8_e4m3fn)

    s_enc_b = 1 / (s_dec_b_fp8.float() * s_dec)

    x = (x * s_enc_b.unsqueeze(-1)).reshape(x_shape)

    x_q = quantize_to_fp4(x)

    return x_q, s_dec_b_fp8, s_dec


def reconstruct_weight(x_packed, group_scales, tensor_scale, col_wise=False):
    x_scaled = MXFP4Tensor(size=(x_packed.shape[0], x_packed.shape[1] * 2), device=x_packed.device)
    x_scaled.data = x_scaled.unpack_packed_tensor(x_packed, dim=1, original_shape=x_scaled.size)
    x_scaled_grouped = group_tensor(x_scaled.to(torch.float32), NVFP4_GROUP_SIZE, col_wise=False)
    x_unpacked = (x_scaled_grouped * group_scales.float().unsqueeze(-1) * tensor_scale).reshape(x_packed.shape[0], x_packed.shape[1] * 2)
    return x_unpacked

In [19]:
torch.manual_seed(0)

x = torch.randn(4096, 4096, dtype=torch.float32, device="cuda")

In [20]:
x_scaled_packed, s_dec_b_fp8, s_dec = quantize_NVFP4(x, col_wise=False)

In [18]:
x_reco = reconstruct_weight(x_scaled_packed, s_dec_b_fp8, s_dec)

torch.Size([4096, 4096])


In [None]:
def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference=False):
    BLOCK_M = 128
    BLOCK_N = 256
    BLOCK_K = 256 if "fp4" in block_scale_type else 128
    VEC_SIZE = 16 if block_scale_type == "nvfp4" else 32
    assert block_scale_type in ["nvfp4", "mxfp4", "mxfp8", "mixed"], f"Invalid block scale type: {block_scale_type}"
    ELEM_PER_BYTE_A = 2 if "fp4" in block_scale_type else 1
    ELEM_PER_BYTE_B = 1 if block_scale_type == "mxfp8" else 2

    device = "cuda"
    a_ref = MXFP4Tensor(size=(M, K), device=device).random()
    # Similar to Hopper's wgmma symmetric fp8 instruction, the RHS is expected
    # to be in col-major layout for Blackwell's tcgen05.mma when using fp4 operands.
    # To conform to the expected semantics of tl.dot_scaled, (M, K) x (K, N),
    # the data is generated in col-major layout, packed along K for fp4, and then
    # logically transposed. Note that if one operand is of fp8 precision, unlike Hopper,
    # Blackwell supports both row-major and col-major layouts for the RHS matrix.
    # For the mixed-precision case, the fp4 RHS can be either in row or col-major layout.
    # But for performance reason, it is recommended to use col-major layout. If TMA is used
    # for the fp4 RHS operand load in mixed-precision dot, as in this tutorial, it must be
    # in col-major layout.
    b_ref = MXFP4Tensor(size=(N, K), device=device).random()
    if block_scale_type in ["mxfp8", "mixed"]:
        a_ref = a_ref.to(torch.float32)
        a = a_ref.to(torch.float8_e4m3fn)
    else:
        # Pack two fp4 elements per byte along K
        a = a_ref.to_packed_tensor(dim=1)

    if block_scale_type == "mxfp8":
        b_ref = b_ref.to(torch.float32)
        b = b_ref.to(torch.float8_e4m3fn)
    else:
        b = b_ref.to_packed_tensor(dim=1)

    b_ref = b_ref.to(torch.float32).T

    a_desc = TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A])
    b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B])

    a_scale_shape = [M // 128, K // VEC_SIZE // 4, 32, 16]
    b_scale_shape = [N // 128, K // VEC_SIZE // 4, 32, 16]
    epsilon = 1e-8
    a_scale = torch.rand(a_scale_shape, device=device) + epsilon
    b_scale = torch.rand(b_scale_shape, device=device) + epsilon
    if block_scale_type == "nvfp4":
        a_scale = a_scale.to(torch.float8_e4m3fn)
        b_scale = b_scale.to(torch.float8_e4m3fn)
        a_scale_ref = a_scale
        b_scale_ref = b_scale
    elif block_scale_type in ["mxfp4", "mxfp8", "mixed"]:
        a_scale_ref = MXScaleTensor(a_scale)
        b_scale_ref = MXScaleTensor(b_scale)
        a_scale = a_scale_ref.data
        b_scale = b_scale_ref.data

    rep_m = BLOCK_M // 128
    rep_n = BLOCK_N // 128
    rep_k = BLOCK_K // VEC_SIZE // 4

    # Use 5D TMA descriptor [1, rep_m, rep_k, 2, 256] with uint8 elements.
    # With 256 elements we better utilize the L2 and don't require the TMA
    # engine to emit many small messages (16B) messages as with 32x16xu8.
    a_scale_block_shape = [1, rep_m, rep_k, 2, 256]
    b_scale_block_shape = [1, rep_n, rep_k, 2, 256]
    a_scale = a_scale.reshape(1, a_scale_shape[0], a_scale.shape[1], 2, 256)
    b_scale = b_scale.reshape(1, b_scale_shape[0], b_scale.shape[1], 2, 256)
    a_scale_desc = TensorDescriptor.from_tensor(a_scale, block_shape=a_scale_block_shape)
    b_scale_desc = TensorDescriptor.from_tensor(b_scale, block_shape=b_scale_block_shape)

    reference = None
    if compute_reference:
        a_scale_ref = a_scale_ref.to(torch.float32)
        b_scale_ref = b_scale_ref.to(torch.float32)

        def unpack_scale(packed):
            packed = packed.reshape(*packed.shape[:-2], 32, 4, 4)
            num_chunk_m, num_chunk_k, _, _, _ = packed.shape
            return packed.permute(0, 3, 2, 1, 4).reshape(num_chunk_m * 128, num_chunk_k * 4).contiguous()

        a_scale_ref = unpack_scale(a_scale_ref).repeat_interleave(VEC_SIZE, dim=1)[:M, :K]
        b_scale_ref = unpack_scale(b_scale_ref).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:K, :N]
        
        # print(a_ref.to(torch.float32).shape, a_scale_ref.shape, b_ref.shape, b_scale_ref.shape)
        
        reference = torch.matmul(a_ref.to(torch.float32) * a_scale_ref, b_ref * b_scale_ref)

    configs = {
        "BLOCK_SIZE_M": BLOCK_M,
        "BLOCK_SIZE_N": BLOCK_N,
        "BLOCK_SIZE_K": BLOCK_K,
        "num_stages": 4,
        "ELEM_PER_BYTE_A": ELEM_PER_BYTE_A,
        "ELEM_PER_BYTE_B": ELEM_PER_BYTE_B,
        "VEC_SIZE": VEC_SIZE,
    }
    return a_desc, a_scale_desc, b_desc, b_scale_desc, rep_m, rep_n, rep_k, configs, reference


In [None]:
M, N, K = (128, 128, 128)
block_scale_type="nvfp4"
a_desc, a_scale, b_desc, b_scale, rep_m, rep_n, rep_k, configs, reference = initialize_block_scaled(M, N, K, block_scale_type, compute_reference=True)


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

In [None]:
a_desc

TensorDescriptor(base=tensor([[246, 249, 160,  ..., 178, 137, 155],
        [150, 213, 249,  ..., 185, 109,  66],
        [ 24,  59,  92,  ..., 225,   3, 220],
        ...,
        [225,  54,  76,  ..., 103, 222, 192],
        [123, 104, 202,  ..., 150, 108, 205],
        [188,  42, 221,  ...,  25, 233, 155]], device='cuda:0',
       dtype=torch.uint8), shape=torch.Size([128, 64]), strides=(64, 1), block_shape=[128, 128])

In [None]:
a_ref = MXFP4Tensor(size=(M, K), device="cuda").random()
a_ref.data, a_ref.to_packed_tensor(dim=1), a_ref.to(torch.float32)

(tensor([[ 0,  9,  6,  ..., 14,  6,  3],
         [12,  9, 14,  ..., 12,  3, 14],
         [ 5, 11,  8,  ..., 13,  5, 10],
         ...,
         [ 9,  9,  9,  ..., 15,  0,  7],
         [ 5,  6,  1,  ...,  0, 10,  7],
         [ 9,  1, 12,  ...,  8,  5,  5]], device='cuda:0', dtype=torch.uint8),
 tensor([[144,  86,  30,  ...,  48, 227,  54],
         [156, 142, 221,  ..., 228, 199, 227],
         [181,  72,  57,  ..., 192, 216, 165],
         ...,
         [153, 153, 235,  ...,  34, 254, 112],
         [101,  65,  27,  ..., 149,   4, 122],
         [ 25, 220,  52,  ..., 115, 135,  85]], device='cuda:0',
        dtype=torch.uint8),
 tensor([[ 0.0000, -0.5000,  4.0000,  ..., -4.0000,  4.0000,  1.5000],
         [-2.0000, -0.5000, -4.0000,  ..., -2.0000,  1.5000, -4.0000],
         [ 3.0000, -1.5000, -0.0000,  ..., -3.0000,  3.0000, -1.0000],
         ...,
         [-0.5000, -0.5000, -0.5000,  ..., -6.0000,  0.0000,  6.0000],
         [ 3.0000,  4.0000,  0.5000,  ...,  0.0000, -1.0000,  

In [None]:
x_fp = torch.randn(128, 128).cuda()

In [None]:
a_ref._from_float(x_fp).to(torch.float32)

tensor([[ 3.,  9., 10.,  ..., 10., 10.,  0.],
        [ 3.,  8.,  9.,  ...,  9.,  9.,  2.],
        [ 1.,  3.,  2.,  ...,  1.,  9.,  3.],
        ...,
        [12.,  2.,  2.,  ..., 11.,  1., 10.],
        [ 1., 10.,  3.,  ...,  9.,  9., 10.],
        [11.,  9.,  2.,  ...,  8., 11.,  9.]], device='cuda:0')

In [None]:
a_scale, a_scale.base

(TensorDescriptor(base=tensor([[[[[0.2344, 0.8125, 0.6875,  ..., 1.0000, 0.3750, 0.4062],
            [0.0391, 0.5000, 0.9375,  ..., 0.4375, 0.5625, 0.1406]],
 
           [[1.0000, 0.7500, 0.0625,  ..., 0.6875, 0.3125, 0.5625],
            [0.4062, 0.5625, 0.2188,  ..., 0.5625, 0.2188, 0.1719]]]]],
        device='cuda:0', dtype=torch.float8_e4m3fn), shape=torch.Size([1, 1, 2, 2, 256]), strides=(1024, 1024, 512, 256, 1), block_shape=[1, 1, 4, 2, 256]),
 tensor([[[[[0.2344, 0.8125, 0.6875,  ..., 1.0000, 0.3750, 0.4062],
            [0.0391, 0.5000, 0.9375,  ..., 0.4375, 0.5625, 0.1406]],
 
           [[1.0000, 0.7500, 0.0625,  ..., 0.6875, 0.3125, 0.5625],
            [0.4062, 0.5625, 0.2188,  ..., 0.5625, 0.2188, 0.1719]]]]],
        device='cuda:0', dtype=torch.float8_e4m3fn))