In [1]:
%env CUDA_VISIBLE_DEVICES=9

import sys
sys.path.append("../src")

env: CUDA_VISIBLE_DEVICES=9


In [2]:
import torch
from models.quantization.quantizers import QuestMXFP4Quantizer, AlbertTsengQuantizer, EdenSRQuantizer

x = (torch.randn(2**20, 128, device="cuda") * torch.logspace(0, 1, 2**20, base=2, device="cuda").unsqueeze(1)).flatten()

scale_dtype_hadamard_dim = [("fp32", 128), ("e4m3", 16), ("e8m0", 32)]

table_rows = []

for (scale_dtype, hadamard_dim) in scale_dtype_hadamard_dim:
    for unbiased in ["no", "sr", "eden"]:
        quantizer = EdenSRQuantizer(hadamard_dim=hadamard_dim, scale_dtype=scale_dtype, unbiased=unbiased)
        dq = (quantizer(x).view(-1, quantizer.hadamard_dim) @ quantizer.hadamard_matrix).view_as(x)
        quad_err = (x - dq).pow(2).mean() / x.pow(2).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        magnitude_alignment = ((x @ dq) / (x @ x)).item()
        
        table_rows.append(
            (scale_dtype, hadamard_dim, unbiased, eff_bitwidth, 1 - magnitude_alignment)
        )

# Print markdown table
print("| Scales DType | Group Size | Unbiased | MSE, rate-distortion bits | Magnitude Misalignment |")
print("|--------------|------------|----------|---------------------------|------------------------|")
for row in table_rows:
    print(f"| {str(row[0].upper()):<12} | {str(row[1]):<10} | {str(row[2]).upper():<8} | {row[3]:>25.2f} | {row[4]:>22.4f} |")



| Scales DType | Group Size | Unbiased | MSE, rate-distortion bits | Magnitude Misalignment |
|--------------|------------|----------|---------------------------|------------------------|
| FP32         | 128        | NO       |                      3.20 |                 0.0084 |
| FP32         | 128        | SR       |                      2.69 |                 0.0000 |
| FP32         | 128        | EDEN     |                      3.20 |                 0.0000 |
| E4M3         | 16         | NO       |                      3.27 |                -0.0121 |
| E4M3         | 16         | SR       |                      2.70 |                 0.0000 |
| E4M3         | 16         | EDEN     |                      3.27 |                -0.0000 |
| E8M0         | 32         | NO       |                      3.09 |                 0.0093 |
| E8M0         | 32         | SR       |                      2.58 |                -0.0000 |
| E8M0         | 32         | EDEN     |                    

In [3]:
def rtn_fp4(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
    inds = torch.bucketize(x, grid)

    lo = torch.clamp(inds - 1, min=0, max=15)
    hi = torch.clamp(inds,     min=0, max=15)

    low = grid[lo]
    high = grid[hi]

    return torch.where(
        (high - x) <= (x - low),
        high,
        low,
    )
    
    
grid = torch.tensor(
    [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0,
    0.0,  0.5,  1.0,  1.5,  2.0,  3.0,  4.0, 6.0],
    device="cuda",
)

In [12]:
t = torch.linspace(-6, 6, 1000, device="cuda")

q = rtn_fp4(t, grid)

(t @ q) / (t @ t)

tensor(1.0067, device='cuda:0')