In [1]:
import torch
import sys
sys.path.append("/home/msst/repo/Quantization")
from qlib.quantizers.trellis_quantizer import TrellisQuantizer
DEVICE = 'cuda:0'
from time import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
quantizer = TrellisQuantizer(
    values_type="LUTFREE_FP8",
    T=1024,
    V=4,
    K=2,
    viterbi_batch_size=1024, # rtx5070ti
    # viterbi_batch_size=512, # rtx3060
    use_kernel=False
).to(DEVICE)

print("Trellis elems:", quantizer.T)
print("Viterby bs:", quantizer.viterbi_bs)

Trellis elems: 1024
Viterby bs: 1024


In [3]:
# MSE + warmup
x = torch.randn(2048, 2048).to(DEVICE)
quantizer.weight_shape = x.shape

trellis = quantizer.quantize(x)
x_q1 = quantizer.dequantize(trellis)
print("QTIP MSE", ((x_q1 - x)**2).mean())

trellis = quantizer.quantize(x, fast_tail_bite=True)
x_q2 = quantizer.dequantize(trellis)
print("NEW MSE", ((x_q2 - x)**2).mean())

QTIP MSE tensor(0.0746, device='cuda:0')
NEW MSE tensor(0.0755, device='cuda:0')


In [4]:
# print("QTIP speed")
# t0 = time()
# for s in 4*[4096,] + 3*[11008,]:
#     x = torch.randn(4096, s).cuda()
#     quantizer.weight_shape = x.shape
#     trellis = quantizer.quantize(x)
#     # x_q = quantizer.dequantize(trellis)
#     # print("QTIP", s, ((x_q - x)**2).mean())
# print(time() - t0)

In [5]:
print("NEW speed")
t0 = time()
for s in 4*[4096,] + 3*[11008,]:
    x = torch.randn(4096, s).to(DEVICE)
    quantizer.weight_shape = x.shape
    trellis = quantizer.quantize(x, fast_tail_bite=True)
    # x_q = quantizer.dequantize(trellis)
    # print("NEW", s, ((x_q - x)**2).mean())
print(time() - t0)

NEW speed
54.3755087852478


In [6]:
raise

RuntimeError: No active exception to reraise

In [None]:
# torch.manual_seed(0)

# x = torch.randn(2048, 2048).cuda()
# # x = torch.randn(quantizer.viterbi_bs, quantizer.T).cuda()
# quantizer.weight_shape = x.shape

# trellis = quantizer.quantize(x)
# x_q1 = quantizer.dequantize(trellis)
# print("QTIP", ((x_q1 - x)**2).mean())

# print()

# trellis = quantizer.quantize(x, fast_tail_bite=True)
# x_q2 = quantizer.dequantize(trellis)
# # x_q2 = quantizer.quantize(x, fast_tail_bite=True).reshape_as(x)
# print("NEW", ((x_q2 - x)**2).mean())

In [None]:
torch.manual_seed(0)

x = torch.randn(2048, 2048).cuda()
# x = torch.randn(quantizer.viterbi_bs, quantizer.T).cuda()
quantizer.weight_shape = x.shape

x_q1 = quantizer.quantize(x, return_reco=True)
print("QTIP", ((x_q1 - x)**2).mean())

print()

x_q2 = quantizer.quantize(x, fast_tail_bite=True, return_reco=True)
print("NEW", ((x_q2 - x)**2).mean())

QTIP tensor(0.0746, device='cuda:0')

NEW tensor(0.0755, device='cuda:0')


In [None]:
((x_q1.reshape(-1, quantizer.T)[:, :4] - x.reshape(-1, quantizer.T)[:, :4])**2).mean()

tensor(0.0743, device='cuda:0')

In [None]:
print(((x_q2.reshape(-1, quantizer.T)[:, :4] - x.reshape(-1, quantizer.T)[:, :4])**2).mean())
print()
# ((x_q2[:, :4] - x[:, :4])**2).sum(axis=1)

tensor(0.0487, device='cuda:0')



In [None]:
raise

RuntimeError: No active exception to reraise

In [None]:
trellis = quantizer.quantize(x)
x_q = quantizer.dequantize(trellis)
print(((x_q - x)**2).mean())

tensor(0.0746, device='cuda:0')


In [None]:
trellis = quantizer.quantize(x, fast_tail_bite=True)
x_q = quantizer.dequantize(trellis)
print(((x_q - x)**2).mean())

tensor(0.0755, device='cuda:0')


In [None]:
raise

RuntimeError: No active exception to reraise

In [None]:
training_lut = quantizer.codebook.get_training_lut().cuda()
training_lut.unique()

tensor([-2.2265, -2.0781, -1.9297, -1.7812, -1.6328, -1.4844, -1.3359, -1.1875,
        -1.1133, -1.0391, -0.9648, -0.8906, -0.8164, -0.7422, -0.6680, -0.5937,
        -0.5566, -0.5195, -0.4824, -0.4453, -0.4082, -0.3711, -0.3340, -0.2969,
        -0.2598, -0.2227, -0.1855, -0.1484, -0.1113, -0.0742, -0.0371,  0.0000,
         0.0371,  0.0742,  0.1113,  0.1484,  0.1855,  0.2227,  0.2598,  0.2969,
         0.3340,  0.3711,  0.4082,  0.4453,  0.4824,  0.5195,  0.5566,  0.5937,
         0.6680,  0.7422,  0.8164,  0.8906,  0.9648,  1.0391,  1.1133,  1.1875,
         1.3359,  1.4844,  1.6328,  1.7812,  1.9297,  2.0781,  2.2265],
       device='cuda:0')

In [None]:
# raise

In [None]:
torch.manual_seed(0)

x = torch.randn(4096, 4096).cuda()
quantizer.weight_shape = x.shape
training_lut = quantizer.codebook.get_training_lut().to(x.device)


In [None]:
import math

# @torch.compile
def update(quantizer, training_lut, cost, orig_seq_part, state_candidates):
    B = orig_seq_part.shape[0]  # batch size
    R = 2 ** (quantizer.L - quantizer.K * quantizer.V)  # reduced state size
    D = 2 ** (quantizer.K * quantizer.V)  # delta size
    S = 2 ** quantizer.L  # total states

    # Gather candidate costs (B, R, D)
    cand_cost = torch.gather(
        input=cost.view(B, 1, S).expand(-1, R, -1), 
        dim=-1, 
        index=state_candidates.expand(B, R, D)
    )

    # Find best candidate for each reduced state (B, R)
    best = torch.min(cand_cost, dim=-1)

    # Calculate state reconstruction error (B, S)
    state_err = (training_lut - orig_seq_part.unsqueeze(1)).square().sum(dim=-1)

    # Update cost (B, S)
    cost = state_err + best.values.view(B, R, 1).expand(-1, -1, D).reshape(B, S)

    # Get previous states (B, R)
    prev_state = torch.gather(
        input=state_candidates.expand(B, R, D), 
        dim=-1, 
        index=best.indices.unsqueeze(-1)
    )[..., 0]

    return prev_state, cost


def viterbi(quantizer, training_lut, X):
    """Optimized Viterbi decoding with time-major storage"""

    # State transition buffers
    sumdelta = (torch.arange(2 ** (quantizer.K * quantizer.V), device=X.device) << (quantizer.L - quantizer.K * quantizer.V)).view(1, 1, -1)

    # State candidates: maps (reduced_state, delta) -> full_state
    # Shape: (1, 2^(L-K*V), 2^(K*V))
    state_candidates = (torch.arange(2**quantizer.L, device=X.device).unsqueeze(0) >> (quantizer.K * quantizer.V))[
        0, :: 2 ** (quantizer.K * quantizer.V)
    ].unsqueeze(-1) + sumdelta

    # print("state_candidates")
    # print(state_candidates[0, 0], state_candidates[0, 0].shape)
    

    B = X.shape[0]
    T_v = quantizer.T // quantizer.V

    # Forward pass
    cost = (training_lut - X[:, : quantizer.V].unsqueeze(1)).square().sum(dim=-1)

    # print(cost)
    # print(cost.shape)
    top_k = 1 #32 #64
    values, indices = torch.topk(cost, k=top_k, dim=1, largest=False)
    # print("topk cost:", values, indices)
    first_bytes = indices & 0xFF
    # print("first_bytes", first_bytes, first_bytes.shape)
    # print("unique", first_bytes.sort(dim=-1))
    mode, _ = torch.mode(first_bytes, dim=1, keepdim=True)
    # mode = mode.unsqueeze(-1)
    cost_mask = (torch.arange(1 << 16, device=cost.device) & 0xFF).expand_as(cost)
    cost_mask = (cost_mask != mode.expand_as(cost)).float()
    cost_mask[cost_mask != 0] = torch.tensor(torch.inf)
    cost = cost + cost_mask
    # print(cost)

    # Time-major storage for efficient backtrace
    from_state = torch.zeros(T_v, B, 2 ** (quantizer.L - quantizer.K * quantizer.V), dtype=torch.long, device=X.device)

    for i in range(1, T_v):
        obs = X[:, i * quantizer.V : (i + 1) * quantizer.V]
        prev_state, cost = quantizer.update(
            training_lut.to(torch.float32),
            cost.to(torch.float32),
            obs.to(torch.float32),
            state_candidates,
        )
        from_state[i] = prev_state

    # Backtrace
    backtrace_cost_mask = ((torch.arange(1 << 16, device=cost.device) >> 8) & 0xFF).expand_as(cost)
    backtrace_cost_mask = (backtrace_cost_mask != mode.expand_as(cost)).float()
    backtrace_cost_mask[backtrace_cost_mask != 0] = torch.tensor(torch.inf)
    cost = cost + backtrace_cost_mask

    final_state = torch.zeros(T_v, B, dtype=quantizer.idx_dtype, device=X.device)
    final_state[T_v - 1] = torch.argmin(cost, dim=-1)

    for i in range(T_v - 1, 0, -1):
        reduced_idx = (final_state[i] >> (quantizer.K * quantizer.V)).long().unsqueeze(1)
        final_state[i - 1] = torch.gather(from_state[i], 1, reduced_idx).squeeze(1)

    return final_state.transpose(0, 1)  # Return as (B, T_v)


def quantize_seq(quantizer, training_lut, X, **kwargs):
    """Quantize sequence with batch processing"""
    n_seq, T = X.shape
    batch_padding_len = math.ceil(n_seq / quantizer.viterbi_bs) * quantizer.viterbi_bs - n_seq
    X = torch.nn.functional.pad(X.T, (0, batch_padding_len)).T

    n_seq_padded = X.shape[0]
    X = X.reshape(n_seq_padded // quantizer.viterbi_bs, quantizer.viterbi_bs, T).contiguous()

    Qidxs = torch.zeros(
        n_seq_padded // quantizer.viterbi_bs, quantizer.viterbi_bs, T // quantizer.V, dtype=quantizer.idx_dtype, device=X.device
    )
    for i in range(len(X)):
        Qidxs[i] = viterbi(quantizer, training_lut, X[i])
    Qidxs = Qidxs.reshape(n_seq_padded, T // quantizer.V)[:n_seq]
    return Qidxs


In [None]:
state = quantize_seq(quantizer, training_lut, x.reshape(-1, quantizer.T))
x_q = training_lut[state.int().to(training_lut.device)].to(state.device).reshape_as(x)
((x_q - x)**2).mean()

tensor(0.0795, device='cuda:0')

In [None]:
# x = torch.tensor(15 << 8).to(torch.uint16).unsqueeze(0)
# x.view(torch.uint8)

In [None]:
# for s in 4*[4096,] + 3*[11008,]:
#     x = torch.randn(4096, s).cuda()
#     quantizer.weight_shape = x.shape
#     print(s)
#     x_q = quantizer.quantize(x, return_reco=False)#.reshape_as(x)

In [None]:
# ((x_q - x)**2).mean()