In [1]:
%env CUDA_VISIBLE_DEVICES=0
# %env TORCH_LOGS=recompiles

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import torch
torch._dynamo.config.compiled_autograd = True
torch._dynamo.config.recompile_limit = 2048
from torch import nn
from torch.nn import functional as F
import time
from scipy.linalg import hadamard


def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
    return torch.tensor(
        hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device
    )


DTYPE = torch.bfloat16
# DTYPE = torch.float32
GRID = torch.tensor(
    [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0,
    0.0,  0.5,  1.0,  1.5,  2.0,  3.0,  4.0, 6.0],
    device="cuda", dtype=DTYPE,
)
EMAX = 2
SCALE = 3/4
GAUSSIAL_SCALE = 2.92247856 / 6.0


### FORWARD QUANTIZATION

def rtn_fp4(x, grid):
    inds = torch.bucketize(x, grid)

    lo = torch.clamp(inds - 1, min=0, max=15)
    hi = torch.clamp(inds,     min=0, max=15)

    g_lo = grid[lo]
    g_hi = grid[hi]

    pick_hi = (g_hi - x) <= (x - g_lo)
    return torch.where(pick_hi, g_hi, g_lo)


@torch.compile
def quantize_quest(x):
    x_grouped = x.view(-1, 32)
    shared_exps = torch.floor(torch.log2(
        GAUSSIAL_SCALE * torch.std(x_grouped, dim=-1, correction=0, keepdim=True) + 1e-8
    ))
    scales = 2 ** shared_exps

    scaled_x = x_grouped / scales

    x_fp4 = rtn_fp4(scaled_x, GRID)

    return (x_fp4 * scales).reshape_as(x), torch.abs(scaled_x) <= 6.0


### BACKWARD QUANTIZATION

def stochastic_round_fp4(x, grid):
    inds = torch.bucketize(x, grid)

    lo = torch.clamp(inds - 1, min=0, max=15)
    hi = torch.clamp(inds,     min=0, max=15)

    g_lo = grid[lo]
    g_hi = grid[hi]

    delta = g_hi - g_lo
    p = torch.where(
        delta > 0,
        (x - g_lo) / delta,
        torch.full_like(x, 0.5)
    )

    u = torch.rand_like(x)
    pick_hi = u < p
    return torch.where(pick_hi, g_hi, g_lo)


@torch.compile
def quantize_tseng(x):
    x_grouped = x.view(-1, 32)
    shared_exps = torch.floor(torch.log2(x_grouped.abs().max(dim=-1, keepdim=True)[0])) - EMAX
    scales = 2 ** shared_exps / SCALE

    scaled_x = x_grouped / scales

    # x_fp4 = stochastic_round_fp4(scaled_x, GRID)
    x_fp4 = rtn_fp4(scaled_x, GRID)

    return (x_fp4 * scales).reshape_as(x)


In [3]:
FORWARD_HADAMARD_MATRIX = get_hadamard_matrix(32, dtype=DTYPE, device="cuda")
BACKWARD_HADAMARD_MATRIX = get_hadamard_matrix(32, dtype=DTYPE, device="cuda")

class HadamardGemm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, deterministic=True):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]

        ctx.deterministic = deterministic

        input_hf = (
            input.reshape(-1, 32) @ FORWARD_HADAMARD_MATRIX
        ).view(ctx.batch, ctx.seq, ctx.in_dim)
        weight_hf = (
            weight.reshape(-1, 32) @ FORWARD_HADAMARD_MATRIX
        ).view(ctx.out_dim, ctx.in_dim)

        ctx.save_for_backward(input_hf, weight_hf)
        return F.linear(input_hf, weight_hf)

    @staticmethod
    def backward(ctx, grad_output):
        global BACKWARD_HADAMARD_MATRIX

        input_hf, weight_hf = ctx.saved_tensors

        if not ctx.deterministic:
            BACKWARD_HADAMARD_MATRIX = BACKWARD_HADAMARD_MATRIX @ torch.diag(
                torch.randint(
                    0, 2, (32,),
                    device=BACKWARD_HADAMARD_MATRIX.device,
                    dtype=BACKWARD_HADAMARD_MATRIX.dtype
                ) * 2 - 1
            )

        grad_output_hb = (
            grad_output.reshape(-1, 32) @ BACKWARD_HADAMARD_MATRIX
        ).view(ctx.batch, ctx.seq, ctx.out_dim)
        hft_weightt_hb = (
            weight_hf.T.reshape(-1, 32) @ BACKWARD_HADAMARD_MATRIX
        ).view(ctx.in_dim, ctx.out_dim)
        grad_input_hf = F.linear(grad_output_hb, hft_weightt_hb)
        grad_input = (
            grad_input_hf.reshape(-1, 32) @ FORWARD_HADAMARD_MATRIX.T
        ).view(ctx.batch, ctx.seq, ctx.in_dim)

        grad_outputt_hb = (
            grad_output.view(-1, ctx.out_dim).T.reshape(-1, 32) @ BACKWARD_HADAMARD_MATRIX
        ).view(ctx.out_dim, -1)
        hft_inputt_hb = (
            input_hf.view(-1, ctx.in_dim).T.reshape(-1, 32) @ BACKWARD_HADAMARD_MATRIX
        ).view(ctx.in_dim, -1)
        grad_weight_hf = F.linear(grad_outputt_hb, hft_inputt_hb)
        grad_weight = (
            grad_weight_hf.reshape(-1, 32) @ FORWARD_HADAMARD_MATRIX.T
        ).view(ctx.out_dim, ctx.in_dim)
        return grad_input, grad_weight, None


class MXFP4Gemm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, deterministic=True):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]

        ctx.deterministic = deterministic

        input_hf, input_mask_hf = quantize_quest(
            input.view(-1, 32) @ FORWARD_HADAMARD_MATRIX
        )
        input_hf, input_mask_hf = input_hf.view_as(input), input_mask_hf.view_as(input)
        weight_hf, weight_mask_hf = quantize_quest(
            weight.view(-1, 32) @ FORWARD_HADAMARD_MATRIX
        )
        weight_hf, weight_mask_hf = weight_hf.view_as(weight), weight_mask_hf.view_as(weight)

        ctx.save_for_backward(input_hf, weight_hf, input_mask_hf, weight_mask_hf)
        return F.linear(input_hf, weight_hf)

    @staticmethod
    def backward(ctx, grad_output):
        global BACKWARD_HADAMARD_MATRIX
        input_hf, weight_hf, input_mask_hf, weight_mask_hf = ctx.saved_tensors

        if not ctx.deterministic:
            BACKWARD_HADAMARD_MATRIX = BACKWARD_HADAMARD_MATRIX @ torch.diag(
                torch.randint(
                    0, 2, (32,),
                    device=BACKWARD_HADAMARD_MATRIX.device,
                    dtype=BACKWARD_HADAMARD_MATRIX.dtype
                ) * 2 - 1
            )

        grad_output_hb = quantize_tseng(
            grad_output.view(-1, 32) @ BACKWARD_HADAMARD_MATRIX.T
        ).view(ctx.batch, ctx.seq, ctx.out_dim)
        hft_weightt_hb = quantize_tseng(
            weight_hf.T.reshape(-1, 32) @ BACKWARD_HADAMARD_MATRIX
        ).view(ctx.in_dim, ctx.out_dim)
        grad_input_hf = F.linear(grad_output_hb, hft_weightt_hb)
        grad_input = (
            (grad_input_hf.view(-1, 32) * input_mask_hf.view(-1, 32).to(grad_input_hf.dtype))
            @ FORWARD_HADAMARD_MATRIX.T
        ).view(ctx.batch, ctx.seq, ctx.in_dim)

        grad_outputt_hb = quantize_tseng(
            grad_output.view(-1, ctx.out_dim).T.reshape(-1, 32) @ BACKWARD_HADAMARD_MATRIX
        ).view(ctx.out_dim, -1)
        hft_inputt_hb = quantize_tseng(
            input_hf.view(-1, ctx.in_dim).T.reshape(-1, 32) @ BACKWARD_HADAMARD_MATRIX
        ).view(ctx.in_dim, -1)
        grad_weight_hf = F.linear(grad_outputt_hb, hft_inputt_hb)
        grad_weight = (
            (grad_weight_hf.view(-1, 32) * weight_mask_hf.view(-1, 32).to(grad_weight_hf.dtype))
            @ FORWARD_HADAMARD_MATRIX.T
        ).view(ctx.out_dim, ctx.in_dim)
        return grad_input, grad_weight, None

In [4]:
from random import randint

import torch
import triton
import triton.language as tl


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 32 * 32}),
        triton.Config({"BLOCK_SIZE": 64 * 32}),
        triton.Config({"BLOCK_SIZE": 128 * 32}),
        triton.Config({"BLOCK_SIZE": 256 * 32}),
        triton.Config({"BLOCK_SIZE": 512 * 32}),
    ],
    key=[],
)
@triton.jit
def mxfp4_forward_kernel(
    x_ptr,
    hadamard_matrix_ptr,
    output_ptr,
    clip_mask_ptr,
    n_elements: tl.constexpr,
    hadamard_dim: tl.constexpr,
    group_size: tl.constexpr,
    seed: int,
    quest: tl.constexpr,
    stochastic_round: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim)
    hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim)

    # load x
    pid = tl.program_id(0)
    start_idx = pid * BLOCK_SIZE
    offsets = start_idx + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x_flat = tl.load(x_ptr + offsets, mask=mask)

    # hadamard transform
    x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
    x_had = tl.dot(x, hadamard_matrix)

    # group
    x_had_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size))

    # scale
    if quest:
        mean_squared = tl.sum(x_had_grouped * x_had_grouped, axis=-1, keep_dims=True) / group_size
        mean = tl.sum(x_had_grouped, axis=-1, keep_dims=True) / group_size
        std = tl.sqrt(mean_squared - mean * mean)
        scales = (2.92247856 / 6.0) * std + 1e-8
        shared_exps = tl.exp2(tl.floor(tl.log2(scales)))
        x_had_scaled = x_had_grouped / shared_exps
    else:
        scales = tl.max(tl.abs(x_had_grouped), axis=-1, keep_dims=True)
        shared_exps = tl.exp2(tl.floor(tl.log2(scales)) - 2)
        x_had_scaled = x_had_grouped / shared_exps * (3/4) # 3/4 is constant. In CUDA, scale the GEMM output by 16/9

    # quantize
    x_had_scaled_abs = tl.abs(x_had_scaled)
    x_had_scaled_sign = tl.where(
        x_had_scaled > 0,
        1,
        -1,
    )
    if stochastic_round:
        x_fp4_high = tl.where(
            x_had_scaled_abs > 4,
            6,
            tl.where(
                x_had_scaled_abs > 3,
                4,
                tl.where(
                    x_had_scaled_abs > 2,
                    3,
                    tl.where(
                        x_had_scaled_abs > 1.5,
                        2,
                        tl.where(
                            x_had_scaled_abs > 1.0,
                            1.5,
                            tl.where(
                                x_had_scaled_abs > 0.5,
                                1,
                                0.5,
                            )
                        )
                    )
                )
            )
        )

        x_fp4_low = tl.where(
            x_had_scaled_abs > 4,
            4,
            tl.where(
                x_had_scaled_abs > 3,
                3,
                tl.where(
                    x_had_scaled_abs > 2,
                    2,
                    tl.where(
                        x_had_scaled_abs > 1.5,
                        1.5,
                        tl.where(
                            x_had_scaled_abs > 1.0,
                            1.0,
                            tl.where(
                                x_had_scaled_abs > 0.5,
                                0.5,
                                0.0,
                            )
                        )
                    )
                )
            )
        )

        prob_up = (x_had_scaled_abs - x_fp4_low) / (x_fp4_high - x_fp4_low)
        sampled_prob = tl.rand(seed, offsets).reshape(BLOCK_SIZE // hadamard_dim, hadamard_dim)
        x_fp4 = tl.where(
            sampled_prob < prob_up,
            x_fp4_high,
            x_fp4_low,
        ) * x_had_scaled_sign
    else:
        x_fp4 = tl.where(
            x_had_scaled_abs > 5,
            6,
            tl.where(
                x_had_scaled_abs > 3.5,
                4,
                tl.where(
                    x_had_scaled_abs > 2.5,
                    3,
                    tl.where(
                        x_had_scaled_abs > 1.75,
                        2,
                        tl.where(
                            x_had_scaled_abs > 1.25,
                            1.5,
                            tl.where(
                                x_had_scaled_abs > 0.75,
                                1,
                                tl.where(
                                    x_had_scaled_abs > 0.25,
                                    0.5,
                                    0,
                                )
                            )
                        )
                    )
                )
            )
        ) * x_had_scaled_sign


    # dequantize
    if quest:
        x_dequantized = x_fp4 * shared_exps
        tl.store(
            clip_mask_ptr + offsets,
            tl.reshape(x_had_scaled_abs < 6, (BLOCK_SIZE,)),
            mask=mask
        )
    else:
        x_dequantized = x_fp4 * shared_exps * (4/3) # 3/4 is constant. In CUDA, scale the GEMM output by 16/9

    # Reshape back to flat form for storage
    x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,))

    # store
    tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask)


def mxfp4_forward_kernel_wrapper(
    x,
    hadamard_matrix,
    stochastic_round=False,
    quest=True,
):
    # Make sure inputs are contiguous
    x = x.contiguous()

    # Create output tensor
    output = torch.empty_like(x)
    if quest:
        clip_mask = torch.empty_like(x, dtype=torch.bool)
    else:
        clip_mask = None

    # Get total number of elements and calculate grid for launching the kernel
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

    # Launch optimized kernel
    mxfp4_forward_kernel[grid](
        x_ptr=x,
        hadamard_matrix_ptr=hadamard_matrix,
        output_ptr=output,
        clip_mask_ptr=clip_mask,
        n_elements=n_elements,
        hadamard_dim=hadamard_matrix.shape[-1],
        group_size=32,
        seed=42,
        quest=quest,
        stochastic_round=stochastic_round,
    )

    return output, clip_mask


class TritonGemm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, deterministic=True):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]

        ctx.deterministic = deterministic

        input_hf, input_mask_hf = mxfp4_forward_kernel_wrapper(
            input,
            FORWARD_HADAMARD_MATRIX,
            stochastic_round=False,
            quest=True,
        )
        weight_hf, weight_mask_hf = mxfp4_forward_kernel_wrapper(
            weight,
            FORWARD_HADAMARD_MATRIX,
            stochastic_round=False,
            quest=True,
        )

        ctx.save_for_backward(input_hf, weight_hf, input_mask_hf, weight_mask_hf)
        return F.linear(input_hf, weight_hf)

    @staticmethod
    def backward(ctx, grad_output):
        global BACKWARD_HADAMARD_MATRIX
        input_hf, weight_hf, input_mask_hf, weight_mask_hf = ctx.saved_tensors

        if not ctx.deterministic:
            BACKWARD_HADAMARD_MATRIX = BACKWARD_HADAMARD_MATRIX @ torch.diag(
                torch.randint(
                    0, 2, (32,),
                    device=BACKWARD_HADAMARD_MATRIX.device,
                    dtype=BACKWARD_HADAMARD_MATRIX.dtype
                ) * 2 - 1
            )

        grad_output_hb, _ = mxfp4_forward_kernel_wrapper(
            grad_output,
            BACKWARD_HADAMARD_MATRIX,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        hft_weightt_hb, _ = mxfp4_forward_kernel_wrapper(
            weight_hf.T,
            BACKWARD_HADAMARD_MATRIX,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        grad_input_hf = F.linear(grad_output_hb, hft_weightt_hb)
        grad_input = (
            (grad_input_hf.view(-1, 32) * input_mask_hf.view(-1, 32).to(grad_input_hf.dtype))
            @ FORWARD_HADAMARD_MATRIX.T
        ).view(ctx.batch, ctx.seq, ctx.in_dim)

        grad_outputt_hb, _ = mxfp4_forward_kernel_wrapper(
            grad_output.view(-1, grad_output.size(-1)).T,
            BACKWARD_HADAMARD_MATRIX,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        hft_inputt_hb, _ = mxfp4_forward_kernel_wrapper(
            input_hf.view(-1, ctx.in_dim).T,
            BACKWARD_HADAMARD_MATRIX,
            stochastic_round=not ctx.deterministic,
            quest=False,
        )
        grad_weight_hf = F.linear(grad_outputt_hb, hft_inputt_hb)
        grad_weight = (
            (grad_weight_hf.view(-1, 32) * weight_mask_hf.view(-1, 32).to(grad_weight_hf.dtype))
            @ FORWARD_HADAMARD_MATRIX.T
        ).view(ctx.out_dim, ctx.in_dim)
        return grad_input, grad_weight, None

In [5]:
DETERMINISTIC_FOR_TESTS = True

x = torch.randn(1, 128, 4096, device="cuda", dtype=DTYPE, requires_grad=True)
w = torch.randn(128, 4096, device="cuda", dtype=DTYPE, requires_grad=True)


y = F.linear(x, w)
y_grad = torch.randn_like(y)
y.backward(y_grad)

grad = w.grad.clone()
w.grad = None

y_had = HadamardGemm.apply(x, w, DETERMINISTIC_FOR_TESTS)
y_had.backward(y_grad)
y_had_grad = w.grad.clone()
w.grad = None

y_fp4 = MXFP4Gemm.apply(x, w, DETERMINISTIC_FOR_TESTS)
y_fp4.backward(y_grad)
y_fp4_grad = w.grad.clone()
w.grad = None

y_triton = TritonGemm.apply(x, w, DETERMINISTIC_FOR_TESTS)
y_triton.backward(y_grad)
y_triton_grad = w.grad.clone()
w.grad = None

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


In [6]:
had_l2_error = (torch.linalg.norm(y - y_had) / torch.linalg.norm(y)).pow(2).detach().item()
fp4_l2_error = (torch.linalg.norm(y - y_fp4) / torch.linalg.norm(y)).pow(2).detach().item()
triton_l2_discrepancy = (torch.linalg.norm(y_fp4 - y_triton) / torch.linalg.norm(y)).pow(2).detach().item()

print(f"Hadamard L2 error: {had_l2_error:.1e}")
print(f"FP4 L2 error: {fp4_l2_error:.1e}")
print(f"Triton L2 discrepancy: {triton_l2_discrepancy:.1e}")
assert had_l2_error < 1e-4
assert 2e-2 < fp4_l2_error < 6e-2
assert triton_l2_discrepancy < fp4_l2_error / 10

had_grad_l2_error = (torch.linalg.norm(grad - y_had_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
fp4_grad_l2_error = (torch.linalg.norm(grad - y_fp4_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
triton_grad_l2_discrepancy = (torch.linalg.norm(y_fp4_grad - y_triton_grad) / torch.linalg.norm(grad)).pow(2).detach().item()

print(f"Hadamard grad L2 error: {had_grad_l2_error:.1e}")
print(f"FP4 grad L2 error: {fp4_grad_l2_error:.1e}")
print(f"Triton grad L2 discrepancy: {triton_grad_l2_discrepancy:.1e}")

assert had_grad_l2_error < 1e-4
assert 6e-2 < fp4_grad_l2_error < 15e-2
assert triton_grad_l2_discrepancy < fp4_grad_l2_error / 10

Hadamard L2 error: 1.5e-05
FP4 L2 error: 5.6e-02
Triton L2 discrepancy: 1.6e-03
Hadamard grad L2 error: 1.6e-05
FP4 grad L2 error: 1.2e-01
Triton grad L2 discrepancy: 4.5e-03


In [7]:
gemm_fns = {
    "baseline": F.linear,
    "+hadamard": HadamardGemm.apply,
    "+mxfp4": MXFP4Gemm.apply,
    "+triton": TritonGemm.apply,
}

gemm_compiled_fns = {
    k: torch.compile(v) for k, v in gemm_fns.items()
}


def benchmark_gpu(fn, input_size, weight_size, num_iterations=100, warmup=10):
    """Benchmark a function on GPU"""
    input = torch.randn(*input_size, device="cuda", dtype=DTYPE, requires_grad=True)
    weight = torch.randn(*weight_size, device="cuda", dtype=DTYPE, requires_grad=True)

    # # Warmup
    # for _ in range(warmup):
    #     result = fn(input, weight)

    # torch.cuda.synchronize()

    # # Measure forward pass
    # start_time = time.time()
    # for _ in range(num_iterations):
    #     result = fn(input, weight)
    #     torch.cuda.synchronize()
    # forward_time = (time.time() - start_time) / num_iterations

    # # Warmup
    # grad = torch.randn_like(result)
    # for _ in range(warmup):
    #     result.backward(grad, retain_graph=True)

    # # Measure backward pass
    # start_time = time.time()
    # for _ in range(num_iterations):
    #     result = fn(input, weight)
    #     result.backward(grad, retain_graph=True)
    #     torch.cuda.synchronize()
    # backward_time = (time.time() - start_time) / num_iterations - forward_time

    # return {
    #     "forward_ms": forward_time * 1000,
    #     "backward_ms": backward_time * 1000,
    #     "total_ms": (forward_time + backward_time) * 1000
    # }
    
    result = fn(input, weight)
    grad = torch.randn_like(result)
    
    forward_time = triton.testing.do_bench(
        lambda: fn(input, weight), warmup=warmup, rep=num_iterations,
    )
    backward_time = triton.testing.do_bench(
        lambda: result.backward(grad, retain_graph=True), warmup=warmup, rep=num_iterations,
    )
    
    return {
        "forward_ms": forward_time,
        "backward_ms": backward_time,
        "total_ms": (forward_time + backward_time)
    }


def run_gpu_benchmarks(batch_size=64, seq_len=512, hidden_size=1024):
    """Run benchmarks for different GEMM implementations on GPU"""
    input_size = (batch_size, seq_len, hidden_size)
    weight_size = (hidden_size, hidden_size)

    results = {}
    for name, fn in gemm_compiled_fns.items():
        print(f"Benchmarking {name}...")
        results[name] = benchmark_gpu(fn, input_size, weight_size)

    # Print results
    print("\nGPU Benchmark Results (ms):")
    print(f"{'Method':<15} {'Forward':<10} {'Backward':<10} {'Total':<10}")
    print("-" * 45)
    for name, timings in results.items():
        forward = f"{timings['forward_ms']:.2f}"
        backward = f"{timings['backward_ms']:.2f}" if timings['backward_ms'] is not None else "N/A"
        total = f"{timings['total_ms']:.2f}"
        print(f"{name:<15} {forward:<10} {backward:<10} {total:<10}")

    return results

In [8]:
_ = run_gpu_benchmarks(hidden_size=4096)

Benchmarking baseline...
Benchmarking +hadamard...
Benchmarking +mxfp4...
Benchmarking +triton...

GPU Benchmark Results (ms):
Method          Forward    Backward   Total     
---------------------------------------------
baseline        5.06       11.43      16.50     
+hadamard       6.33       13.65      19.98     
+mxfp4          7.34       15.04      22.39     
+triton         7.34       14.17      21.51     


# CUDA Kernels Needed

---

## Forward HT+Quant:

 * **Inputs**: 
    * `x`: torch.Tensor; shape=[M, K]; dtype=float32; already contiguous.
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `q`: torch.Tensor; shape=[M, K]; dtype=MXFP4 with scales along K;
    * `mask`: torch.Tensor; shape=[M, K]; dtype=bool;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        q = mxfp4_forward_kernel_wrapper(
            x,
            hadamard_matrix,
            stochastic_round=False,
            quest=True,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * RTN projection.
    * Scales based on STD.


---
    
## Backward HT+Quant:

 * **Inputs**: 
    * `x`: torch.Tensor; shape=[M, K]; dtype=float32; already contiguous.
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `q`: torch.Tensor; shape=[M, K]; dtype=MXFP4 with scales along K;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        q = mxfp4_forward_kernel_wrapper(
            x,
            hadamard_matrix,
            stochastic_round=True,
            quest=False,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * Stochastic Rounding.
    * Scales based on absmax.

---

## Backward Transpose+HT+Quant:

 * **Inputs**: 
    * `x`: torch.Tensor; shape=[M, K]; dtype=float32; already contiguous.
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `q`: torch.Tensor; shape=[K, M]; dtype=MXFP4 with scales along M;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        q = mxfp4_forward_kernel_wrapper(
            x.T,
            hadamard_matrix,
            stochastic_round=True,
            quest=False,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * Stochastic Rounding.
    * Scales based on absmax.

---
    
## Backward Dequant+Transpose+HT+Quant:

 * **Inputs**:
    * `q`: torch.Tensor; shape=[M, K]; dtype=MXFP4 with scales along K;
    * `hadamard_matrix`: torch.Tensor; shape=[32, 32]; dtype=float32.
 * **Outputs**:
    * `qq`: torch.Tensor; shape=[K, M]; dtype=MXFP4 with scales along M;
 * **Reference Impl**:
    * Triton impl up to pseudoquant:
        ```
        qq = mxfp4_forward_kernel_wrapper(
            q.T,
            hadamard_matrix,
            stochastic_round=True,
            quest=False,
        )
        ```
 * **Features**:
    * Assume `M,K` divisible by 128.
    * Stochastic Rounding.
    * Scales based on absmax.

In [9]:
from tqdm import tqdm

In [167]:
from qutlass import matmul_mxf4_bf16_tn

@torch.library.custom_op("quartet::matmul_mxf4_bf16_tn_op", mutates_args=())
def matmul_mxf4_bf16_tn_op(
    x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: torch.Tensor
) -> torch.Tensor:
    return matmul_mxf4_bf16_tn(
        x.view(torch.uint8), w.view(torch.uint8), xs.view(torch.float8_e8m0fnu), ws.view(torch.float8_e8m0fnu), alpha
    )

@matmul_mxf4_bf16_tn_op.register_fake
def _(x, w, xs, ws, alpha):
    return x.new_empty(x.shape[0], w.shape[0], dtype=DTYPE)


from qutlass import fusedQuantizeMx

@torch.library.custom_op("quartet::fusedQuantizeMx_op", mutates_args=())
def fusedQuantizeMx_op(
    x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, return_mask: bool
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if return_mask:
        return fusedQuantizeMx(x_flat, hadamard_matrix, return_mask=True)
    else:
        return fusedQuantizeMx(x_flat, hadamard_matrix, return_mask=False) + (None,)

@fusedQuantizeMx_op.register_fake
def _(x_flat, hadamard_matrix, return_mask):
    rows, cols = x_flat.shape[0], x_flat.shape[1] // 32
    padded_rows = ((rows + 128 - 1) // 128) * 128
    padded_cols = ((cols + 4 - 1) // 4) * 4

    xh_e2m1 = torch.empty(
        x_flat.shape[0], x_flat.shape[1] // 2, dtype=torch.uint8, device=x_flat.device
    )
    xh_e8m0 = torch.empty(
        padded_rows, padded_cols, dtype=torch.uint8, device=x_flat.device
    )
    clip_mask = torch.empty(*x_flat.shape[:-1], x_flat.size(-1) // 8,  dtype=torch.uint8, device=x_flat.device) if return_mask else None
    return xh_e2m1, xh_e8m0, clip_mask


from qutlass import backward_t_bf16

@torch.library.custom_op("quartet::backward_t_bf16_op", mutates_args=())
def backward_t_bf16_op(
    grad_output_flat: torch.Tensor, hadamard_matrix: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    return backward_t_bf16(grad_output_flat, hadamard_matrix)

@backward_t_bf16_op.register_fake
def _(grad_output_flat, hadamard_matrix):
    xh_e2m1 = torch.empty(grad_output_flat.shape[1], grad_output_flat.shape[0] // 2,  dtype=torch.uint8, device=grad_output_flat.device)
    xh_e8m0 = torch.empty(grad_output_flat.shape[1], grad_output_flat.shape[0] // 32, dtype=torch.uint8, device=grad_output_flat.device)

    return xh_e2m1, xh_e8m0


from qutlass import backward_qt_bf16

@torch.library.custom_op("quartet::backward_qt_bf16_op", mutates_args=())
def backward_qt_bf16_op(
    x_e2m1: torch.Tensor,
    x_e8m0: torch.Tensor,
    h: torch.Tensor,
    alpha: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert x_e2m1.dim() == 2
    return backward_qt_bf16(x_e2m1, x_e8m0, h, alpha)

@backward_qt_bf16_op.register_fake
def _(x_e2m1, x_e8m0, h, alpha):
    assert x_e2m1.dim() == 2
    xh_e2m1 = torch.empty(x_e2m1.shape[1] * 2, x_e2m1.shape[0] // 2, dtype=torch.uint8, device=h.device)
    xh_e8m0 = torch.empty(x_e8m0.shape[1] * 32, x_e8m0.shape[0] // 32, dtype=torch.uint8, device=h.device)
    return xh_e2m1, xh_e8m0


from qutlass import matmul_mxf8_bf16_tn

@torch.library.custom_op("quartet::matmul_mxf8_bf16_tn_op", mutates_args=())
def matmul_mxf8_bf16_tn_op(
    x: torch.Tensor, w: torch.Tensor, xs: torch.Tensor, ws: torch.Tensor, alpha: torch.Tensor
) -> torch.Tensor:
    return matmul_mxf8_bf16_tn(
        x, w, xs.view(torch.float8_e8m0fnu), ws.view(torch.float8_e8m0fnu), alpha
    )

@matmul_mxf8_bf16_tn_op.register_fake
def _(x, w, xs, ws, alpha):
    return x.new_empty(x.shape[0], w.shape[0], dtype=DTYPE)


from qutlass.utils import to_blocked

def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor:
    clip_mask_unpacked_dq = torch.zeros(*clip_mask.shape[:-1], clip_mask.size(-1) * 8, dtype=torch.bool, device=clip_mask.device)
    for i in range(8):
        clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1
    return clip_mask_unpacked_dq

In [177]:
ALPHA_FWD = torch.tensor(1., device="cuda")
ALPHA_BWD = torch.tensor(1./9., device="cuda")

class CudaGemm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, deterministic=True):
        ctx.batch = input.shape[0]
        ctx.seq = input.shape[1]
        ctx.in_dim = weight.shape[1]
        ctx.out_dim = weight.shape[0]

        ctx.deterministic = deterministic

        input_hf_e2m1, input_hf_e8m0, input_hf_mask = fusedQuantizeMx_op(
            input.flatten(end_dim=-2),
            FORWARD_HADAMARD_MATRIX,
            return_mask=input.requires_grad,
        )

        weight_hf_e2m1, weight_hf_e8m0, weight_hf_mask = fusedQuantizeMx_op(
            weight,
            FORWARD_HADAMARD_MATRIX,
            return_mask=input.requires_grad,
        )

        ctx.save_for_backward(input_hf_e2m1, input_hf_e8m0, input_hf_mask, weight_hf_e2m1, weight_hf_e8m0, weight_hf_mask)

        input_hf_scale_block = to_blocked(input_hf_e8m0, False)
        weight_hf_scale_block = to_blocked(weight_hf_e8m0, False)

        out = matmul_mxf4_bf16_tn_op(
            input_hf_e2m1,
            weight_hf_e2m1,
            input_hf_scale_block,
            weight_hf_scale_block,
            ALPHA_FWD,
        )
        return out.view(*input.shape[:-1], weight.size(-2))

    @staticmethod
    def backward(ctx, grad_output):
        global BACKWARD_HADAMARD_MATRIX
        input_hf_e2m1, input_hf_e8m0, input_hf_mask, weight_hf_e2m1, weight_hf_e8m0, weight_hf_mask = ctx.saved_tensors

        if not ctx.deterministic:
            BACKWARD_HADAMARD_MATRIX = BACKWARD_HADAMARD_MATRIX * (
                torch.randint(0, 2, (32,), device=BACKWARD_HADAMARD_MATRIX.device, dtype=BACKWARD_HADAMARD_MATRIX.dtype)
                * 2. - 1.
            )

        grad_output_hb_e2m1, grad_output_hb_e8m0, _ = fusedQuantizeMx_op(
            grad_output.flatten(end_dim=-2),
            BACKWARD_HADAMARD_MATRIX,
            False,
        )

        hft_weightt_hb_e2m1, hft_weightt_hb_e8m0 = backward_qt_bf16_op(weight_hf_e2m1, weight_hf_e8m0, BACKWARD_HADAMARD_MATRIX, ALPHA_FWD)
        grad_output_hb_scale_block = to_blocked(grad_output_hb_e8m0, False)
        hft_weightt_hb_scale_block = to_blocked(hft_weightt_hb_e8m0, False)
        grad_input_hf = matmul_mxf4_bf16_tn_op(
            grad_output_hb_e2m1,
            hft_weightt_hb_e2m1,
            grad_output_hb_scale_block,
            hft_weightt_hb_scale_block,
            ALPHA_BWD,
        )

        input_mask_hf = _unpack_mask(input_hf_mask)
        grad_input = (
            (grad_input_hf.view(-1, 32) * input_mask_hf.view(-1, 32).to(grad_input_hf.dtype))
            @ FORWARD_HADAMARD_MATRIX.T
        ).view(*grad_output.shape[:-1], weight_hf_e2m1.size(-1) * 2)

        grad_outputt_hb_e2m1, grad_outputt_hb_e8m0 = backward_t_bf16_op(grad_output.flatten(end_dim=-2), BACKWARD_HADAMARD_MATRIX)
        hft_inputt_hb_e2m1, hft_inputt_hb_e8m0 = backward_qt_bf16_op(input_hf_e2m1, input_hf_e8m0, BACKWARD_HADAMARD_MATRIX, ALPHA_FWD)
        grad_outputt_hb_scale_block = to_blocked(grad_outputt_hb_e8m0, False)
        hft_inputt_hb_scale_block = to_blocked(hft_inputt_hb_e8m0, False)
        grad_weight_hf = matmul_mxf4_bf16_tn_op(
            grad_outputt_hb_e2m1,
            hft_inputt_hb_e2m1,
            grad_outputt_hb_scale_block,
            hft_inputt_hb_scale_block,
            ALPHA_BWD,
        )
        
        # torch._assert(grad_weight_hf.shape == (weight_hf_e2m1.size(0), weight_hf_e2m1.size(1) * 2), f"{grad_outputt_hb_e2m1.shape=} {hft_inputt_hb_e2m1.shape=} {grad_weight_hf.shape=} {weight_hf_e2m1.shape=}")

        weight_mask_hf = _unpack_mask(weight_hf_mask)
        grad_weight = (
            (grad_weight_hf.view(-1, 32) * weight_mask_hf.view(-1, 32).to(grad_weight_hf.dtype))
            @ FORWARD_HADAMARD_MATRIX.T
        ).view(grad_output.size(-1), weight_hf_e2m1.size(-1) * 2)
        return grad_input, grad_weight, None

In [178]:
DUMMY_E8M0 = torch.ones(2 ** 30, dtype=torch.float8_e8m0fnu, device="cuda").view(torch.uint8)

class Fp8Gemm(torch.autograd.Function):
    @staticmethod
    def get_dummy_e8m0(x_e4m3: torch.Tensor) -> torch.Tensor:
        x_e8m0 = DUMMY_E8M0[:x_e4m3.numel() // 32].view(*x_e4m3.shape[:-1], x_e4m3.size(-1) // 32)
        return x_e8m0

    @staticmethod
    def mm_fp8(a_e4m3: torch.Tensor, b_e4m3: torch.Tensor) -> torch.Tensor:
        c_bf16 = matmul_mxf8_bf16_tn_op(a_e4m3, b_e4m3, Fp8Gemm.get_dummy_e8m0(a_e4m3), Fp8Gemm.get_dummy_e8m0(b_e4m3), ALPHA_FWD)
        return c_bf16

    @staticmethod
    def forward(ctx, input, weight):
        input_e4m3 = input.flatten(end_dim=-2).to(dtype=torch.float8_e4m3fn)
        weight_e4m3 = weight.to(dtype=torch.float8_e4m3fn)
        ctx.save_for_backward(input_e4m3, weight_e4m3)

        return Fp8Gemm.mm_fp8(
            input_e4m3,
            weight_e4m3,
        ).view(*input.shape[:-1], weight.size(-2))

    @staticmethod
    def backward(ctx, grad_output):
        input_e4m3, weight_e4m3 = ctx.saved_tensors

        grad_output_e4m3 = grad_output.flatten(end_dim=-2).to(dtype=torch.float8_e4m3fn)
        
        grad_input = Fp8Gemm.mm_fp8(
            grad_output_e4m3,
            weight_e4m3.T.contiguous(),
        ).view(*grad_output.shape[:-1], weight_e4m3.size(-1))

        grad_outputt_e4m3 = grad_output.flatten(end_dim=-2).to(dtype=torch.float8_e4m3fn)
        grad_weight = Fp8Gemm.mm_fp8(
            grad_outputt_e4m3.T.contiguous(),
            input_e4m3.T.contiguous(),
        ).view(grad_output.size(-1), weight_e4m3.size(-1))

        return grad_input, grad_weight

In [179]:
torch.set_grad_enabled(True)

y_cuda = CudaGemm.apply(x, w, DETERMINISTIC_FOR_TESTS)
y_cuda.backward(y_grad)
y_cuda_grad = w.grad.clone()
w.grad = None

In [180]:
y_fp8 = Fp8Gemm.apply(x, w)
y_fp8.backward(y_grad)
y_fp8_grad = w.grad.clone()
w.grad = None

In [181]:
had_l2_error = (torch.linalg.norm(y - y_had) / torch.linalg.norm(y)).pow(2).detach().item()
fp4_l2_error = (torch.linalg.norm(y - y_fp4) / torch.linalg.norm(y)).pow(2).detach().item()
fp8_l2_error = (torch.linalg.norm(y - y_fp8) / torch.linalg.norm(y)).pow(2).detach().item()
triton_l2_discrepancy = (torch.linalg.norm(y_fp4 - y_triton) / torch.linalg.norm(y)).pow(2).detach().item()
cuda_l2_discrepancy = (torch.linalg.norm(y_fp4 - y_cuda) / torch.linalg.norm(y)).pow(2).detach().item()

print(f"Hadamard L2 error: {had_l2_error:.1e}")
print(f"FP4 L2 error: {fp4_l2_error:.1e}")
print(f"FP8 L2 error: {fp8_l2_error:.1e}")
print(f"Triton L2 discrepancy: {triton_l2_discrepancy:.1e}")
print(f"Cuda L2 discrepancy: {cuda_l2_discrepancy:.1e}")
assert had_l2_error < 1e-4
assert fp8_l2_error < 2e-2 < fp4_l2_error < 6e-2
assert triton_l2_discrepancy < fp4_l2_error / 10
assert cuda_l2_discrepancy < fp4_l2_error / 10

had_grad_l2_error = (torch.linalg.norm(grad - y_had_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
fp4_grad_l2_error = (torch.linalg.norm(grad - y_fp4_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
fp8_grad_l2_error = (torch.linalg.norm(grad - y_fp8_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
triton_grad_l2_discrepancy = (torch.linalg.norm(y_fp4_grad - y_triton_grad) / torch.linalg.norm(grad)).pow(2).detach().item()
cuda_grad_l2_discrepancy = (torch.linalg.norm(y_fp4_grad - y_cuda_grad) / torch.linalg.norm(grad)).pow(2).detach().item()

print(f"Hadamard grad L2 error: {had_grad_l2_error:.1e}")
print(f"FP4 grad L2 error: {fp4_grad_l2_error:.1e}")
print(f"FP8 grad L2 error: {fp8_grad_l2_error:.1e}")
print(f"Triton grad L2 discrepancy: {triton_grad_l2_discrepancy:.1e}")
print(f"Cuda grad L2 discrepancy: {cuda_grad_l2_discrepancy:.1e}")

assert had_grad_l2_error < 1e-4
assert fp8_grad_l2_error < 6e-2 < fp4_grad_l2_error < 15e-2
assert triton_grad_l2_discrepancy < fp4_grad_l2_error / 10
assert cuda_grad_l2_discrepancy < fp4_grad_l2_error / 10

Hadamard L2 error: 1.5e-05
FP4 L2 error: 5.6e-02
FP8 L2 error: 1.4e-03
Triton L2 discrepancy: 1.6e-03
Cuda L2 discrepancy: 1.6e-03
Hadamard grad L2 error: 1.6e-05
FP4 grad L2 error: 1.2e-01
FP8 grad L2 error: 1.4e-03
Triton grad L2 discrepancy: 4.5e-03
Cuda grad L2 discrepancy: 4.5e-03


In [182]:
gemm_fns = {
    "baseline": F.linear,
    # "+hadamard": HadamardGemm.apply,
    # "+mxfp4": MXFP4Gemm.apply,
    # "+triton": TritonGemm.apply,
    "+cuda": CudaGemm.apply,
    "+fp8": Fp8Gemm.apply,
}

torch._dynamo.config.compiled_autograd = True
compile_kwargs = {"fullgraph": True}


def benchmark_gpu(fn, input_size, weight_size, num_iterations=100, warmup=10):
    """Benchmark a function on GPU"""
    input = torch.randn(*input_size, device="cuda", dtype=DTYPE, requires_grad=True)
    weight = torch.randn(*weight_size, device="cuda", dtype=DTYPE, requires_grad=True)


    # Forward
    torch.set_grad_enabled(False)
    
    compiled_forward_fn = torch.compile(fn, **compile_kwargs)
    
    ms = triton.testing.do_bench(
        lambda: compiled_forward_fn(input, weight), warmup=warmup, rep=num_iterations,
    )
    forward_time = ms

    # Forward+Backward
    grad = torch.randn_like(fn(input, weight))
    torch.set_grad_enabled(True)    
    
    def compiled_forward_backward(input, weight, grad):
        with torch._dynamo.compiled_autograd._enable(torch.compile(**compile_kwargs)):
        # with torch._dynamo.utils.maybe_enable_compiled_autograd(True, dynamic=False, **compile_kwargs):
            output = fn(input, weight)
            output.backward(grad)
    
    ms = triton.testing.do_bench(
        lambda: compiled_forward_backward(input, weight, grad), warmup=warmup, rep=num_iterations,
    )
    total_time = ms

    return {
        "forward_ms": forward_time,
        "total_ms": total_time,
    }


def run_gpu_benchmarks(batch_size=64, seq_len=512, hidden_size=1024):
    """Run benchmarks for different GEMM implementations on GPU"""
    input_size = (batch_size, seq_len, hidden_size)
    weight_size = (hidden_size, hidden_size)

    results = {}
    for name, fn in gemm_fns.items():
        print(f"Benchmarking {name}...")
        results[name] = benchmark_gpu(fn, input_size, weight_size)

    # Print results
    print("\nGPU Benchmark Results (ms):")
    print(f"{'Method':<15} {'Forward':<10} {'Total':<10}")
    print("-" * 45)
    for name, timings in results.items():
        forward = f"{timings['forward_ms']:.2f}"
        total = f"{timings['total_ms']:.2f}"
        print(f"{name:<15} {forward:<10} {total:<10}")

    return results

In [183]:
_ = run_gpu_benchmarks(hidden_size=1024)

Benchmarking baseline...
Benchmarking +cuda...
Benchmarking +fp8...

GPU Benchmark Results (ms):
Method          Forward    Total     
---------------------------------------------
baseline        0.38       1.25      
+cuda           0.15       1.41      
+fp8            0.21       0.89      


In [184]:
# shapes = {
#     # Q K V Down Up Gate Down
#     "100M": [(1024 * 3, 1024), (1024, 1024), (2816 * 2, 1024), (1024, 2816)],
#     "800M": [(2048 * 3, 2048), (2048, 2048), (5632 * 2, 2048), (2048, 5632)],
#     "3B": [(3072 * 3, 3072), (3072, 3072), (8192 * 2, 3072), (3072, 8192)],
#     "7B": [(4096 * 3, 4096), (4096, 4096), (11008 * 2, 4096), (4096, 11008)],
#     "22B": [(6144 * 3, 6144), (6144, 6144), (16384 * 2, 6144), (6144, 16384)],
#     "52B": [(8192 * 3, 8192), (8192, 8192), (22016 * 2, 8192), (8192, 22016)],
# }

In [185]:
from itertools import chain

def from_num_head(n_head):
    h = n_head * 128
    inter = h * 8 / 3
    inter = int((inter - 1) // 256) * 256 + 256
    
    shapes = [(h * 3, h), (h, h), (inter * 2, h), (h, inter)]
    assert sum(map(lambda x: x[0] * x[1], shapes)) == 4 * h**2 + 3 * h * inter
    
    return (4 * h**2 + 3 * h * inter) * n_head, shapes


shapes = {
    n: shapes for (n, shapes) in map(from_num_head, chain(range(8, 13), range(16, 65, 4)))
}

In [186]:
RESULTS = {}

def run_gpu_benchmarks_layer(batch_size=64, seq_len=512):
    all_shapes = set()
    for shape_list in shapes.values():
        all_shapes.update(shape_list)
    all_shapes = sorted(all_shapes)

    for name in ['baseline', '+cuda', '+fp8']:
        print(f"Benchmarking {name}...")
        fn = gemm_fns[name]
        if name not in RESULTS:
            RESULTS[name] = {}
        for weight_size in tqdm(all_shapes):
            input_size = (batch_size, seq_len, weight_size[1])
            if weight_size not in RESULTS[name]:
                print(f"Benchmarking {name} {weight_size}...")
                RESULTS[name][weight_size] = benchmark_gpu(fn, input_size, weight_size)
            else:
                print(f"Skipping {name} {weight_size} because it already exists")

In [187]:
run_gpu_benchmarks_layer()

Benchmarking baseline...


  0%|          | 0/72 [00:00<?, ?it/s]

Benchmarking baseline (1024, 1024)...


  1%|▏         | 1/72 [00:00<00:16,  4.39it/s]

Benchmarking baseline (1024, 2816)...


  3%|▎         | 2/72 [00:02<01:39,  1.42s/it]

Benchmarking baseline (1152, 1152)...


  4%|▍         | 3/72 [00:04<01:59,  1.74s/it]

Benchmarking baseline (1152, 3072)...


  6%|▌         | 4/72 [00:04<01:18,  1.15s/it]

Benchmarking baseline (1280, 1280)...


  7%|▋         | 5/72 [00:05<00:54,  1.22it/s]

Benchmarking baseline (1280, 3584)...


  8%|▊         | 6/72 [00:05<00:41,  1.60it/s]

Benchmarking baseline (1408, 1408)...


 10%|▉         | 7/72 [00:05<00:32,  2.01it/s]

Benchmarking baseline (1408, 3840)...


 11%|█         | 8/72 [00:05<00:26,  2.37it/s]

Benchmarking baseline (1536, 1536)...


 12%|█▎        | 9/72 [00:06<00:22,  2.75it/s]

Benchmarking baseline (1536, 4096)...


 14%|█▍        | 10/72 [00:06<00:20,  3.00it/s]

Benchmarking baseline (2048, 2048)...


 15%|█▌        | 11/72 [00:06<00:18,  3.24it/s]

Benchmarking baseline (2048, 5632)...


 17%|█▋        | 12/72 [00:06<00:18,  3.28it/s]

Benchmarking baseline (2560, 2560)...


 18%|█▊        | 13/72 [00:07<00:17,  3.42it/s]

Benchmarking baseline (2560, 6912)...


 19%|█▉        | 14/72 [00:07<00:17,  3.23it/s]

Benchmarking baseline (3072, 1024)...


 21%|██        | 15/72 [00:07<00:16,  3.45it/s]

Benchmarking baseline (3072, 3072)...


 22%|██▏       | 16/72 [00:08<00:16,  3.46it/s]

Benchmarking baseline (3072, 8192)...


 24%|██▎       | 17/72 [00:08<00:17,  3.10it/s]

Benchmarking baseline (3456, 1152)...


 25%|██▌       | 18/72 [00:08<00:16,  3.34it/s]

Benchmarking baseline (3584, 3584)...


 26%|██▋       | 19/72 [00:08<00:15,  3.31it/s]

Benchmarking baseline (3584, 9728)...


 28%|██▊       | 20/72 [00:09<00:18,  2.80it/s]

Benchmarking baseline (3840, 1280)...


 29%|██▉       | 21/72 [00:09<00:16,  3.07it/s]

Benchmarking baseline (4096, 4096)...


 31%|███       | 22/72 [00:10<00:16,  3.03it/s]

Benchmarking baseline (4096, 11008)...


 32%|███▏      | 23/72 [00:10<00:19,  2.53it/s]

Benchmarking baseline (4224, 1408)...


 33%|███▎      | 24/72 [00:10<00:17,  2.81it/s]

Benchmarking baseline (4608, 1536)...


 35%|███▍      | 25/72 [00:11<00:15,  3.05it/s]

Benchmarking baseline (4608, 4608)...


 36%|███▌      | 26/72 [00:11<00:15,  2.94it/s]

Benchmarking baseline (4608, 12288)...


 38%|███▊      | 27/72 [00:12<00:19,  2.28it/s]

Benchmarking baseline (5120, 5120)...


 39%|███▉      | 28/72 [00:12<00:18,  2.33it/s]

Benchmarking baseline (5120, 13824)...


 40%|████      | 29/72 [00:13<00:22,  1.91it/s]

Benchmarking baseline (5632, 1024)...


 42%|████▏     | 30/72 [00:13<00:18,  2.25it/s]

Benchmarking baseline (5632, 5632)...


 43%|████▎     | 31/72 [00:14<00:18,  2.25it/s]

Benchmarking baseline (5632, 15104)...


 44%|████▍     | 32/72 [00:14<00:22,  1.75it/s]

Benchmarking baseline (6144, 1152)...


 46%|████▌     | 33/72 [00:15<00:18,  2.09it/s]

Benchmarking baseline (6144, 2048)...


 47%|████▋     | 34/72 [00:15<00:16,  2.34it/s]

Benchmarking baseline (6144, 6144)...


 49%|████▊     | 35/72 [00:15<00:16,  2.23it/s]

Benchmarking baseline (6144, 16384)...


 50%|█████     | 36/72 [00:16<00:22,  1.63it/s]

Benchmarking baseline (6656, 6656)...


 51%|█████▏    | 37/72 [00:17<00:20,  1.69it/s]

Benchmarking baseline (6656, 17920)...


 53%|█████▎    | 38/72 [00:18<00:26,  1.30it/s]

Benchmarking baseline (7168, 1280)...


 54%|█████▍    | 39/72 [00:18<00:20,  1.60it/s]

Benchmarking baseline (7168, 7168)...


 56%|█████▌    | 40/72 [00:19<00:19,  1.61it/s]

Benchmarking baseline (7168, 19200)...


 57%|█████▋    | 41/72 [00:20<00:25,  1.20it/s]

Benchmarking baseline (7680, 1408)...


 58%|█████▊    | 42/72 [00:21<00:20,  1.49it/s]

Benchmarking baseline (7680, 2560)...


 60%|█████▉    | 43/72 [00:21<00:16,  1.74it/s]

Benchmarking baseline (7680, 7680)...


 61%|██████    | 44/72 [00:22<00:16,  1.68it/s]

Benchmarking baseline (7680, 20480)...


 62%|██████▎   | 45/72 [00:23<00:23,  1.15it/s]

Benchmarking baseline (8192, 1536)...


 64%|██████▍   | 46/72 [00:23<00:18,  1.43it/s]

Benchmarking baseline (8192, 8192)...


 65%|██████▌   | 47/72 [00:24<00:17,  1.42it/s]

Benchmarking baseline (8192, 22016)...


 67%|██████▋   | 48/72 [00:26<00:24,  1.02s/it]

Benchmarking baseline (9216, 3072)...


 68%|██████▊   | 49/72 [00:26<00:19,  1.20it/s]

Benchmarking baseline (10752, 3584)...


 69%|██████▉   | 50/72 [00:27<00:16,  1.35it/s]

Benchmarking baseline (11264, 2048)...


 71%|███████   | 51/72 [00:27<00:13,  1.58it/s]

Benchmarking baseline (12288, 4096)...


 72%|███████▏  | 52/72 [00:28<00:12,  1.60it/s]

Benchmarking baseline (13824, 2560)...


 74%|███████▎  | 53/72 [00:28<00:10,  1.73it/s]

Benchmarking baseline (13824, 4608)...


 75%|███████▌  | 54/72 [00:29<00:10,  1.66it/s]

Benchmarking baseline (15360, 5120)...


 76%|███████▋  | 55/72 [00:30<00:11,  1.51it/s]

Benchmarking baseline (16384, 3072)...


 78%|███████▊  | 56/72 [00:30<00:10,  1.55it/s]

Benchmarking baseline (16896, 5632)...


 79%|███████▉  | 57/72 [00:31<00:11,  1.36it/s]

Benchmarking baseline (18432, 6144)...


 81%|████████  | 58/72 [00:32<00:11,  1.18it/s]

Benchmarking baseline (19456, 3584)...


 82%|████████▏ | 59/72 [00:33<00:10,  1.23it/s]

Benchmarking baseline (19968, 6656)...


 83%|████████▎ | 60/72 [00:34<00:11,  1.05it/s]

Benchmarking baseline (21504, 7168)...


 85%|████████▍ | 61/72 [00:36<00:12,  1.11s/it]

Benchmarking baseline (22016, 4096)...


 86%|████████▌ | 62/72 [00:37<00:10,  1.05s/it]

Benchmarking baseline (23040, 7680)...


 88%|████████▊ | 63/72 [00:39<00:11,  1.25s/it]

Benchmarking baseline (24576, 4608)...


 89%|████████▉ | 64/72 [00:40<00:09,  1.21s/it]

Benchmarking baseline (24576, 8192)...


 90%|█████████ | 65/72 [00:42<00:09,  1.41s/it]

Benchmarking baseline (27648, 5120)...


 92%|█████████▏| 66/72 [00:43<00:08,  1.39s/it]

Benchmarking baseline (30208, 5632)...


 93%|█████████▎| 67/72 [00:45<00:07,  1.46s/it]

Benchmarking baseline (32768, 6144)...


 94%|█████████▍| 68/72 [00:46<00:06,  1.58s/it]

Benchmarking baseline (35840, 6656)...


 96%|█████████▌| 69/72 [00:49<00:05,  1.77s/it]

Benchmarking baseline (38400, 7168)...


 97%|█████████▋| 70/72 [00:51<00:04,  2.02s/it]

Benchmarking baseline (40960, 7680)...


 99%|█████████▊| 71/72 [00:54<00:02,  2.30s/it]

Benchmarking baseline (44032, 8192)...


100%|██████████| 72/72 [00:58<00:00,  1.24it/s]


Benchmarking +cuda...


  0%|          | 0/72 [00:00<?, ?it/s]

Benchmarking +cuda (1024, 1024)...


  1%|▏         | 1/72 [00:00<00:15,  4.50it/s]

Benchmarking +cuda (1024, 2816)...


  3%|▎         | 2/72 [00:07<05:18,  4.55s/it]

Benchmarking +cuda (1152, 1152)...


  4%|▍         | 3/72 [00:23<10:58,  9.54s/it]

Benchmarking +cuda (1152, 3072)...


  6%|▌         | 4/72 [00:23<06:39,  5.87s/it]

Benchmarking +cuda (1280, 1280)...


  7%|▋         | 5/72 [00:23<04:16,  3.84s/it]

Benchmarking +cuda (1280, 3584)...


  8%|▊         | 6/72 [00:23<02:52,  2.61s/it]

Benchmarking +cuda (1408, 1408)...


 10%|▉         | 7/72 [00:24<01:59,  1.84s/it]

Benchmarking +cuda (1408, 3840)...


 11%|█         | 8/72 [00:24<01:25,  1.33s/it]

Benchmarking +cuda (1536, 1536)...


 12%|█▎        | 9/72 [00:24<01:02,  1.01it/s]

Benchmarking +cuda (1536, 4096)...


 14%|█▍        | 10/72 [00:24<00:46,  1.32it/s]

Benchmarking +cuda (2048, 2048)...


 15%|█▌        | 11/72 [00:25<00:36,  1.67it/s]

Benchmarking +cuda (2048, 5632)...


 17%|█▋        | 12/72 [00:25<00:29,  2.01it/s]

Benchmarking +cuda (2560, 2560)...


 18%|█▊        | 13/72 [00:25<00:24,  2.37it/s]

Benchmarking +cuda (2560, 6912)...


 19%|█▉        | 14/72 [00:25<00:22,  2.63it/s]

Benchmarking +cuda (3072, 1024)...


 21%|██        | 15/72 [00:26<00:19,  2.98it/s]

Benchmarking +cuda (3072, 3072)...


 22%|██▏       | 16/72 [00:26<00:17,  3.23it/s]

Benchmarking +cuda (3072, 8192)...


 24%|██▎       | 17/72 [00:26<00:16,  3.25it/s]

Benchmarking +cuda (3456, 1152)...


 25%|██▌       | 18/72 [00:27<00:15,  3.50it/s]

Benchmarking +cuda (3584, 3584)...


 26%|██▋       | 19/72 [00:27<00:14,  3.60it/s]

Benchmarking +cuda (3584, 9728)...


 28%|██▊       | 20/72 [00:27<00:15,  3.39it/s]

Benchmarking +cuda (3840, 1280)...


 29%|██▉       | 21/72 [00:27<00:14,  3.61it/s]

Benchmarking +cuda (4096, 4096)...


 31%|███       | 22/72 [00:28<00:13,  3.63it/s]

Benchmarking +cuda (4096, 11008)...


 32%|███▏      | 23/72 [00:28<00:14,  3.33it/s]

Benchmarking +cuda (4224, 1408)...


 33%|███▎      | 24/72 [00:28<00:13,  3.53it/s]

Benchmarking +cuda (4608, 1536)...


 35%|███▍      | 25/72 [00:28<00:12,  3.68it/s]

Benchmarking +cuda (4608, 4608)...


 36%|███▌      | 26/72 [00:29<00:12,  3.65it/s]

Benchmarking +cuda (4608, 12288)...


 38%|███▊      | 27/72 [00:29<00:13,  3.27it/s]

Benchmarking +cuda (5120, 5120)...


 39%|███▉      | 28/72 [00:29<00:13,  3.30it/s]

Benchmarking +cuda (5120, 13824)...


 40%|████      | 29/72 [00:30<00:14,  3.00it/s]

Benchmarking +cuda (5632, 1024)...


 42%|████▏     | 30/72 [00:30<00:12,  3.27it/s]

Benchmarking +cuda (5632, 5632)...


 43%|████▎     | 31/72 [00:30<00:12,  3.26it/s]

Benchmarking +cuda (5632, 15104)...


 44%|████▍     | 32/72 [00:31<00:13,  2.87it/s]

Benchmarking +cuda (6144, 1152)...


 46%|████▌     | 33/72 [00:31<00:12,  3.14it/s]

Benchmarking +cuda (6144, 2048)...


 47%|████▋     | 34/72 [00:31<00:11,  3.33it/s]

Benchmarking +cuda (6144, 6144)...


 49%|████▊     | 35/72 [00:32<00:11,  3.25it/s]

Benchmarking +cuda (6144, 16384)...


 50%|█████     | 36/72 [00:32<00:12,  2.84it/s]

Benchmarking +cuda (6656, 6656)...


 51%|█████▏    | 37/72 [00:32<00:12,  2.89it/s]

Benchmarking +cuda (6656, 17920)...


 53%|█████▎    | 38/72 [00:33<00:13,  2.48it/s]

Benchmarking +cuda (7168, 1280)...


 54%|█████▍    | 39/72 [00:33<00:11,  2.79it/s]

Benchmarking +cuda (7168, 7168)...


 56%|█████▌    | 40/72 [00:34<00:11,  2.83it/s]

Benchmarking +cuda (7168, 19200)...


 57%|█████▋    | 41/72 [00:34<00:13,  2.35it/s]

Benchmarking +cuda (7680, 1408)...


 58%|█████▊    | 42/72 [00:34<00:11,  2.68it/s]

Benchmarking +cuda (7680, 2560)...


 60%|█████▉    | 43/72 [00:35<00:09,  2.90it/s]

Benchmarking +cuda (7680, 7680)...


 61%|██████    | 44/72 [00:35<00:09,  2.84it/s]

Benchmarking +cuda (7680, 20480)...


 62%|██████▎   | 45/72 [00:36<00:11,  2.34it/s]

Benchmarking +cuda (8192, 1536)...


 64%|██████▍   | 46/72 [00:36<00:09,  2.66it/s]

Benchmarking +cuda (8192, 8192)...


 65%|██████▌   | 47/72 [00:36<00:09,  2.66it/s]

Benchmarking +cuda (8192, 22016)...


 67%|██████▋   | 48/72 [00:37<00:11,  2.16it/s]

Benchmarking +cuda (9216, 3072)...


 68%|██████▊   | 49/72 [00:37<00:09,  2.43it/s]

Benchmarking +cuda (10752, 3584)...


 69%|██████▉   | 50/72 [00:38<00:08,  2.61it/s]

Benchmarking +cuda (11264, 2048)...


 71%|███████   | 51/72 [00:38<00:07,  2.83it/s]

Benchmarking +cuda (12288, 4096)...


 72%|███████▏  | 52/72 [00:38<00:07,  2.85it/s]

Benchmarking +cuda (13824, 2560)...


 74%|███████▎  | 53/72 [00:39<00:06,  2.96it/s]

Benchmarking +cuda (13824, 4608)...


 75%|███████▌  | 54/72 [00:39<00:06,  2.87it/s]

Benchmarking +cuda (15360, 5120)...


 76%|███████▋  | 55/72 [00:39<00:06,  2.76it/s]

Benchmarking +cuda (16384, 3072)...


 78%|███████▊  | 56/72 [00:40<00:05,  2.83it/s]

Benchmarking +cuda (16896, 5632)...


 79%|███████▉  | 57/72 [00:40<00:05,  2.65it/s]

Benchmarking +cuda (18432, 6144)...


 81%|████████  | 58/72 [00:41<00:05,  2.41it/s]

Benchmarking +cuda (19456, 3584)...


 82%|████████▏ | 59/72 [00:41<00:05,  2.48it/s]

Benchmarking +cuda (19968, 6656)...


 83%|████████▎ | 60/72 [00:41<00:05,  2.27it/s]

Benchmarking +cuda (21504, 7168)...


 85%|████████▍ | 61/72 [00:42<00:05,  2.06it/s]

Benchmarking +cuda (22016, 4096)...


 86%|████████▌ | 62/72 [00:42<00:04,  2.15it/s]

Benchmarking +cuda (23040, 7680)...


 88%|████████▊ | 63/72 [00:43<00:04,  1.97it/s]

Benchmarking +cuda (24576, 4608)...


 89%|████████▉ | 64/72 [00:44<00:03,  2.00it/s]

Benchmarking +cuda (24576, 8192)...


 90%|█████████ | 65/72 [00:44<00:03,  1.81it/s]

Benchmarking +cuda (27648, 5120)...


 92%|█████████▏| 66/72 [00:45<00:03,  1.84it/s]

Benchmarking +cuda (30208, 5632)...


 93%|█████████▎| 67/72 [00:45<00:02,  1.78it/s]

Benchmarking +cuda (32768, 6144)...


 94%|█████████▍| 68/72 [00:46<00:02,  1.69it/s]

Benchmarking +cuda (35840, 6656)...


 96%|█████████▌| 69/72 [00:47<00:01,  1.57it/s]

Benchmarking +cuda (38400, 7168)...


 97%|█████████▋| 70/72 [00:48<00:01,  1.43it/s]

Benchmarking +cuda (40960, 7680)...


 99%|█████████▊| 71/72 [00:49<00:00,  1.29it/s]

Benchmarking +cuda (44032, 8192)...


100%|██████████| 72/72 [00:50<00:00,  1.44it/s]


Benchmarking +fp8...


  0%|          | 0/72 [00:00<?, ?it/s]

Benchmarking +fp8 (1024, 1024)...


  1%|▏         | 1/72 [00:00<00:15,  4.54it/s]

Benchmarking +fp8 (1024, 2816)...


  3%|▎         | 2/72 [00:02<01:52,  1.61s/it]

Benchmarking +fp8 (1152, 1152)...


  4%|▍         | 3/72 [00:05<02:27,  2.14s/it]

Benchmarking +fp8 (1152, 3072)...


  6%|▌         | 4/72 [00:05<01:34,  1.39s/it]

Benchmarking +fp8 (1280, 1280)...


  7%|▋         | 5/72 [00:06<01:04,  1.03it/s]

Benchmarking +fp8 (1280, 3584)...


  8%|▊         | 6/72 [00:06<00:47,  1.39it/s]

Benchmarking +fp8 (1408, 1408)...


 10%|▉         | 7/72 [00:06<00:36,  1.80it/s]

Benchmarking +fp8 (1408, 3840)...


 11%|█         | 8/72 [00:06<00:29,  2.20it/s]

Benchmarking +fp8 (1536, 1536)...


 12%|█▎        | 9/72 [00:06<00:24,  2.60it/s]

Benchmarking +fp8 (1536, 4096)...


 14%|█▍        | 10/72 [00:07<00:21,  2.95it/s]

Benchmarking +fp8 (2048, 2048)...


 15%|█▌        | 11/72 [00:07<00:18,  3.26it/s]

Benchmarking +fp8 (2048, 5632)...


 17%|█▋        | 12/72 [00:07<00:17,  3.44it/s]

Benchmarking +fp8 (2560, 2560)...


 18%|█▊        | 13/72 [00:07<00:16,  3.63it/s]

Benchmarking +fp8 (2560, 6912)...


 19%|█▉        | 14/72 [00:08<00:16,  3.60it/s]

Benchmarking +fp8 (3072, 1024)...


 21%|██        | 15/72 [00:08<00:15,  3.80it/s]

Benchmarking +fp8 (3072, 3072)...


 22%|██▏       | 16/72 [00:08<00:14,  3.84it/s]

Benchmarking +fp8 (3072, 8192)...


 24%|██▎       | 17/72 [00:08<00:15,  3.65it/s]

Benchmarking +fp8 (3456, 1152)...


 25%|██▌       | 18/72 [00:09<00:14,  3.84it/s]

Benchmarking +fp8 (3584, 3584)...


 26%|██▋       | 19/72 [00:09<00:13,  3.85it/s]

Benchmarking +fp8 (3584, 9728)...


 28%|██▊       | 20/72 [00:09<00:14,  3.56it/s]

Benchmarking +fp8 (3840, 1280)...


 29%|██▉       | 21/72 [00:10<00:13,  3.76it/s]

Benchmarking +fp8 (4096, 4096)...


 31%|███       | 22/72 [00:10<00:13,  3.72it/s]

Benchmarking +fp8 (4096, 11008)...


 32%|███▏      | 23/72 [00:10<00:14,  3.37it/s]

Benchmarking +fp8 (4224, 1408)...


 33%|███▎      | 24/72 [00:10<00:13,  3.58it/s]

Benchmarking +fp8 (4608, 1536)...


 35%|███▍      | 25/72 [00:11<00:12,  3.72it/s]

Benchmarking +fp8 (4608, 4608)...


 36%|███▌      | 26/72 [00:11<00:12,  3.65it/s]

Benchmarking +fp8 (4608, 12288)...


 38%|███▊      | 27/72 [00:11<00:14,  3.19it/s]

Benchmarking +fp8 (5120, 5120)...


 39%|███▉      | 28/72 [00:12<00:13,  3.22it/s]

Benchmarking +fp8 (5120, 13824)...


 40%|████      | 29/72 [00:12<00:15,  2.85it/s]

Benchmarking +fp8 (5632, 1024)...


 42%|████▏     | 30/72 [00:12<00:13,  3.16it/s]

Benchmarking +fp8 (5632, 5632)...


 43%|████▎     | 31/72 [00:13<00:13,  3.15it/s]

Benchmarking +fp8 (5632, 15104)...


 44%|████▍     | 32/72 [00:13<00:14,  2.72it/s]

Benchmarking +fp8 (6144, 1152)...


 46%|████▌     | 33/72 [00:13<00:12,  3.04it/s]

Benchmarking +fp8 (6144, 2048)...


 47%|████▋     | 34/72 [00:14<00:11,  3.24it/s]

Benchmarking +fp8 (6144, 6144)...


 49%|████▊     | 35/72 [00:14<00:11,  3.14it/s]

Benchmarking +fp8 (6144, 16384)...


 50%|█████     | 36/72 [00:15<00:14,  2.51it/s]

Benchmarking +fp8 (6656, 6656)...


 51%|█████▏    | 37/72 [00:15<00:13,  2.59it/s]

Benchmarking +fp8 (6656, 17920)...


 53%|█████▎    | 38/72 [00:16<00:15,  2.13it/s]

Benchmarking +fp8 (7168, 1280)...


 54%|█████▍    | 39/72 [00:16<00:13,  2.48it/s]

Benchmarking +fp8 (7168, 7168)...


 56%|█████▌    | 40/72 [00:16<00:12,  2.53it/s]

Benchmarking +fp8 (7168, 19200)...


 57%|█████▋    | 41/72 [00:17<00:15,  1.99it/s]

Benchmarking +fp8 (7680, 1408)...


 58%|█████▊    | 42/72 [00:17<00:12,  2.34it/s]

Benchmarking +fp8 (7680, 2560)...


 60%|█████▉    | 43/72 [00:17<00:11,  2.62it/s]

Benchmarking +fp8 (7680, 7680)...


 61%|██████    | 44/72 [00:18<00:10,  2.55it/s]

Benchmarking +fp8 (7680, 20480)...


 62%|██████▎   | 45/72 [00:19<00:14,  1.88it/s]

Benchmarking +fp8 (8192, 1536)...


 64%|██████▍   | 46/72 [00:19<00:11,  2.23it/s]

Benchmarking +fp8 (8192, 8192)...


 65%|██████▌   | 47/72 [00:19<00:11,  2.27it/s]

Benchmarking +fp8 (8192, 22016)...


 67%|██████▋   | 48/72 [00:20<00:14,  1.66it/s]

Benchmarking +fp8 (9216, 3072)...


 68%|██████▊   | 49/72 [00:21<00:11,  1.95it/s]

Benchmarking +fp8 (10752, 3584)...


 69%|██████▉   | 50/72 [00:21<00:10,  2.16it/s]

Benchmarking +fp8 (11264, 2048)...


 71%|███████   | 51/72 [00:21<00:08,  2.43it/s]

Benchmarking +fp8 (12288, 4096)...


 72%|███████▏  | 52/72 [00:22<00:07,  2.51it/s]

Benchmarking +fp8 (13824, 2560)...


 74%|███████▎  | 53/72 [00:22<00:07,  2.65it/s]

Benchmarking +fp8 (13824, 4608)...


 75%|███████▌  | 54/72 [00:22<00:06,  2.58it/s]

Benchmarking +fp8 (15360, 5120)...


 76%|███████▋  | 55/72 [00:23<00:07,  2.41it/s]

Benchmarking +fp8 (16384, 3072)...


 78%|███████▊  | 56/72 [00:23<00:06,  2.49it/s]

Benchmarking +fp8 (16896, 5632)...


 79%|███████▉  | 57/72 [00:24<00:06,  2.24it/s]

Benchmarking +fp8 (18432, 6144)...


 81%|████████  | 58/72 [00:24<00:07,  2.00it/s]

Benchmarking +fp8 (19456, 3584)...


 82%|████████▏ | 59/72 [00:25<00:06,  2.06it/s]

Benchmarking +fp8 (19968, 6656)...


 83%|████████▎ | 60/72 [00:26<00:06,  1.77it/s]

Benchmarking +fp8 (21504, 7168)...


 85%|████████▍ | 61/72 [00:27<00:07,  1.52it/s]

Benchmarking +fp8 (22016, 4096)...


 86%|████████▌ | 62/72 [00:27<00:06,  1.62it/s]

Benchmarking +fp8 (23040, 7680)...


 88%|████████▊ | 63/72 [00:28<00:06,  1.37it/s]

Benchmarking +fp8 (24576, 4608)...


 89%|████████▉ | 64/72 [00:29<00:05,  1.40it/s]

Benchmarking +fp8 (24576, 8192)...


 90%|█████████ | 65/72 [00:30<00:05,  1.21it/s]

Benchmarking +fp8 (27648, 5120)...


 92%|█████████▏| 66/72 [00:31<00:04,  1.21it/s]

Benchmarking +fp8 (30208, 5632)...


 93%|█████████▎| 67/72 [00:32<00:04,  1.16it/s]

Benchmarking +fp8 (32768, 6144)...


 94%|█████████▍| 68/72 [00:33<00:03,  1.07it/s]

Benchmarking +fp8 (35840, 6656)...


 96%|█████████▌| 69/72 [00:34<00:03,  1.05s/it]

Benchmarking +fp8 (38400, 7168)...


 97%|█████████▋| 70/72 [00:36<00:02,  1.19s/it]

Benchmarking +fp8 (40960, 7680)...


 99%|█████████▊| 71/72 [00:37<00:01,  1.33s/it]

Benchmarking +fp8 (44032, 8192)...


100%|██████████| 72/72 [00:39<00:00,  1.82it/s]


In [188]:
speedups = {}
for model_size, tensor_shapes in shapes.items():
    fp8_forward_latency = 0
    fp8_backward_latency = 0
    our_forward_latency = 0
    our_backward_latency = 0
    for key in tensor_shapes:
        fp8_forward_latency += RESULTS['+fp8'][key]["forward_ms"]
        fp8_backward_latency += RESULTS['+fp8'][key]["total_ms"] - RESULTS['+fp8'][key]["forward_ms"]
        our_forward_latency += RESULTS['+cuda'][key]["forward_ms"]
        our_backward_latency += RESULTS['+cuda'][key]["total_ms"] - RESULTS['+cuda'][key]["forward_ms"]
    speedups[model_size] = (fp8_forward_latency / our_forward_latency, fp8_backward_latency/our_backward_latency)
print("Quartet vs FP8 speedups:\n")
speedups

Quartet vs FP8 speedups:



{102760448: (1.6421884192100795, 0.5700544965507398),
 143327232: (1.7507068972351185, 0.5428460629888379),
 203161600: (1.816469204026392, 0.5630831684896017),
 265650176: (1.8337929823925465, 0.5852455896793946),
 339738624: (1.8373143452188634, 0.6084983531164263),
 822083584: (1.94544993557363, 0.7102217258283788),
 1585971200: (2.016410358863029, 0.7844006064602403),
 2717908992: (2.04055388509346, 0.8540075062870569),
 4367319040: (2.0577915015558075, 0.9298757910443761),
 6476005376: (2.1262369945933037, 0.9957348479609069),
 9172942848: (2.4712320149265006, 1.0801941994467603),
 12687769600: (2.4354341004521824, 1.1340499333657252),
 16811294720: (2.298580006518396, 1.1847066991846906),
 21743271936: (2.2977885692458155, 1.2498281167452674),
 27821867008: (2.284662048985173, 1.306921984970655),
 34630270976: (2.2658029762562637, 1.3405308805024048),
 42467328000: (2.1909100647863364, 1.3656890776400838),
 51808043008: (2.127026201426419, 1.3737888056720073)}

In [189]:
speedups = {}
for model_size, tensor_shapes in shapes.items():
    fp8_forward_latency = 0
    fp8_backward_latency = 0
    our_forward_latency = 0
    our_backward_latency = 0
    for key in tensor_shapes:
        fp8_forward_latency += RESULTS['baseline'][key]["forward_ms"]
        fp8_backward_latency += RESULTS['baseline'][key]["total_ms"] - RESULTS['baseline'][key]["forward_ms"]
        our_forward_latency += RESULTS['+cuda'][key]["forward_ms"]
        our_backward_latency += RESULTS['+cuda'][key]["total_ms"] - RESULTS['+cuda'][key]["forward_ms"]
    speedups[model_size] = (fp8_forward_latency / our_forward_latency, fp8_backward_latency/our_backward_latency)
print("Quartet vs BF16 speedups:\n")
speedups

Quartet vs BF16 speedups:



{102760448: (3.5049470226024093, 0.8646094275384334),
 143327232: (3.715967665970758, 0.9213451205718582),
 203161600: (3.7432328443731238, 0.9898115334503184),
 265650176: (3.9949393126868133, 1.1018071412556414),
 339738624: (3.9681931586149175, 1.1402885192169),
 822083584: (4.343412928598179, 1.3829555547147108),
 1585971200: (4.53723356118794, 1.585533786598462),
 2717908992: (4.6859598161053375, 1.777270307633622),
 4367319040: (4.753695314032934, 1.9254506999337986),
 6476005376: (4.777181752338976, 2.0752000752899664),
 9172942848: (4.73225151546531, 2.195375242005379),
 12687769600: (4.688234075118633, 2.3125219626443494),
 16811294720: (4.32008093026321, 2.3977843198505044),
 21743271936: (4.114755971002955, 2.480224805446807),
 27821867008: (3.992000854253022, 2.5316271814060287),
 34630270976: (3.9050340451224814, 2.6144872925084353),
 42467328000: (3.756596215509311, 2.6503525698540353),
 51808043008: (3.6367347234294076, 2.7054283751005626)}

Paste them into plots.ipynb!