In [1]:
%env CUDA_VISIBLE_DEVICES=9
%env CUDA_LAUNCH_BLOCKING=1

import sys
sys.path.append("../src")

import torch
from models.quantization.quantizers import QuestMXFP4Quantizer, AlbertTsengQuantizer, EdenSRQuantizer

from tqdm.auto import trange, tqdm

env: CUDA_VISIBLE_DEVICES=9
env: CUDA_LAUNCH_BLOCKING=1


In [2]:
x = (torch.randn(2**20, 128, device="cuda") * torch.logspace(0, 8, 2**20, base=2, device="cuda").unsqueeze(1)).flatten()

scale_dtype_group = [("fp32", 128), ("e4m3", 16), ("e8m0", 32)]
optimal_scale_override = {
    "fp32": 0.96,
    "e4m3": 0.93,
    "e8m0": 0.91,
}

table_rows = []
data = {}

for (scale_dtype, group_dim) in scale_dtype_group:
    for unbiased in ["no", "sr", "eden"]:
        if unbiased == "eden":
            scale_override = optimal_scale_override[scale_dtype]
        else:
            scale_override = 1.0
        
        quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=group_dim, scale_dtype=scale_dtype, unbiased=unbiased, scale_override=scale_override)
        dq = quantizer(x).view(-1, quantizer.hadamard_dim) @ quantizer.hadamard_matrix
        ref = x.view(-1, quantizer.hadamard_dim)
        quad_err = ((ref - dq).pow(2).sum(dim=-1) / ref.pow(2).sum(dim=-1)).mean()
        eff_bitwidth = (-torch.log2(quad_err) / 2).item()
        magnitude_alignment = ((ref * dq).sum(dim=-1) / (ref * ref).sum(dim=-1)).abs().mean().item()
        
        data[(group_dim, scale_dtype, unbiased)] = (eff_bitwidth, magnitude_alignment - 1)
        
        table_rows.append(
            (scale_dtype, group_dim, unbiased, eff_bitwidth, magnitude_alignment - 1)
        )

# Print markdown table
print("| Scales DType | Group Size | Unbiased | MSE, rate-distortion bits | Magnitude Misalignment |")
print("|--------------|------------|----------|---------------------------|------------------------|")
for row in table_rows:
    print(f"| {str(row[0].upper()):<12} | {str(row[1]):<10} | {str(row[2]).upper():<8} | {row[3]:>25.2f} | {row[4]:>22.5f} |")


| Scales DType | Group Size | Unbiased | MSE, rate-distortion bits | Magnitude Misalignment |
|--------------|------------|----------|---------------------------|------------------------|
| FP32         | 128        | NO       |                      3.20 |               -0.00836 |
| FP32         | 128        | SR       |                      2.69 |               -0.00001 |
| FP32         | 128        | EDEN     |                      3.22 |                0.00000 |
| E4M3         | 16         | NO       |                      3.28 |                0.01119 |
| E4M3         | 16         | SR       |                      2.70 |               -0.00000 |
| E4M3         | 16         | EDEN     |                      3.34 |               -0.00054 |
| E8M0         | 32         | NO       |                      3.10 |               -0.00946 |
| E8M0         | 32         | SR       |                      2.59 |                0.00001 |
| E8M0         | 32         | EDEN     |                    

In [3]:
# import numpy as np

# for scale_override in np.linspace(0.8, 1.0, 50):
#     quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=128, scale_dtype="fp32", unbiased="eden", scale_override=scale_override)
    
#     dq = quantizer(x).view(-1, quantizer.hadamard_dim) @ quantizer.hadamard_matrix
#     ref = x.view(-1, quantizer.hadamard_dim)
#     quad_err = ((ref - dq).pow(2).sum(dim=-1) / ref.pow(2).sum(dim=-1)).mean()
#     eff_bitwidth = (-torch.log2(quad_err) / 2).item()
#     magnitude_alignment = ((ref * dq).sum(dim=-1) / (ref * ref).sum(dim=-1)).abs().mean().item()
    
#     print(f"{scale_override:.2f}: {eff_bitwidth:.3f}, {magnitude_alignment-1:.5f}")
    

Optimal `scale_override`:

 - `e8m0`: 0.91
 - `e4m3`: 0.93

In [4]:
x = (torch.randn(2**20, 128, device="cuda") * torch.logspace(0, 20, 2**20, base=2, device="cuda").unsqueeze(1)).flatten()

quantizer = EdenSRQuantizer(hadamard_dim=128, group_dim=16, scale_dtype="e4m3", unbiased="eden", rerotate="signs", scale_override=0.93)

for accum_steps in tqdm([1, 4, 16, 64, 256, 1024]):
    acc = torch.zeros_like(x)

    for _ in trange(accum_steps, leave=False):
        quantizer.re_randomize()
        acc += (quantizer(x).view(-1, quantizer.hadamard_matrix.shape[0]) @ quantizer.hadamard_matrix).view_as(x)
    err = (acc / accum_steps - x).pow(2).sum() / x.pow(2).sum()
    print(accum_steps, ":", err.item())
    

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

1 : 0.009754621423780918


  0%|          | 0/4 [00:00<?, ?it/s]

4 : 0.0024393717758357525


  0%|          | 0/16 [00:00<?, ?it/s]

16 : 0.0006100585451349616


  0%|          | 0/64 [00:00<?, ?it/s]

64 : 0.00015281731612049043


  0%|          | 0/256 [00:00<?, ?it/s]

256 : 3.8397196476580575e-05


  0%|          | 0/1024 [00:00<?, ?it/s]

1024 : 9.84456801234046e-06
