In [1]:
%env CUDA_VISIBLE_DEVICES=8

env: CUDA_VISIBLE_DEVICES=8


In [2]:
import torch
from torch import nn
from torch.nn import functional as F
import time

from fast_hadamard_transform import hadamard_transform


DTYPE = torch.bfloat16
# DTYPE = torch.float32
GRID = torch.tensor(
    [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0,
    0.0,  0.5,  1.0,  1.5,  2.0,  3.0,  4.0, 6.0],
    device="cuda", dtype=DTYPE,
)
EMAX = 2
SCALE = 3/4
GAUSSIAL_SCALE = 2.92247856 / 6.0


### FORWARD QUANTIZATION

def rtn_fp4(x, grid):
    inds = torch.bucketize(x, grid)

    lo = torch.clamp(inds - 1, min=0, max=15)
    hi = torch.clamp(inds,     min=0, max=15)

    g_lo = grid[lo]
    g_hi = grid[hi]

    pick_hi = (g_hi - x) <= (x - g_lo)
    return torch.where(pick_hi, g_hi, g_lo)


@torch.compile
def quantize_quest(x):
    x_grouped = x.view(-1, 32)
    shared_exps = torch.floor(torch.log2(
        GAUSSIAL_SCALE * torch.std(x_grouped, dim=-1, correction=0, keepdim=True) + 1e-8
    ))
    scales = 2 ** shared_exps
    
    scaled_x = x_grouped / scales
    
    x_fp4 = rtn_fp4(scaled_x, GRID)
    
    return (x_fp4 * scales).reshape_as(x), torch.abs(scaled_x) <= 6.0


### BACKWARD QUANTIZATION

def stochastic_round_fp4(x, grid):
    inds = torch.bucketize(x, grid)  
    
    lo = torch.clamp(inds - 1, min=0, max=15)  
    hi = torch.clamp(inds,     min=0, max=15)  

    g_lo = grid[lo]
    g_hi = grid[hi]

    delta = g_hi - g_lo
    p = torch.where(
        delta > 0,
        (x - g_lo) / delta,
        torch.full_like(x, 0.5)
    )

    u = torch.rand_like(x)
    pick_hi = u < p
    return torch.where(pick_hi, g_hi, g_lo)


@torch.compile
def quantize_tseng(x):
    x_grouped = x.view(-1, 32)
    shared_exps = torch.floor(torch.log2(x_grouped.abs().max(dim=-1, keepdim=True)[0])) - EMAX
    scales = 2 ** shared_exps / SCALE
    
    scaled_x = x_grouped / scales

    # x_fp4 = stochastic_round_fp4(scaled_x, GRID)
    x_fp4 = rtn_fp4(scaled_x, GRID)

    return (x_fp4 * scales).reshape_as(x)


In [3]:
class HadamardGemm(torch.autograd.Function):
    forward_hadamard_matrix = hadamard_transform(torch.eye(32, dtype=DTYPE, device="cuda"), scale=32**(-1/2))
    backward_hadamard_matrix = hadamard_transform(torch.eye(32, dtype=DTYPE, device="cuda"), scale=32**(-1/2))
    
    @staticmethod
    def forward(ctx, input, weight, deterministic=False):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]
        
        ctx.deterministic = deterministic
        
        input_hf = (
            input.reshape(-1, 32) @ HadamardGemm.forward_hadamard_matrix
        ).view(ctx.batch, ctx.seq, ctx.in_dim)
        weight_hf = (
            weight.reshape(-1, 32) @ HadamardGemm.forward_hadamard_matrix
        ).view(ctx.out_dim, ctx.in_dim)
        
        ctx.save_for_backward(input_hf, weight_hf)
        return F.linear(input_hf, weight_hf)
    
    @staticmethod
    @torch.compile()
    def backward(ctx, grad_output):
        input_hf, weight_hf = ctx.saved_tensors
        
        if not ctx.deterministic:
            HadamardGemm.backward_hadamard_matrix = HadamardGemm.backward_hadamard_matrix @ torch.diag(
                torch.randint(
                    0, 2, (32,),
                    device=HadamardGemm.backward_hadamard_matrix.device,
                    dtype=HadamardGemm.backward_hadamard_matrix.dtype
                ) * 2 - 1
            )
        
        grad_output_hb = (
            grad_output.reshape(-1, 32) @ HadamardGemm.backward_hadamard_matrix
        ).view(ctx.batch, ctx.seq, ctx.out_dim)
        hft_weightt_hb = (
            weight_hf.T.reshape(-1, 32) @ HadamardGemm.backward_hadamard_matrix
        ).view(ctx.in_dim, ctx.out_dim)
        grad_input_hf = F.linear(grad_output_hb, hft_weightt_hb)
        grad_input = (
            grad_input_hf.reshape(-1, 32) @ HadamardGemm.forward_hadamard_matrix.T
        ).view(ctx.batch, ctx.seq, ctx.in_dim)
        
        grad_outputt_hb = (
            grad_output.view(-1, ctx.out_dim).T.reshape(-1, 32) @ HadamardGemm.backward_hadamard_matrix
        ).view(ctx.out_dim, -1)
        hft_inputt_hb = (
            input_hf.view(-1, ctx.in_dim).T.reshape(-1, 32) @ HadamardGemm.backward_hadamard_matrix
        ).view(ctx.in_dim, -1)
        grad_weight_hf = F.linear(grad_outputt_hb, hft_inputt_hb)
        grad_weight = (
            grad_weight_hf.reshape(-1, 32) @ HadamardGemm.forward_hadamard_matrix.T
        ).view(ctx.out_dim, ctx.in_dim)
        return grad_input, grad_weight, None


class MXFP4Gemm(torch.autograd.Function):
    forward_hadamard_matrix = hadamard_transform(torch.eye(32, dtype=DTYPE, device="cuda"), scale=32**(-1/2))
    backward_hadamard_matrix = hadamard_transform(torch.eye(32, dtype=DTYPE, device="cuda"), scale=32**(-1/2))
    
    def forward(ctx, input, weight, deterministic=False):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]
        
        ctx.deterministic = deterministic
        
        input_hf, input_mask_hf = quantize_quest(
            input.view(-1, 32) @ MXFP4Gemm.forward_hadamard_matrix
        )
        input_hf, input_mask_hf = input_hf.view_as(input), input_mask_hf.view_as(input)
        weight_hf, weight_mask_hf = quantize_quest(
            weight.view(-1, 32) @ MXFP4Gemm.forward_hadamard_matrix
        )
        weight_hf, weight_mask_hf = weight_hf.view_as(weight), weight_mask_hf.view_as(weight)
        
        ctx.save_for_backward(input_hf, weight_hf, input_mask_hf, weight_mask_hf)
        return F.linear(input_hf, weight_hf)
    
    @torch.compile()
    def backward(ctx, grad_output):
        input_hf, weight_hf, input_mask_hf, weight_mask_hf = ctx.saved_tensors
        
        if not ctx.deterministic:
            MXFP4Gemm.backward_hadamard_matrix = MXFP4Gemm.backward_hadamard_matrix @ torch.diag(
                torch.randint(
                    0, 2, (32,),
                    device=MXFP4Gemm.backward_hadamard_matrix.device,
                    dtype=MXFP4Gemm.backward_hadamard_matrix.dtype
                ) * 2 - 1
            )
        
        grad_output_hb = quantize_tseng(
            grad_output.view(-1, 32) @ MXFP4Gemm.backward_hadamard_matrix.T
        ).view(ctx.batch, ctx.seq, ctx.out_dim)
        hft_weightt_hb = quantize_tseng(
            weight_hf.T.reshape(-1, 32) @ MXFP4Gemm.backward_hadamard_matrix
        ).view(ctx.in_dim, ctx.out_dim)
        grad_input_hf = F.linear(grad_output_hb, hft_weightt_hb)
        grad_input = (
            (grad_input_hf.view(-1, 32) * input_mask_hf.view(-1, 32).to(grad_input_hf.dtype))
            @ MXFP4Gemm.forward_hadamard_matrix.T
        ).view(ctx.batch, ctx.seq, ctx.in_dim)
        
        grad_outputt_hb = quantize_tseng(
            grad_output.view(-1, ctx.out_dim).T.reshape(-1, 32) @ MXFP4Gemm.backward_hadamard_matrix
        ).view(ctx.out_dim, -1)
        hft_inputt_hb = quantize_tseng(
            input_hf.view(-1, ctx.in_dim).T.reshape(-1, 32) @ MXFP4Gemm.backward_hadamard_matrix
        ).view(ctx.in_dim, -1)
        grad_weight_hf = F.linear(grad_outputt_hb, hft_inputt_hb)
        grad_weight = (
            (grad_weight_hf.view(-1, 32) * weight_mask_hf.view(-1, 32).to(grad_weight_hf.dtype))
            @ MXFP4Gemm.forward_hadamard_matrix.T
        ).view(ctx.out_dim, ctx.in_dim)
        return grad_input, grad_weight, None

In [4]:
from random import randint

import torch
import triton
import triton.language as tl
from fast_hadamard_transform import hadamard_transform


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 32 * 32}),
        triton.Config({"BLOCK_SIZE": 64 * 32}),
        triton.Config({"BLOCK_SIZE": 128 * 32}),
        triton.Config({"BLOCK_SIZE": 256 * 32}),
        triton.Config({"BLOCK_SIZE": 512 * 32}),
    ],
    key=[],
)
@triton.jit
def mxfp4_forward_kernel(
    x_ptr,
    hadamard_matrix_ptr,
    output_ptr,
    clip_mask_ptr,
    n_elements: tl.constexpr,
    hadamard_dim: tl.constexpr,
    group_size: tl.constexpr,
    seed: int,
    quest: tl.constexpr,
    stochastic_round: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):    
    offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim)
    hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim)
    
    # load x
    pid = tl.program_id(0)
    start_idx = pid * BLOCK_SIZE
    offsets = start_idx + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x_flat = tl.load(x_ptr + offsets, mask=mask)
    
    # hadamard transform
    x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
    x_had = tl.dot(x, hadamard_matrix)
    
    # group
    x_had_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size))
    
    # scale
    if quest:
        mean_squared = tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size
        mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size
        std = tl.sqrt(mean_squared - mean * mean)
        scales = (2.92247856 / 6.0) * std + 1e-8
        shared_exps = tl.exp2(tl.floor(tl.log2(scales)))
        x_had_scaled = x_had_grouped / shared_exps
    else:
        scales = tl.max(tl.abs(x_had_grouped), axis=-1, keep_dims=True)
        shared_exps = tl.exp2(tl.floor(tl.log2(scales)) - 2) 
        x_had_scaled = x_had_grouped / shared_exps * (3/4) # 3/4 is constant. In CUDA, scale the GEMM output by 16/9
    
    # quantize
    x_had_scaled_abs = tl.abs(x_had_scaled)
    x_had_scaled_sign = tl.where(
        x_had_scaled > 0,
        1,
        -1,
    )
    if stochastic_round:
        x_fp4_high = tl.where(
            x_had_scaled_abs > 4,
            6,
            tl.where(
                x_had_scaled_abs > 3,
                4,
                tl.where(
                    x_had_scaled_abs > 2,
                    3,
                    tl.where(
                        x_had_scaled_abs > 1.5,
                        2,
                        tl.where(
                            x_had_scaled_abs > 1.0,
                            1.5,
                            tl.where(
                                x_had_scaled_abs > 0.5,
                                1,
                                0.5,
                            )
                        )
                    )
                )
            )
        )
        
        x_fp4_low = tl.where(
            x_had_scaled_abs > 4,
            4,
            tl.where(
                x_had_scaled_abs > 3,
                3,
                tl.where(
                    x_had_scaled_abs > 2,
                    2,
                    tl.where(
                        x_had_scaled_abs > 1.5,
                        1.5,
                        tl.where(
                            x_had_scaled_abs > 1.0,
                            1.0,
                            tl.where(
                                x_had_scaled_abs > 0.5,
                                0.5,
                                0.0,
                            )
                        )
                    )
                )
            )
        )
        
        prob_up = (x_had_scaled_abs - x_fp4_low) / (x_fp4_high - x_fp4_low)
        sampled_prob = tl.rand(seed, offsets).reshape(BLOCK_SIZE // hadamard_dim, hadamard_dim)
        x_fp4 = tl.where(
            sampled_prob < prob_up,
            x_fp4_high,
            x_fp4_low,
        ) * x_had_scaled_sign
    else:    
        x_fp4 = tl.where(
            x_had_scaled_abs > 5,
            6,
            tl.where(
                x_had_scaled_abs > 3.5,
                4,
                tl.where(
                    x_had_scaled_abs > 2.5,
                    3,
                    tl.where(
                        x_had_scaled_abs > 1.75,
                        2,
                        tl.where(
                            x_had_scaled_abs > 1.25,
                            1.5,
                            tl.where(
                                x_had_scaled_abs > 0.75,
                                1,
                                tl.where(
                                    x_had_scaled_abs > 0.25,
                                    0.5,
                                    0,
                                )
                            )
                        )
                    )
                )
            )
        ) * x_had_scaled_sign


    # dequantize
    if quest:
        x_dequantized = x_fp4 * shared_exps
        tl.store(
            clip_mask_ptr + offsets,
            tl.reshape(x_had_scaled_abs < 6, (BLOCK_SIZE,)),
            mask=mask
        )
    else:
        x_dequantized = x_fp4 * shared_exps * (4/3) # 3/4 is constant. In CUDA, scale the GEMM output by 16/9
    
    # Reshape back to flat form for storage
    x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,))
    
    # store
    tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask)


def mxfp4_forward_kernel_wrapper(
    x,
    hadamard_matrix,
    stochastic_round=False,
    quest=True,
):    
    # Make sure inputs are contiguous
    x = x.contiguous()
    
    # Create output tensor
    output = torch.empty_like(x)
    if quest:
        clip_mask = torch.empty_like(x, dtype=torch.bool)
    else:
        clip_mask = None
    
    # Get total number of elements and calculate grid for launching the kernel
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    
    # Launch optimized kernel
    mxfp4_forward_kernel[grid](
        x_ptr=x,
        hadamard_matrix_ptr=hadamard_matrix,
        output_ptr=output,
        clip_mask_ptr=clip_mask,
        n_elements=n_elements,
        hadamard_dim=hadamard_matrix.shape[-1],
        group_size=32,
        seed=42,
        quest=quest,
        stochastic_round=stochastic_round,
    )
    
    return output, clip_mask


class TritonGemm(torch.autograd.Function):
    forward_hadamard_matrix = hadamard_transform(torch.eye(32, dtype=DTYPE, device="cuda"), scale=32**(-1/2))
    backward_hadamard_matrix = hadamard_transform(torch.eye(32, dtype=DTYPE, device="cuda"), scale=32**(-1/2))
    
    def forward(ctx, input, weight, deterministic=False):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]
        
        ctx.deterministic = deterministic
        
        input_hf, input_mask_hf = mxfp4_forward_kernel_wrapper(
            input,
            TritonGemm.forward_hadamard_matrix,
            stochastic_round=False,
            quest=True,
        )
        weight_hf, weight_mask_hf = mxfp4_forward_kernel_wrapper(
            weight,
            TritonGemm.forward_hadamard_matrix,
            stochastic_round=False,
            quest=True,
        )
        
        ctx.save_for_backward(input_hf, weight_hf, input_mask_hf, weight_mask_hf)
        return F.linear(input_hf, weight_hf)
    
    @torch.compile()
    def backward(ctx, grad_output):
        input_hf, weight_hf, input_mask_hf, weight_mask_hf = ctx.saved_tensors
        
        if not ctx.deterministic:
            TritonGemm.backward_hadamard_matrix = TritonGemm.backward_hadamard_matrix @ torch.diag(
                torch.randint(
                    0, 2, (32,),
                    device=TritonGemm.backward_hadamard_matrix.device,
                    dtype=TritonGemm.backward_hadamard_matrix.dtype
                ) * 2 - 1
            )
        
        grad_output_hb, _ = mxfp4_forward_kernel_wrapper(
            grad_output,
            TritonGemm.backward_hadamard_matrix,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        hft_weightt_hb, _ = mxfp4_forward_kernel_wrapper(
            weight_hf.T,
            TritonGemm.backward_hadamard_matrix,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        grad_input_hf = F.linear(grad_output_hb, hft_weightt_hb)
        grad_input = (
            (grad_input_hf.view(-1, 32) * input_mask_hf.view(-1, 32).to(grad_input_hf.dtype))
            @ TritonGemm.forward_hadamard_matrix.T
        ).view(ctx.batch, ctx.seq, ctx.in_dim)
        
        grad_outputt_hb, _ = mxfp4_forward_kernel_wrapper(
            grad_output.view(-1, grad_output.size(-1)).T,
            TritonGemm.backward_hadamard_matrix,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        hft_inputt_hb, _ = mxfp4_forward_kernel_wrapper(
            input_hf.view(-1, ctx.in_dim).T,
            TritonGemm.backward_hadamard_matrix,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        grad_weight_hf = F.linear(grad_outputt_hb, hft_inputt_hb)
        grad_weight = (
            (grad_weight_hf.view(-1, 32) * weight_mask_hf.view(-1, 32).to(grad_weight_hf.dtype))
            @ TritonGemm.forward_hadamard_matrix.T
        ).view(ctx.out_dim, ctx.in_dim)
        return grad_input, grad_weight, None

In [5]:
DETERMINISTIC_FOR_TESTS = True

x = torch.randn(1, 32, 4096, device="cuda", dtype=DTYPE, requires_grad=True)
w = torch.randn(128, 4096, device="cuda", dtype=DTYPE, requires_grad=True)


y = F.linear(x, w)
y_grad = torch.randn_like(y)
y.backward(y_grad)

grad = w.grad.clone()
w.grad = None

y_had = HadamardGemm.apply(x, w, DETERMINISTIC_FOR_TESTS)
y_had.backward(y_grad)
y_had_grad = w.grad.clone()
w.grad = None

y_fp4 = MXFP4Gemm.apply(x, w, DETERMINISTIC_FOR_TESTS)
y_fp4.backward(y_grad)
y_fp4_grad = w.grad.clone()
w.grad = None

y_triton = TritonGemm.apply(x, w, DETERMINISTIC_FOR_TESTS)
y_triton.backward(y_grad)
y_triton_grad = w.grad.clone()
w.grad = None

In [6]:
had_l2_error = (torch.linalg.norm(y - y_had) / torch.linalg.norm(y)).pow(2).detach().item()
fp4_l2_error = (torch.linalg.norm(y - y_fp4) / torch.linalg.norm(y)).pow(2).detach().item()
triton_l2_discrepancy = (torch.linalg.norm(y_fp4 - y_triton) / torch.linalg.norm(y)).pow(2).detach().item()

print(f"Hadamard L2 error: {had_l2_error:.1e}")
print(f"FP4 L2 error: {fp4_l2_error:.1e}")
print(f"Triton L2 discrepancy: {triton_l2_discrepancy:.1e}")
assert had_l2_error < 1e-4
assert 2e-2 < fp4_l2_error < 6e-2
assert triton_l2_discrepancy < fp4_l2_error / 10

had_grad_l2_error = (torch.linalg.norm(grad - y_had_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
fp4_grad_l2_error = (torch.linalg.norm(grad - y_fp4_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
triton_grad_l2_discrepancy = (torch.linalg.norm(y_fp4_grad - y_triton_grad) / torch.linalg.norm(grad)).pow(2).detach().item()

print(f"Hadamard grad L2 error: {had_grad_l2_error:.1e}")
print(f"FP4 grad L2 error: {fp4_grad_l2_error:.1e}")
print(f"Triton grad L2 discrepancy: {triton_grad_l2_discrepancy:.1e}")

assert had_grad_l2_error < 1e-4
assert 6e-2 < fp4_grad_l2_error < 15e-2
assert triton_grad_l2_discrepancy < fp4_grad_l2_error / 10

Hadamard L2 error: 1.5e-05
FP4 L2 error: 5.8e-02
Triton L2 discrepancy: 1.4e-03
Hadamard grad L2 error: 1.6e-05
FP4 grad L2 error: 1.2e-01
Triton grad L2 discrepancy: 4.4e-03


In [7]:
gemm_fns = {
    "baseline": F.linear,
    "+hadamard": HadamardGemm.apply,
    "+mxfp4": MXFP4Gemm.apply,
    "+triton": TritonGemm.apply,
}

gemm_compiled_fns = {
    k: torch.compile(v) for k, v in gemm_fns.items()
}


def benchmark_gpu(fn, input_size, weight_size, num_iterations=100, warmup=10):
    """Benchmark a function on GPU"""
    input = torch.randn(*input_size, device="cuda", dtype=DTYPE, requires_grad=True)
    weight = torch.randn(*weight_size, device="cuda", dtype=DTYPE, requires_grad=True)
    
    # Warmup
    for _ in range(warmup):
        result = fn(input, weight)
    
    torch.cuda.synchronize()
    
    # Measure forward pass
    start_time = time.time()
    for _ in range(num_iterations):
        result = fn(input, weight)
        torch.cuda.synchronize()
    forward_time = (time.time() - start_time) / num_iterations
    
    # Warmup
    grad = torch.randn_like(result)
    for _ in range(warmup):
        result.backward(grad, retain_graph=True)
    
    # Measure backward pass
    start_time = time.time()
    for _ in range(num_iterations):
        result = fn(input, weight)
        result.backward(grad, retain_graph=True)
        torch.cuda.synchronize()
    backward_time = (time.time() - start_time) / num_iterations - forward_time
    
    return {
        "forward_ms": forward_time * 1000,
        "backward_ms": backward_time * 1000,
        "total_ms": (forward_time + backward_time) * 1000
    }


def run_gpu_benchmarks(batch_size=64, seq_len=512, hidden_size=1024):
    """Run benchmarks for different GEMM implementations on GPU"""
    input_size = (batch_size, seq_len, hidden_size)
    weight_size = (hidden_size, hidden_size)
    
    results = {}
    for name, fn in gemm_compiled_fns.items():
        print(f"Benchmarking {name}...")
        results[name] = benchmark_gpu(fn, input_size, weight_size)
    
    # Print results
    print("\nGPU Benchmark Results (ms):")
    print(f"{'Method':<15} {'Forward':<10} {'Backward':<10} {'Total':<10}")
    print("-" * 45)
    for name, timings in results.items():
        forward = f"{timings['forward_ms']:.2f}"
        backward = f"{timings['backward_ms']:.2f}" if timings['backward_ms'] is not None else "N/A"
        total = f"{timings['total_ms']:.2f}"
        print(f"{name:<15} {forward:<10} {backward:<10} {total:<10}")
    
    return results

In [8]:
_ = run_gpu_benchmarks(hidden_size=4096)

Benchmarking baseline...
Benchmarking +hadamard...
Benchmarking +mxfp4...
Benchmarking +triton...


W0506 16:08:09.656000 2094825 site-packages/torch/_dynamo/convert_frame.py:990] [5/8] torch._dynamo hit config.recompile_limit (8)
W0506 16:08:09.656000 2094825 site-packages/torch/_dynamo/convert_frame.py:990] [5/8]    function: 'mxfp4_forward_kernel_wrapper' (/tmp/ipykernel_2094825/748202876.py:180)
W0506 16:08:09.656000 2094825 site-packages/torch/_dynamo/convert_frame.py:990] [5/8]    last reason: 5/7: x._base.stride()[0] == x._base.size()[1]  # (unknown source x._base.stride()[0], please file a bug)
W0506 16:08:09.656000 2094825 site-packages/torch/_dynamo/convert_frame.py:990] [5/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0506 16:08:09.656000 2094825 site-packages/torch/_dynamo/convert_frame.py:990] [5/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.



GPU Benchmark Results (ms):
Method          Forward    Backward   Total     
---------------------------------------------
baseline        5.13       12.55      17.68     
+hadamard       5.83       18.73      24.56     
+mxfp4          5.86       28.64      34.50     
+triton         5.00       20.40      25.40     


# CUDA Kernels Needed

---

## Forward HT+Quant:

 * **Inputs**: 
    * `x`: torch.Tensor; shape=[M, K]; dtype=float32; already contiguous.
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `q`: torch.Tensor; shape=[M, K]; dtype=MXFP4 with scales along K;
    * `mask`: torch.Tensor; shape=[M, K]; dtype=bool;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        q = mxfp4_forward_kernel_wrapper(
            x,
            hadamard_matrix,
            stochastic_round=False,
            quest=True,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * RTN projection.
    * Scales based on STD.


---
    
## Backward HT+Quant:

 * **Inputs**: 
    * `x`: torch.Tensor; shape=[M, K]; dtype=float32; already contiguous.
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `q`: torch.Tensor; shape=[M, K]; dtype=MXFP4 with scales along K;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        q = mxfp4_forward_kernel_wrapper(
            x,
            hadamard_matrix,
            stochastic_round=True,
            quest=False,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * Stochastic Rounding.
    * Scales based on absmax.

---

## Backward Transpose+HT+Quant:

 * **Inputs**: 
    * `x`: torch.Tensor; shape=[M, K]; dtype=float32; already contiguous.
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `q`: torch.Tensor; shape=[K, M]; dtype=MXFP4 with scales along M;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        q = mxfp4_forward_kernel_wrapper(
            x.T,
            hadamard_matrix,
            stochastic_round=True,
            quest=False,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * Stochastic Rounding.
    * Scales based on absmax.

---
    
## Backward Dequant+Transpose+HT+Quant:

 * **Inputs**:
    * `q`: torch.Tensor; shape=[M, K]; dtype=MXFP4 with scales along K;
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `qq`: torch.Tensor; shape=[K, M]; dtype=MXFP4 with scales along M;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        qq = mxfp4_forward_kernel_wrapper(
            q.T,
            hadamard_matrix,
            stochastic_round=True,
            quest=False,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * Stochastic Rounding.
    * Scales based on absmax.