In [1]:
import os
import torch
import math

In [2]:
from q8_matmul.gemm._C import q8_mm
from q8_matmul.quantizer._C import tokenwise_quant
from q8_matmul.ops._C import rms_norm

from fast_hadamard_transform import hadamard_transform

In [3]:

from fast_hadamard_transform import hadamard_transform

from safetensors.torch import load_file

def hadamard_quant(x):
    k = x.shape[-1]
    x_hadamard = hadamard_transform(x, scale=1/math.sqrt(k))
    x_abs_max_hadamard = x_hadamard.float().abs().max(-1, False).values
    x_scale_hadamard = x_abs_max_hadamard/127.0
    x_q8_hadamard = (x_hadamard.float() / x_scale_hadamard[..., None]).round().to(torch.int8)
    return x_q8_hadamard, x_scale_hadamard

def quant(x):
    x_abs_max = x.float().abs().max(-1, False).values
    x_scale = x_abs_max/127.0
    x_q8 =  (x.float() / x_scale[..., None]).round().to(torch.int8)
    return x_q8, x_scale


l_idx = 3
x = torch.load(f"/data/LTXVideo/acts/ffn/hs-{l_idx}.pt", map_location="cuda")[:, :, :]
model_weights = load_file("/data/ltx_weights/unet/unet_diffusion_pytorch_model.safetensors", device="cpu")
w = model_weights[f"transformer_blocks.{l_idx}.ff.net.0.proj.weight"].cuda()

In [4]:
model_weights["transformer_blocks.15.ff.net.0.proj.bias"].shape

torch.Size([8192])

In [5]:
x.shape

torch.Size([2, 3795, 2048])

In [6]:
for k in model_weights:
    if "norm" in k:
        print(k)

transformer_blocks.0.attn1.k_norm.weight
transformer_blocks.0.attn1.q_norm.weight
transformer_blocks.0.attn2.k_norm.weight
transformer_blocks.0.attn2.q_norm.weight
transformer_blocks.1.attn1.k_norm.weight
transformer_blocks.1.attn1.q_norm.weight
transformer_blocks.1.attn2.k_norm.weight
transformer_blocks.1.attn2.q_norm.weight
transformer_blocks.10.attn1.k_norm.weight
transformer_blocks.10.attn1.q_norm.weight
transformer_blocks.10.attn2.k_norm.weight
transformer_blocks.10.attn2.q_norm.weight
transformer_blocks.11.attn1.k_norm.weight
transformer_blocks.11.attn1.q_norm.weight
transformer_blocks.11.attn2.k_norm.weight
transformer_blocks.11.attn2.q_norm.weight
transformer_blocks.12.attn1.k_norm.weight
transformer_blocks.12.attn1.q_norm.weight
transformer_blocks.12.attn2.k_norm.weight
transformer_blocks.12.attn2.q_norm.weight
transformer_blocks.13.attn1.k_norm.weight
transformer_blocks.13.attn1.q_norm.weight
transformer_blocks.13.attn2.k_norm.weight
transformer_blocks.13.attn2.q_norm.weight


In [7]:
norm_weights = model_weights["transformer_blocks.0.attn2.k_norm.weight"].cuda().float()

In [8]:
norm_weights.shape

torch.Size([2048])

In [9]:
norm_weights

tensor([0.0258, 0.0448, 0.0280,  ..., 0.0207, 0.0159, 0.0105], device='cuda:0')

In [10]:
x.shape

torch.Size([2, 3795, 2048])

In [55]:
x_fp8 = x.to(torch.float8_e4m3fn)
s = torch.cuda.Event(True)
e = torch.cuda.Event(True)
s.record()
x_normed = rms_norm(x_fp8, norm_weights)
e.record()
torch.cuda.synchronize()
print(s.elapsed_time(e))

2.309119939804077


In [56]:
x.dtype

torch.bfloat16

In [57]:
norm_weights.dtype

torch.float32

In [58]:
norm_weights_fp16 = norm_weights.to(torch.bfloat16)

In [77]:
x_normed_torch = (norm_weights_fp16 * (x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True)).to(torch.bfloat16)))
torch.cuda.synchronize()
s = torch.cuda.Event(True)
e = torch.cuda.Event(True)
s.record()
x_normed_torch = (norm_weights_fp16 * (x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True)).to(torch.bfloat16)))
e.record()
torch.cuda.synchronize()
print(s.elapsed_time(e))

19.932159423828125


In [26]:
x_normed_torch.dtype

torch.float32

In [12]:
x_fp8.shape

torch.Size([2, 3795, 2048])

In [13]:
torch.cuda.synchronize()

In [14]:
x_normed

tensor([[[ 0.0059, -0.0117,  0.0039,  ...,  0.0176, -0.0039,  0.0117],
         [ 0.0117, -0.0293,  0.0000,  ...,  0.0176, -0.0078,  0.0078],
         [ 0.0156, -0.0352, -0.0020,  ...,  0.0215,  0.0039,  0.0078],
         ...,
         [-0.0195, -0.0234,  0.0117,  ...,  0.0020, -0.0020, -0.0020],
         [-0.0117, -0.0215,  0.0039,  ...,  0.0059, -0.0020,  0.0059],
         [-0.0078, -0.0430,  0.0020,  ...,  0.0059, -0.0059,  0.0039]],

        [[ 0.0020, -0.0195,  0.0020,  ...,  0.0195, -0.0039,  0.0098],
         [ 0.0098, -0.0352, -0.0039,  ...,  0.0215, -0.0078,  0.0078],
         [ 0.0117, -0.0430, -0.0059,  ...,  0.0254,  0.0020,  0.0078],
         ...,
         [-0.0195, -0.0137,  0.0098,  ...,  0.0039, -0.0020, -0.0020],
         [-0.0117, -0.0156,  0.0059,  ...,  0.0059, -0.0039,  0.0039],
         [-0.0059, -0.0391,  0.0039,  ...,  0.0078, -0.0059,  0.0020]]],
       device='cuda:0', dtype=torch.float8_e4m3fn)

In [15]:
x_normed_torch = (norm_weights * x.to(torch.float8_e4m3fn).float() * torch.rsqrt(x_fp8.float().pow(2).mean(-1, keepdim=True))).to(torch.float8_e4m3fn)

In [16]:
(x_normed_torch.float() - x_normed.float()).abs().max()

tensor(0.0156, device='cuda:0')

In [8]:
x_normed

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       dtype=torch.float8_e4m3fn)

In [5]:
for k in model_weights:
    if "transformer_blocks.0" in k:
        print(k)

transformer_blocks.0.attn1.k_norm.weight
transformer_blocks.0.attn1.q_norm.weight
transformer_blocks.0.attn1.to_k.bias
transformer_blocks.0.attn1.to_k.weight
transformer_blocks.0.attn1.to_out.0.bias
transformer_blocks.0.attn1.to_out.0.weight
transformer_blocks.0.attn1.to_q.bias
transformer_blocks.0.attn1.to_q.weight
transformer_blocks.0.attn1.to_v.bias
transformer_blocks.0.attn1.to_v.weight
transformer_blocks.0.attn2.k_norm.weight
transformer_blocks.0.attn2.q_norm.weight
transformer_blocks.0.attn2.to_k.bias
transformer_blocks.0.attn2.to_k.weight
transformer_blocks.0.attn2.to_out.0.bias
transformer_blocks.0.attn2.to_out.0.weight
transformer_blocks.0.attn2.to_q.bias
transformer_blocks.0.attn2.to_q.weight
transformer_blocks.0.attn2.to_v.bias
transformer_blocks.0.attn2.to_v.weight
transformer_blocks.0.ff.net.0.proj.bias
transformer_blocks.0.ff.net.0.proj.weight
transformer_blocks.0.ff.net.2.bias
transformer_blocks.0.ff.net.2.weight
transformer_blocks.0.scale_shift_table


In [14]:
k = x.shape[-1]
x_hadamard = hadamard_transform(x.to(torch.float8_e4m3fn), scale=1/math.sqrt(k))
w_hadamard = hadamard_transform(w.to(torch.float8_e4m3fn), scale=1/math.sqrt(k))

x_quant_h, x_scales_h = hadamard_quant(x.to(torch.float8_e4m3fn))
w_quant_h, w_scales_h = hadamard_quant(w.to(torch.float8_e4m3fn))

x_quant, x_scales = quant(x.to(torch.float8_e4m3fn))
w_quant, w_scales = quant(w.to(torch.float8_e4m3fn))


x_quant_h, x_scales_h = tokenwise_quant(x_hadamard)
w_quant_h, w_scales_h = tokenwise_quant(w_hadamard)

x_quant, x_scales = tokenwise_quant(x.to(torch.float8_e4m3fn))
w_quant, w_scales = tokenwise_quant(w.to(torch.float8_e4m3fn))

In [7]:
o_q8_h = q8_mm(x_quant_h[0].contiguous(), w_quant_h, x_scales_h, w_scales_h, False)

In [8]:
def print_scales(scales_tensor, ix, idx, type="A"):
    block = scales_tensor[ix*128:(ix+1)*128]
    warp = idx // 32
    if type == "A":
        row = warp % 2 * 16 + idx // 4
        arr = []
        for i in range(4*4):
            v = i % 4
            v_row = v // 2 * 8
            mma_m = i // 4
            arr.append(block[row+v_row + mma_m*32].item())
        print(arr)
    else:
        col = warp // 2 * 8 + (idx % 4) * 2
        arr = []
        for i in range(4*8):
            v = i % 4
            v_col = v % 2
            mma_n = i // 4
            arr.append(block[col+v_col + mma_n*16].item())
        print(arr)
# print_scales(w_scales_h, 8, 2, "B")

In [9]:
o_q8_h.shape

torch.Size([1, 3795, 8192])

In [10]:
o_q8_h = q8_mm(x_quant_h, w_quant_h, x_scales_h, w_scales_h, False)
o_q8 = q8_mm(x_quant, w_quant, x_scales, w_scales, False)

o_q8_torch_h = ((x_scales_h[..., None] * w_scales_h[None, None, :]) * torch.matmul(x_quant_h.float(), w_quant_h.float().t())).to(torch.float8_e4m3fn)
o_q8_torch = ((x_scales[..., None] * w_scales[None, None, :]) * torch.matmul(x_quant.float(), w_quant.float().t())).to(torch.float8_e4m3fn)


o_orig = torch.matmul(x.half(), w.half().t())# nn.functional.linear(x, w, bias=None)

In [11]:
def diff_max(a, b):
    return (a.float() - b.float()).abs().max()

def diff_mean(a, b):
    return (a.float() - b.float()).abs().mean()


def diff_quantiles(a, b):
    if a.ndim > 2:
        return torch.quantile((a.float() - b.float()).abs()[1, :2048, :], torch.tensor([0.25, 0.5, 0.75, 0.9, 0.99, 1.0]).cuda())
    else:
        return torch.quantile((a.float() - b.float()).abs()[:2048, :], torch.tensor([0.25, 0.5, 0.75, 0.9, 0.99, 1.0]).cuda())
        
diff_q8_h = diff_max(o_q8_h, o_orig)
diff_q8 = diff_max(o_q8, o_orig)
diff_q8_torch_h = diff_max(o_q8_torch_h, o_orig)
diff_q8_torch = diff_max(o_q8_torch, o_orig)

In [12]:
o_fp8_orig = torch._scaled_mm(x[1].to(torch.float8_e4m3fn),  w.to(torch.float8_e4m3fn).contiguous().t(), scale_a=torch.tensor([1.0]).cuda(), scale_b=torch.tensor([1.0]).cuda())

In [13]:
print("DIFF Hadamard: ", diff_q8_h)
print("DIFF no Hadamard: ", diff_q8)

print("DIFF Hadamard Torch: ", diff_q8_torch_h)
print("DIFF no Hadamard Torch: ", diff_q8_torch)

print("Diff mean hadamard: ", diff_max(o_orig, o_q8_h))
print("Diff mean: ", diff_max(o_orig, o_q8))

print("torch q8: ", diff_max(o_q8_torch_h, o_q8_h))
print("torch fp8: ", diff_max(o_fp8_orig, o_orig[1]))

print("diff torch fp8 q8: ", diff_max(o_fp8_orig, o_q8_h[1]))


DIFF Hadamard:  tensor(0.5859, device='cuda:0')
DIFF no Hadamard:  tensor(0.7031, device='cuda:0')
DIFF Hadamard Torch:  tensor(0.5859, device='cuda:0')
DIFF no Hadamard Torch:  tensor(0.7031, device='cuda:0')
Diff mean hadamard:  tensor(0.5859, device='cuda:0')
Diff mean:  tensor(0.7031, device='cuda:0')
torch q8:  tensor(0., device='cuda:0')
torch fp8:  tensor(0.4609, device='cuda:0')
diff torch fp8 q8:  tensor(1., device='cuda:0')


In [None]:
# DIFF Hadamard:  tensor(0.5859, device='cuda:0')
# DIFF no Hadamard:  tensor(0.7031, device='cuda:0')
# DIFF Hadamard Torch:  tensor(0.5859, device='cuda:0')
# DIFF no Hadamard Torch:  tensor(0.7031, device='cuda:0')
# Diff mean hadamard:  tensor(0.5859, device='cuda:0')
# Diff mean:  tensor(0.7031, device='cuda:0')
# torch q8:  tensor(0., device='cuda:0')
# torch fp8:  tensor(0.4609, device='cuda:0')
# diff torch fp8 q8:  tensor(1., device='cuda:0')

In [33]:
o_q8_torch_h.shape

torch.Size([2, 3795, 8192])

In [34]:
o_fp8_orig.shape

torch.Size([3795, 8192])

In [35]:
((o_q8_h[1].float() * o_orig[1].float()) < 0).sum()/o_orig[1].numel()

tensor(0.0029, device='cuda:0')

In [36]:
((o_q8.float() * o_orig.float()) < 0).sum()/o_orig.numel()

tensor(0.0044, device='cuda:0')

In [37]:
((o_fp8_orig.float() * o_orig[1].float()) < 0).sum()/o_orig.numel()

tensor(0.0009, device='cuda:0')

In [42]:
o_orig[(o_q8_h.float() * o_orig.float()) < 0].abs().max()

tensor(0.2113, device='cuda:0', dtype=torch.float16)

In [41]:
o_q8_h.float()[(o_q8_h.float() * o_orig.float()) < 0].abs().max()

tensor(0.2188, device='cuda:0')

In [43]:
o_orig[1][(o_fp8_orig.float() * o_orig[1].float()) < 0].abs().max()

tensor(0.1755, device='cuda:0', dtype=torch.float16)

In [21]:
diff_quantiles(o_orig, o_q8_h)

tensor([0.0210, 0.0454, 0.0801, 0.1113, 0.1914, 0.4102], device='cuda:0')

In [22]:
diff_quantiles(o_orig[1], o_fp8_orig)

tensor([0.0137, 0.0381, 0.0762, 0.1074, 0.1875, 0.3594], device='cuda:0')

In [23]:
batch = x.shape[0]
m = x.shape[1]
n = w.shape[0]
k = x.shape[-1]
int8_tflops = []
TFLOPS = 2 * batch * m *n *k 

In [40]:

torch.cuda.synchronize()
N_ROUNDS = 10
N_OUTER_ROUNDS = 1

for _ in range(5):
    o_q8_h = q8_mm(x_quant_h, w_quant_h, x_scales_h, w_scales_h, True)

torch.cuda.synchronize()

for _ in range(N_OUTER_ROUNDS):
    start_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    end_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    for i in range(N_ROUNDS):
        start_events[i].record()
        o_q8_h = q8_mm(x_quant_h, w_quant_h, x_scales_h, w_scales_h, True)
        end_events[i].record()
    torch.cuda.synchronize()
    elapsed_times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    int8_tflops.append((TFLOPS * 1e-12)/(min(elapsed_times) * 1e-3))
print(max(int8_tflops))

109.99504667097293


In [35]:
fp8_tflops = []

In [36]:

torch.cuda.synchronize()
N_ROUNDS = 10
N_OUTER_ROUNDS = 1
_w = w_hadamard.contiguous().t()
_scale = torch.tensor([1.0]).cuda()
for _ in range(5):
    o_fp8_orig = torch._scaled_mm(x_hadamard[0],  _w, scale_a=_scale, scale_b=_scale)

torch.cuda.synchronize()

for _ in range(N_OUTER_ROUNDS):
    start_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    end_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    for i in range(N_ROUNDS):
        start_events[i].record()
        o_fp8_orig = torch._scaled_mm(x_hadamard[0],  _w, scale_a=torch.tensor([1.0]).cuda(), scale_b=torch.tensor([1.0]).cuda())
        end_events[i].record()
    torch.cuda.synchronize()
    elapsed_times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    fp8_tflops.append((TFLOPS/2 * 1e-12)/(min(elapsed_times) * 1e-3))
print(max(fp8_tflops))

58.16396682309988


In [39]:
print("INT8/FP8: ", max(int8_tflops)/max(fp8_tflops))

INT8/FP8:  1.8911200985571068


In [38]:
fp16_tflops = []

In [92]:
# for _ in range(10):
#     o = flash_attention_int8(q_quant, k_quant, v_quant, q_scales, k_scales, v_scales)
# flash_attention_int8_4stages
torch.cuda.synchronize()
N_ROUNDS = 10
N_OUTER_ROUNDS = 1

a_half = x.half()
b_half = w.half().t().contiguous()

for _ in range(5):
    o_half = torch.matmul(a_half, b_half)
torch.cuda.synchronize()

for _ in range(N_OUTER_ROUNDS):
    start_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    end_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    for i in range(N_ROUNDS):
        start_events[i].record()
        o_half = torch.matmul(a_half, b_half)
        end_events[i].record()
    torch.cuda.synchronize()
    elapsed_times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    fp16_tflops.append((TFLOPS * 1e-12)/(min(elapsed_times) * 1e-3))
print(max(fp16_tflops))

22.924014044917687


In [1]:
import torch
import math
from attention_cutlass_fp8 import flash_attention_fp8_even, flash_attention_fp8
from fast_hadamard_transform import hadamard_transform

l = 12
q = torch.load(f"/data/LTXVideo/acts/attn/q-{l}.pt")[:, :, :].cuda().contiguous().half()
k = torch.load(f"/data/LTXVideo/acts/attn/k-{l}.pt")[:, :, :].cuda().contiguous().half()
v = torch.load(f"/data/LTXVideo/acts/attn/v-{l}.pt")[:, :, :].cuda().contiguous().half()

In [2]:
q.shape

torch.Size([2, 32, 3795, 64])

In [3]:
head_dim = q.shape[-1]
sm_scale = 1/math.sqrt(head_dim)
sm_scale_fp8 = sm_scale*1.44269504

q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.transpose(2, 3).contiguous().to(torch.float8_e4m3fn)

q_hadamard = hadamard_transform(q, scale=1/math.sqrt(head_dim)).to(torch.float8_e4m3fn)
k_hadamard = hadamard_transform(k, scale=1/math.sqrt(head_dim)).to(torch.float8_e4m3fn)

In [4]:
torch.cuda.synchronize()
s = torch.cuda.Event(True)
e = torch.cuda.Event(True)
v_tokens = v.shape[-2]
v_tokens_pad = ((v_tokens + 15)//16)*16 - v_tokens
s.record()
v_fp8_padded = torch.nn.functional.pad(v_fp8, (0, v_tokens_pad))
e.record()
torch.cuda.synchronize()
print(s.elapsed_time(e))

10.720255851745605


In [5]:
TFLOPS_PER_ATTN = 4*q.shape[0]*q.shape[1]*q.shape[2]*q.shape[2]*q.shape[3]
int8_tflops = []
fp16_tflops = []

In [6]:
# for _ in range(10):
#     o = flash_attention_int8(q_quant, k_quant, v_quant, q_scales, k_scales, v_scales)
# flash_attention_int8_4stages
torch.cuda.synchronize()
N_ROUNDS = 10
N_OUTER_ROUNDS = 1
sm_scale_fp8 =sm_scale*1.44269504
for _ in range(5):
    o = flash_attention_fp8_even(q_fp8, k_fp8, v_fp8_padded, sm_scale_fp8)
torch.cuda.synchronize()

for _ in range(N_OUTER_ROUNDS):
    start_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    end_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    for i in range(N_ROUNDS):
        start_events[i].record()
        v_fp8_padded = torch.nn.functional.pad(v_fp8, (0, v_tokens_pad))
        o = flash_attention_fp8_even(q_fp8, k_fp8, v_fp8_padded, sm_scale_fp8)
        end_events[i].record()
    torch.cuda.synchronize()
    elapsed_times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    int8_tflops.append((TFLOPS_PER_ATTN * 1e-12)/(min(elapsed_times) * 1e-3))
print(max(int8_tflops))

36.8809854556419


In [7]:
min(elapsed_times)

6.397952079772949

In [8]:
# for _ in range(10):
#     o = flash_attention_int8(q_quant, k_quant, v_quant, q_scales, k_scales, v_scales)
# flash_attention_int8_4stages
torch.cuda.synchronize()
N_ROUNDS = 10
N_OUTER_ROUNDS = 1
for _ in range(5):
    o_half = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale)

torch.cuda.synchronize()

for _ in range(N_OUTER_ROUNDS):
    start_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    end_events = [ torch.cuda.Event(True) for _ in range(N_ROUNDS)]
    for i in range(N_ROUNDS):
        start_events[i].record()
        
        o_half = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale)

        end_events[i].record()
    torch.cuda.synchronize()
    elapsed_times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    fp16_tflops.append((TFLOPS_PER_ATTN * 1e-12)/(min(elapsed_times) * 1e-3))
print(max(fp16_tflops))

28.043372120269883


In [9]:
v_fp8.shape

torch.Size([2, 32, 64, 3795])

In [10]:
v_fp8.stride()

(7772160, 242880, 3795, 1)

In [11]:
o_fp8 = flash_attention_fp8_even(q_fp8, k_fp8, v_fp8_padded, sm_scale_fp8)
o_fp8_h = flash_attention_fp8_even(q_hadamard, k_hadamard, v_fp8_padded, sm_scale_fp8)
o_half = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale)
o_half_ref = torch.nn.functional.scaled_dot_product_attention(q.to(torch.float8_e4m3fn).half(), k.to(torch.float8_e4m3fn).half(), v.to(torch.float8_e4m3fn).half(), scale=sm_scale).to(torch.float8_e4m3fn)

In [12]:

def diff_max(a, b):
    return (a.float() - b.float()).abs().max()

def diff_quantiles(a, b):
    return torch.quantile((a.float() - b.float()).abs()[1, :, :], torch.tensor([0.25, 0.5, 0.75, 0.9, 0.99, 1.0]).cuda())

def diff_rms(a, b):
    return torch.sqrt(((a.float() - b.float()).square().sum()/a.numel()))

def cos_sim(a, b):
    a = a.float()
    b = b.float()
    a_len = a.norm(dim=-1, p=2)
    b_len = b.norm(dim=-1, p=2)
    dot_prod = (a * b).sum(dim=-1)
    return dot_prod/(a_len*b_len)


diff_fp8 = diff_max(o_fp8, o_half)
diff_h = diff_max(o_fp8_h, o_half)
diff_ideal = diff_max(o_half, o_half.to(torch.float8_e4m3fn))

In [13]:
diff_rms(o_half, o_fp8)

tensor(0.0153, device='cuda:0')

In [14]:
diff_rms(o_half, o_fp8_h)

tensor(0.0139, device='cuda:0')

In [15]:
diff_rms(o_half, o_half_ref)

tensor(0.0134, device='cuda:0')

In [16]:
diff_quantiles(o_half_ref, o_fp8_h)

tensor([0.0000, 0.0000, 0.0078, 0.0312, 0.0625, 0.8125], device='cuda:0')

In [17]:
diff_quantiles(o_half_ref, o_fp8)

tensor([0.0000, 0.0000, 0.0000, 0.0156, 0.0625, 0.5000], device='cuda:0')

In [23]:
diff_quantiles(o_half, o_half_ref)

tensor([0.0018, 0.0052, 0.0146, 0.0303, 0.0898, 0.6367], device='cuda:0')

In [24]:
diff_quantiles(o_half, o_fp8_h)

tensor([0.0021, 0.0060, 0.0173, 0.0364, 0.1309, 1.1367], device='cuda:0')

In [25]:
diff_quantiles(o_half, o_fp8)

tensor([0.0020, 0.0056, 0.0156, 0.0320, 0.1147, 0.7324], device='cuda:0')

In [26]:
diff_quantiles(o_half, o_half_ref)

tensor([0.0018, 0.0052, 0.0146, 0.0303, 0.0898, 0.6367], device='cuda:0')

In [28]:
cos_sim(o_half, o_half_ref).min()

tensor(0.9684, device='cuda:0')

In [34]:
torch.quantile(cos_sim(o_fp8_h, o_half), torch.tensor([0.01, 0.25, 0.5, 0.75, 0.9]).cuda())

tensor([0.9909, 0.9973, 0.9987, 0.9996, 0.9997], device='cuda:0')

In [35]:
torch.quantile(cos_sim(o_fp8, o_half), torch.tensor([0.01, 0.25, 0.5, 0.75, 0.9]).cuda())

tensor([0.9933, 0.9979, 0.9989, 0.9996, 0.9997], device='cuda:0')

In [31]:
cos_sim(o_fp8, o_half).min()

tensor(0.9681, device='cuda:0')

In [19]:
o_fp8.float().max()

tensor(6., device='cuda:0')

In [98]:
o_half.max()

tensor(6.0664, device='cuda:0', dtype=torch.float16)

In [99]:
o_fp8_h.float().max()

tensor(6., device='cuda:0')

In [100]:
((o_fp8_h.float() * o_half.float()) < 0).sum()/o_half.numel()

tensor(0.0131, device='cuda:0')

In [102]:
((o_fp8.float() * o_half.float()) < 0).sum()/o_half.numel()

tensor(0.0113, device='cuda:0')

In [103]:
((o_half.float() * o_half_ref.float()) < 0).sum()/o_half.numel()

tensor(0.0102, device='cuda:0')

In [104]:
((o_fp8.float() * o_half_ref.float()) < 0).sum()/o_half.numel()

tensor(0.0023, device='cuda:0')

In [105]:
((o_fp8_h.float() * o_half_ref.float()) < 0).sum()/o_half.numel()

tensor(0.0058, device='cuda:0')