In [2]:
import os
import torch
import math

from q8_gemm import q8_mm, q8_mm_bias
from fast_hadamard_transform import hadamard_transform

from safetensors.torch import load_file

def hadamard_quant(x):
    k = x.shape[-1]
    x_hadamard = hadamard_transform(x, scale=1/math.sqrt(k))
    x_abs_max_hadamard = x_hadamard.float().abs().max(-1, False).values
    x_scale_hadamard = x_abs_max_hadamard/127.0
    x_q8_hadamard = (x_hadamard.float() / x_scale_hadamard[..., None]).round().to(torch.int8)
    return x_q8_hadamard, x_scale_hadamard

def quant(x):
    x_abs_max = x.float().abs().max(-1, False).values
    x_scale = x_abs_max/127.0
    x_q8 =  (x.float() / x_scale[..., None]).round().to(torch.int8)
    return x_q8, x_scale


l_idx = 3
x = torch.load(f"/data/LTXVideo/acts/ffn/hs-{l_idx}.pt", map_location="cuda")[:, :, :]
model_weights = load_file("/data/ltx_weights/unet/unet_diffusion_pytorch_model.safetensors", device="cpu")
w = model_weights[f"transformer_blocks.{l_idx}.ff.net.0.proj.weight"].cuda()
bias = model_weights[f"transformer_blocks.{l_idx}.ff.net.0.proj.bias"].cuda().float()

In [3]:
k = x.shape[-1]
x_hadamard = hadamard_transform(x.to(torch.float8_e4m3fn), scale=1/math.sqrt(k))
w_hadamard = hadamard_transform(w.to(torch.float8_e4m3fn), scale=1/math.sqrt(k))

x_quant_h, x_scales_h = hadamard_quant(x.to(torch.float8_e4m3fn))
w_quant_h, w_scales_h = hadamard_quant(w.to(torch.float8_e4m3fn))

x_quant, x_scales = quant(x)
w_quant, w_scales = quant(w)

In [9]:
o_q8_h = q8_mm_bias(x_quant_h, w_quant_h, bias, x_scales_h, w_scales_h, False)
o_q8 = q8_mm_bias(x_quant, w_quant, bias, x_scales, w_scales, False)

o_q8_torch_h = ((x_scales_h[..., None] * w_scales_h[None, None, :]) * torch.matmul(x_quant_h.float(), w_quant_h.float().t()) + bias[None, None, :]).to(torch.float8_e4m3fn)
o_q8_torch = (((x_scales[..., None] * w_scales[None, None, :]) * torch.matmul(x_quant.float(), w_quant.float().t())) + bias[None, None, :]).to(torch.float8_e4m3fn)


o_orig = torch.matmul(x.half(), w.half().t()) + bias[None, None, :].half()

In [10]:
def diff_max(a, b):
    return (a.float() - b.float()).abs().max()

def diff_mean(a, b):
    return (a.float() - b.float()).abs().mean()


def diff_quantiles(a, b):
    if a.ndim > 2:
        return torch.quantile((a.float() - b.float()).abs()[1, :2048, :], torch.tensor([0.25, 0.5, 0.75, 0.9, 0.99, 1.0]).cuda())
    else:
        return torch.quantile((a.float() - b.float()).abs()[:2048, :], torch.tensor([0.25, 0.5, 0.75, 0.9, 0.99, 1.0]).cuda())
        
diff_q8_h = diff_max(o_q8_h, o_orig)
diff_q8 = diff_max(o_q8, o_orig)
diff_q8_torch_h = diff_max(o_q8_torch_h, o_orig)
diff_q8_torch = diff_max(o_q8_torch, o_orig)

In [11]:
print("DIFF Hadamard: ", diff_q8_h)
print("DIFF no Hadamard: ", diff_q8)

print("DIFF Hadamard Torch: ", diff_q8_torch_h)
print("DIFF no Hadamard Torch: ", diff_q8_torch)

print("Diff mean hadamard: ", diff_max(o_orig, o_q8_h))
print("Diff mean: ", diff_max(o_orig, o_q8))

print("torch q8: ", diff_max(o_q8_torch_h, o_q8_h))

DIFF Hadamard:  tensor(0.5938, device='cuda:0')
DIFF no Hadamard:  tensor(0.7031, device='cuda:0')
DIFF Hadamard Torch:  tensor(0.5938, device='cuda:0')
DIFF no Hadamard Torch:  tensor(0.7031, device='cuda:0')
Diff mean hadamard:  tensor(0.5938, device='cuda:0')
Diff mean:  tensor(0.7031, device='cuda:0')
torch q8:  tensor(0.5000, device='cuda:0')


In [20]:
print("torch q8: ", diff_quantiles(o_q8_torch_h, o_q8_h))

torch q8:  tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2500], device='cuda:0')


In [16]:
(o_q8_h.float() - o_q8_torch_h.float()).abs().argmax()

tensor(27261435, device='cuda:0')

In [17]:
o_q8_torch_h.flatten()[(o_q8_h.float() - o_q8_torch_h.float()).abs().argmax()]

tensor(-4.5000, device='cuda:0', dtype=torch.float8_e4m3fn)

In [18]:
o_q8_h.flatten()[(o_q8_h.float() - o_q8_torch_h.float()).abs().argmax()]

tensor(-5., device='cuda:0', dtype=torch.float8_e4m3fn)

In [19]:
o_orig.flatten()[(o_q8_h.float() - o_q8_torch_h.float()).abs().argmax()]

tensor(-4.6914, device='cuda:0', dtype=torch.float16)