In [1]:
%env CUDA_VISIBLE_DEVICES=9
%env CUDA_LAUNCH_BLOCKING=1

import sys
sys.path.append("../src")

import torch
from models.quantization.quantizers import QuestMXFP4Quantizer, AlbertTsengQuantizer, EdenSRQuantizer

from tqdm.auto import trange, tqdm

env: CUDA_VISIBLE_DEVICES=9
env: CUDA_LAUNCH_BLOCKING=1


In [None]:
x = (torch.randn(2**20, 128, device="cuda") * torch.logspace(0, 10, 2**20, base=2, device="cuda").unsqueeze(1)).flatten()

scale_dtype_group = [("fp32", 128), ("e4m3", 16), ("e8m0", 32)]

table_rows = []
data = {}

for (scale_dtype, group_dim) in scale_dtype_group:
    for unbiased in ["no", "sr", "eden"]:
        quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=group_dim, scale_dtype=scale_dtype, unbiased=unbiased)
        dq = quantizer(x).view(-1, quantizer.hadamard_dim) @ quantizer.hadamard_matrix
        ref = x.view(-1, quantizer.hadamard_dim)
        quad_err = ((ref - dq).pow(2).sum(dim=-1) / ref.pow(2).sum(dim=-1)).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        magnitude_alignment = ((ref * dq).sum(dim=-1) / (ref * ref).sum(dim=-1)).abs().mean().item()
        
        data[(group_dim, scale_dtype, unbiased)] = (eff_bitwidth, 1 - magnitude_alignment)
        
        table_rows.append(
            (scale_dtype, group_dim, unbiased, eff_bitwidth, 1 - magnitude_alignment)
        )

# Print markdown table
print("| Scales DType | Group Size | Unbiased | MSE, rate-distortion bits | Magnitude Misalignment |")
print("|--------------|------------|----------|---------------------------|------------------------|")
for row in table_rows:
    print(f"| {str(row[0].upper()):<12} | {str(row[1]):<10} | {str(row[2]).upper():<8} | {row[3]:>25.2f} | {row[4]:>22.5f} |")


| Scales DType | Group Size | Unbiased | MSE, rate-distortion bits | Magnitude Misalignment |
|--------------|------------|----------|---------------------------|------------------------|
| FP32         | 128        | NO       |                      3.20 |                0.00835 |
| FP32         | 128        | SR       |                      2.69 |               -0.00000 |
| FP32         | 128        | EDEN     |                      3.20 |                0.00000 |
| E4M3         | 16         | NO       |                      3.28 |               -0.01119 |
| E4M3         | 16         | SR       |                      2.70 |               -0.00001 |
| E4M3         | 16         | EDEN     |                      3.25 |               -0.00125 |
| E8M0         | 32         | NO       |                      3.10 |                0.00947 |
| E8M0         | 32         | SR       |                      2.58 |               -0.00001 |
| E8M0         | 32         | EDEN     |                    

In [3]:
x = (torch.randn(2**20, 128, device="cuda") * torch.logspace(0, 16, 2**20, base=2, device="cuda").unsqueeze(1)).flatten()

quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=16, scale_dtype="e4m3", unbiased="eden", rerotate="signs")

for accum_steps in tqdm([1, 4, 16, 64, 256, 1024, 4096]):
    acc = torch.zeros_like(x)

    for _ in trange(accum_steps, leave=False):
        quantizer.re_randomize()
        acc += (quantizer(x).view(-1, quantizer.hadamard_matrix.shape[0]) @ quantizer.hadamard_matrix).view_as(x)
    err = (acc / accum_steps - x).pow(2).sum() / x.pow(2).sum()
    print(accum_steps, ":", err.item())
    

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1 : 0.011094212532043457


  0%|          | 0/4 [00:00<?, ?it/s]

4 : 0.0027771450113505125


  0%|          | 0/16 [00:00<?, ?it/s]

16 : 0.0006977391894906759


  0%|          | 0/64 [00:00<?, ?it/s]

64 : 0.0001779412996256724


  0%|          | 0/256 [00:00<?, ?it/s]

256 : 4.796600842382759e-05


  0%|          | 0/1024 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
def sr_e4m3(x: torch.Tensor) -> torch.Tensor:
    if (x > 448.001).any():
        raise ValueError(f"Can't SR overflowing tensor: {x.max().item()} > 448")
    x = torch.clamp(x, -447.99, 447.99)
    
    if x.isnan().any():
        raise ValueError("x has NaNs")
    
    q = x.to(torch.float8_e4m3fn)
    nextdq = (q.view(torch.uint8) + 1).view(torch.float8_e4m3fn).float()
    prevdq = (q.view(torch.uint8) - 1).view(torch.float8_e4m3fn).float()
    dq = q.float()

    low = torch.where(
        dq > x,
        prevdq,
        dq,
    )
    
    high = torch.where(
        dq > x,
        dq,
        nextdq,
    )
    
    return torch.where(
        torch.bernoulli(
            (x - low) / (high - low)
        ) == 1.0,
        high,
        low,
    )

In [None]:
t = torch.ones((128,), device="cuda") * 448.001
sr_e4m3(t)

tensor([448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448., 448.,
        448., 448., 448., 448., 448., 448., 448., 448.], device='cuda:0')

In [None]:
def rtn_fp4(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
    inds = torch.bucketize(x, grid)

    lo = torch.clamp(inds - 1, min=0, max=15)
    hi = torch.clamp(inds,     min=0, max=15)

    low = grid[lo]
    high = grid[hi]

    return torch.where(
        (high - x) <= (x - low),
        high,
        low,
    )
    
    
grid = torch.tensor(
    [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0,
    0.0,  0.5,  1.0,  1.5,  2.0,  3.0,  4.0, 6.0],
    device="cuda",
)

In [None]:
t = torch.linspace(-6, 6, 1000, device="cuda")

q = rtn_fp4(t, grid)

(t @ q) / (t @ t)

tensor(1.0067, device='cuda:0')