In [1]:
%env CUDA_VISIBLE_DEVICES=8
%env CUDA_LAUNCH_BLOCKING=1

import sys
sys.path.append("../src")

import torch
from models.quantization.quantizers import QuestMXFP4Quantizer, AlbertTsengQuantizer, EdenSRQuantizer, IsolatedEdenQuantizer, QuestNvfp4Quantizer, Nvfp4Quantizer

from tqdm.auto import trange, tqdm

import numpy as np

env: CUDA_VISIBLE_DEVICES=8
env: CUDA_LAUNCH_BLOCKING=1


In [2]:
x = (torch.randn(2**20, 128, device="cuda") * torch.logspace(0, 8, 2**20, base=2, device="cuda").unsqueeze(1)).flatten()

scale_dtype_group = [
    # ("fp32", 128),
    ("e4m3", 16),
    # ("e8m0", 32),
]
optimal_scale_override = {
    "eden": {
        "fp32": 0.96,
        "e4m3": 0.93,
        "e8m0": 0.91,
    },
    "no": {
        "fp32": 0.96,
        "e4m3": 0.93,
        "e8m0": 0.84,
    },
    "sr": {
        "fp32": 1.00,
        "e4m3": 1.00,
        "e8m0": 1.00,   
    },
}

table_rows = []
data = {}

for (scale_dtype, group_dim) in scale_dtype_group:
    for unbiased in ["no", "sr", "eden"]:
        scale_override = optimal_scale_override[unbiased][scale_dtype]

        
        quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=group_dim, scale_dtype=scale_dtype, unbiased=unbiased, scale_override=scale_override, four_over_six=False)
        dq = quantizer(x).view(-1, quantizer.hadamard_dim) @ quantizer.hadamard_matrix
        ref = x.view(-1, quantizer.hadamard_dim)
        quad_err = ((ref - dq).pow(2).sum(dim=-1) / ref.pow(2).sum(dim=-1)).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        magnitude_alignment = ((ref * dq).sum(dim=-1) / (ref * ref).sum(dim=-1)).mean().item()
        
        data[(group_dim, scale_dtype, unbiased)] = (eff_bitwidth, 1 - magnitude_alignment)
        
        table_rows.append(
            (scale_dtype, group_dim, unbiased, quad_err, eff_bitwidth, 1 - magnitude_alignment)
        )

# Print markdown table
print("| Scales DType | Group Size | Unbiased | MSE        | Rate, bits | Magnitude Misalignment |")
print("|--------------|------------|----------|------------|------------|------------------------|")
for row in table_rows:
    print(f"| {str(row[0].upper()):<12} | {str(row[1]):<10} | {str(row[2]).upper():<8} | {row[3] * 1e3:>10.1f} | {row[4]:>10.2f} | {row[5]:>22.5f} |")


| Scales DType | Group Size | Unbiased | MSE        | Rate, bits | Magnitude Misalignment |
|--------------|------------|----------|------------|------------|------------------------|
| E4M3         | 16         | NO       |        9.0 |       3.40 |                0.00790 |
| E4M3         | 16         | SR       |       23.5 |       2.71 |                0.00002 |
| E4M3         | 16         | EDEN     |        9.8 |       3.34 |                0.00055 |


In [3]:
x = torch.randn((8192, 8192), device="cuda")

for square in [False, True]:
    for four_over_six in [False, True]:
        quantizer = Nvfp4Quantizer(square=square, four_over_six=four_over_six)
        dq = quantizer(x)
        quad_err = ((x - dq).pow(2).sum(dim=-1) / x.pow(2).sum(dim=-1)).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        magnitude_alignment = ((x * dq).sum(dim=-1) / (x * x).sum(dim=-1)).mean().item()
        
        print(f"{four_over_six=},{square=}: {quad_err*1e3:.1f} {eff_bitwidth:.3f}")

four_over_six=False,square=False: 9.0 3.394
four_over_six=True,square=False: 7.6 3.524
four_over_six=False,square=True: 12.4 3.168
four_over_six=True,square=True: 12.4 3.169


In [4]:
x = torch.randn((4096, 4096), device="cuda")
y = torch.randn((4096, 4096), device="cuda")

unbiased = "eden"

quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=16, scale_dtype="e4m3", unbiased=unbiased, scale_override=optimal_scale_override[unbiased]["e4m3"], rerotate='signs', four_over_six=True)

for acc_steps in tqdm([1, 4, 16, 64, 256]):
    acc_prod = torch.zeros((x.shape[0], y.shape[0]), device="cuda")
    for step in trange(acc_steps, leave=False):
        quantizer.re_randomize()
        xq = quantizer(x)
        yq = quantizer(y)
        acc_prod += xq @ yq.T
        
    quad_err = (acc_prod / acc_steps - x @ y.T).pow(2).mean() / (x @ y.T).pow(2).mean()
    eff_bitwidth = (-torch.log2(quad_err) / 2).item()
    print(f"{acc_steps}: {quad_err * 1e3:.2f} {eff_bitwidth:.3f}")


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1: 16.29 2.970


  0%|          | 0/4 [00:00<?, ?it/s]

4: 4.08 3.968


  0%|          | 0/16 [00:00<?, ?it/s]

16: 1.03 4.960


  0%|          | 0/64 [00:00<?, ?it/s]

64: 0.27 5.930


  0%|          | 0/256 [00:00<?, ?it/s]

256: 0.08 6.820


In [5]:
16.28

16.28