In [1]:
%env CUDA_VISIBLE_DEVICES=8
%env CUDA_LAUNCH_BLOCKING=1

import sys
sys.path.append("../src")

import torch
from models.quantization.quantizers import QuestMXFP4Quantizer, AlbertTsengQuantizer, EdenSRQuantizer, IsolatedEdenQuantizer, QuestNvfp4Quantizer, Nvfp4Quantizer

from tqdm.auto import trange, tqdm

import numpy as np

env: CUDA_VISIBLE_DEVICES=8
env: CUDA_LAUNCH_BLOCKING=1




In [2]:
x = (torch.randn(2**20, 128, device="cuda") * torch.logspace(0, 8, 2**20, base=2, device="cuda").unsqueeze(1)).flatten()

scale_dtype_group = [
    # ("fp32", 128),
    ("e4m3", 16),
    # ("e8m0", 32),
]
optimal_scale_override = {
    "eden": {
        "fp32": 0.96,
        "e4m3": 0.93,
        "e8m0": 0.91,
    },
    "no": {
        "fp32": 0.96,
        "e4m3": 0.93,
        "e8m0": 0.84,
    },
    "sr": {
        "fp32": 1.00,
        "e4m3": 1.00,
        "e8m0": 1.00,   
    },
}

table_rows = []
data = {}

for (scale_dtype, group_dim) in scale_dtype_group:
    for unbiased in ["no", "sr", "eden"]:
        scale_override = optimal_scale_override[unbiased][scale_dtype]

        
        quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=group_dim, scale_dtype=scale_dtype, unbiased=unbiased, scale_override=scale_override, four_over_six=False)
        dq = quantizer(x).view(-1, quantizer.hadamard_dim) @ quantizer.hadamard_matrix
        ref = x.view(-1, quantizer.hadamard_dim)
        quad_err = ((ref - dq).pow(2).sum(dim=-1) / ref.pow(2).sum(dim=-1)).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        magnitude_alignment = ((ref * dq).sum(dim=-1) / (ref * ref).sum(dim=-1)).mean().item()
        
        data[(group_dim, scale_dtype, unbiased)] = (eff_bitwidth, 1 - magnitude_alignment)
        
        table_rows.append(
            (scale_dtype, group_dim, unbiased, quad_err, eff_bitwidth, 1 - magnitude_alignment)
        )

# Print markdown table
print("| Scales DType | Group Size | Unbiased | MSE        | Rate, bits | Magnitude Misalignment |")
print("|--------------|------------|----------|------------|------------|------------------------|")
for row in table_rows:
    print(f"| {str(row[0].upper()):<12} | {str(row[1]):<10} | {str(row[2]).upper():<8} | {row[3] * 1e3:>10.1f} | {row[4]:>10.2f} | {row[5]:>22.5f} |")




| Scales DType | Group Size | Unbiased | MSE        | Rate, bits | Magnitude Misalignment |
|--------------|------------|----------|------------|------------|------------------------|
| E4M3         | 16         | NO       |        9.0 |       3.40 |                0.00792 |
| E4M3         | 16         | SR       |       23.5 |       2.71 |                0.00000 |
| E4M3         | 16         | EDEN     |        9.8 |       3.34 |                0.00056 |


In [3]:
x = torch.randn((8192, 8192), device="cuda")

for square in [False, True]:
    for four_over_six in [False, True]:
        quantizer = Nvfp4Quantizer(square=square, four_over_six=four_over_six)
        dq = quantizer(x)
        quad_err = ((x - dq).pow(2).sum(dim=-1) / x.pow(2).sum(dim=-1)).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        magnitude_alignment = ((x * dq).sum(dim=-1) / (x * x).sum(dim=-1)).mean().item()
        
        print(f"{four_over_six=},{square=}: {quad_err*1e3:.1f} {eff_bitwidth:.3f}")

four_over_six=False,square=False: 9.0 3.394
four_over_six=True,square=False: 7.6 3.524
four_over_six=False,square=True: 12.4 3.167
four_over_six=True,square=True: 12.4 3.168


In [4]:
x = torch.randn((4096, 4096), device="cuda")
y = torch.randn((4096, 4096), device="cuda")

unbiased = "sr"

quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=16, scale_dtype="e4m3", unbiased=unbiased, scale_override=optimal_scale_override[unbiased]["e4m3"], rerotate='signs', four_over_six=True)

for acc_steps in tqdm([1, 4, 16, 64]):
    acc_prod = torch.zeros((x.shape[0], y.shape[0]), device="cuda")
    for step in trange(acc_steps, leave=False):
        quantizer.re_randomize()
        xq = quantizer(x)
        yq = quantizer(y)
        acc_prod += xq @ yq.T
        
    quad_err = (acc_prod / acc_steps - x @ y.T).pow(2).mean() / (x @ y.T).pow(2).mean()
    eff_bitwidth = (-torch.log2(quad_err) / 2).item()
    print(f"{acc_steps}: {quad_err * 1e3:.2f} {eff_bitwidth:.3f}")


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1: 35.84 2.401


  0%|          | 0/4 [00:00<?, ?it/s]

4: 9.11 3.389


  0%|          | 0/16 [00:00<?, ?it/s]

16: 2.44 4.339


  0%|          | 0/64 [00:00<?, ?it/s]

64: 0.77 5.170


In [5]:
from models.quantization.quantizers.nvfp4_triton import sr_1x16s_fp4_kernel_wrapper

x = torch.randn((4096, 4096), device="cuda")
y = torch.randn((4096, 4096), device="cuda")

for acc_steps in tqdm([1, 4, 16, 64, 256, 1024]):
    acc_prod = torch.zeros((x.shape[0], y.shape[0]), device="cuda")
    for step in trange(acc_steps, leave=False):
        quantizer.re_randomize()
        xq = sr_1x16s_fp4_kernel_wrapper(x, 17/16, 16, False)
        yq = sr_1x16s_fp4_kernel_wrapper(y, 17/16, 16, False)
        acc_prod += xq @ yq.T
        
    quad_err = (acc_prod / acc_steps - x @ y.T).pow(2).mean() / (x @ y.T).pow(2).mean()
    eff_bitwidth = (-torch.log2(quad_err) / 2).item()
    print(f"{acc_steps}: {eff_bitwidth:.3f}")


  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1: 2.197


  0%|          | 0/4 [00:00<?, ?it/s]

4: 3.196


  0%|          | 0/16 [00:00<?, ?it/s]

16: 4.196


  0%|          | 0/64 [00:00<?, ?it/s]

64: 5.197


  0%|          | 0/256 [00:00<?, ?it/s]

256: 6.197


  0%|          | 0/1024 [00:00<?, ?it/s]

1024: 7.196


In [9]:
from models.quantization.schemes.tetrajetv2 import TetraJetV2Linear, TetraJetV2_fn
from models.quantization.schemes.quartet_2 import Quartet_II_Linear

x = torch.randn((3, 128, 4096), device="cuda", requires_grad=True)

# linear = TetraJetV2Linear(4096, 1024, device="cuda", dtype=torch.float32, bias=False, disable_forward_quant=True)
linear = Quartet_II_Linear(4096, 1024, device="cuda", dtype=torch.float32, bias=False, disable_forward_quant=True, hadamard_dim=128)

head = torch.nn.Linear(1024, 1, device="cuda")
target = torch.randn(3, 128, 1, device="cuda")


def get_loss(x, linear, head, target):
    return (head(linear(x)) - target).pow(2).mean()


linear.disable_backward_quant = True
x.grad = None
linear.weight.grad = None
get_loss(x, linear, head, target).backward()
ref_x_grad = x.grad.clone().detach()
ref_w_grad = linear.weight.grad.clone().detach()


linear.disable_backward_quant = False
for acc_steps in tqdm([1, 4, 16, 64]):    
    x.grad = None
    linear.weight.grad = None
    
    loss = get_loss(x, linear, head, target)

    for step in trange(acc_steps, leave=False):
        loss.backward(retain_graph=True)
        
    x_quad_err = (x.grad / acc_steps - ref_x_grad).pow(2).mean() / ref_x_grad.pow(2).mean()
    x_eff_bitwidth = (-torch.log2(x_quad_err) / 2).item()
    print(f"{acc_steps}:\n\tx: {x_eff_bitwidth:.3f} bits")
    
    w_quad_err = (linear.weight.grad / acc_steps - ref_w_grad).pow(2).mean() / ref_w_grad.pow(2).mean()
    w_eff_bitwidth = (-torch.log2(w_quad_err) / 2).item()
    print(f"\tw: {w_eff_bitwidth:.3f} bits")


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1:
	x: 2.847 bits
	w: 2.834 bits


  0%|          | 0/4 [00:00<?, ?it/s]

4:
	x: 3.839 bits
	w: 3.848 bits


  0%|          | 0/16 [00:00<?, ?it/s]

16:
	x: 4.843 bits
	w: 4.828 bits


  0%|          | 0/64 [00:00<?, ?it/s]

64:
	x: 5.840 bits
	w: 5.797 bits
