In [1]:
import torch
import sys
sys.path.append("/home/msst/repo/Quantization")
from qlib.quantizers.trellis_quantizer import TrellisQuantizer
DEVICE = 'cuda:0'


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
quantizer = TrellisQuantizer(
    values_type="LUTFREE_FP8",
    T=256,
    V=4,
    K=2,
    viterbi_batch_size=1024,
    use_kernel=False
).cuda()

In [3]:
x = torch.randn(2048, 2048).cuda()
quantizer.weight_shape = x.shape
x_q = quantizer.quantize(x, return_reco=True).reshape_as(x)
((x_q - x)**2).mean()

tensor(0.0746, device='cuda:0')

In [4]:
raise

RuntimeError: No active exception to reraise

In [None]:
training_lut = quantizer.codebook.get_training_lut().to(x.device)
training_lut.unique()

tensor([-2.1717, -1.8615, -1.5512, -1.2410, -1.0858, -0.9307, -0.7756, -0.6205,
        -0.5429, -0.4654, -0.3878, -0.3102, -0.2327, -0.1551, -0.0776,  0.0000,
         0.0776,  0.1551,  0.2327,  0.3102,  0.3878,  0.4654,  0.5429,  0.6205,
         0.7756,  0.9307,  1.0858,  1.2410,  1.5512,  1.8615,  2.1717],
       device='cuda:0')

In [None]:
raise

RuntimeError: No active exception to reraise

In [None]:
x = torch.randn(4096, 4096).cuda()
quantizer.weight_shape = x.shape
training_lut = quantizer.codebook.get_training_lut().to(x.device)


In [None]:
import math

@torch.compile
def update(quantizer, training_lut, cost, orig_seq_part, state_candidates):
    B = orig_seq_part.shape[0]  # batch size
    R = 2 ** (quantizer.L - quantizer.K * quantizer.V)  # reduced state size
    D = 2 ** (quantizer.K * quantizer.V)  # delta size
    S = 2 ** quantizer.L  # total states

    # Gather candidate costs (B, R, D)
    cand_cost = torch.gather(
        input=cost.view(B, 1, S).expand(-1, R, -1), 
        dim=-1, 
        index=state_candidates.expand(B, R, D)
    )

    # Find best candidate for each reduced state (B, R)
    best = torch.min(cand_cost, dim=-1)

    # Calculate state reconstruction error (B, S)
    state_err = (training_lut - orig_seq_part.unsqueeze(1)).square().sum(dim=-1)

    # Update cost (B, S)
    cost = state_err + best.values.view(B, R, 1).expand(-1, -1, D).reshape(B, S)

    # Get previous states (B, R)
    prev_state = torch.gather(
        input=state_candidates.expand(B, R, D), 
        dim=-1, 
        index=best.indices.unsqueeze(-1)
    )[..., 0]

    return prev_state, cost

def viterbi(quantizer, training_lut, X):
    """Optimized Viterbi decoding with time-major storage"""

    # State transition buffers
    sumdelta = (torch.arange(2 ** (quantizer.K * quantizer.V), device=X.device) << (quantizer.L - quantizer.K * quantizer.V)).view(1, 1, -1)

    # State candidates: maps (reduced_state, delta) -> full_state
    # Shape: (1, 2^(L-K*V), 2^(K*V))
    state_candidates = (torch.arange(2**quantizer.L, device=X.device).unsqueeze(0) >> (quantizer.K * quantizer.V))[
        0, :: 2 ** (quantizer.K * quantizer.V)
    ].unsqueeze(-1) + sumdelta

    B = X.shape[0]
    T_v = quantizer.T // quantizer.V

    # Forward pass
    cost = (training_lut - X[:, : quantizer.V].unsqueeze(1)).square().sum(dim=-1)

    # Time-major storage for efficient backtrace
    from_state = torch.zeros(T_v, B, 2 ** (quantizer.L - quantizer.K * quantizer.V), dtype=torch.long, device=X.device)

    for i in range(1, T_v):
        obs = X[:, i * quantizer.V : (i + 1) * quantizer.V]
        prev_state, cost = quantizer.update(
            training_lut.to(torch.float32),
            cost.to(torch.float32),
            obs.to(torch.float32),
            state_candidates,
        )
        from_state[i] = prev_state

    # Backtrace
    final_state = torch.zeros(T_v, B, dtype=quantizer.idx_dtype, device=X.device)
    final_state[T_v - 1] = torch.argmin(cost, dim=-1)

    for i in range(T_v - 1, 0, -1):
        reduced_idx = (final_state[i] >> (quantizer.K * quantizer.V)).long().unsqueeze(1)
        final_state[i - 1] = torch.gather(from_state[i], 1, reduced_idx).squeeze(1)

    return final_state.transpose(0, 1)  # Return as (B, T_v)

def quantize_seq(quantizer, training_lut, X, **kwargs):
    """Quantize sequence with batch processing"""
    n_seq, T = X.shape
    batch_padding_len = math.ceil(n_seq / quantizer.viterbi_bs) * quantizer.viterbi_bs - n_seq
    X = torch.nn.functional.pad(X.T, (0, batch_padding_len)).T

    n_seq_padded = X.shape[0]
    X = X.reshape(n_seq_padded // quantizer.viterbi_bs, quantizer.viterbi_bs, T).contiguous()

    Qidxs = torch.zeros(
        n_seq_padded // quantizer.viterbi_bs, quantizer.viterbi_bs, T // quantizer.V, dtype=quantizer.idx_dtype, device=X.device
    )
    for i in range(len(X)):
        Qidxs[i] = quantizer.viterbi(training_lut, X[i])
    Qidxs = Qidxs.reshape(n_seq_padded, T // quantizer.V)[:n_seq]
    return Qidxs


In [None]:
state = quantize_seq(quantizer, training_lut, x.reshape(-1, quantizer.T))
x_q = training_lut[state.int().to(training_lut.device)].to(state.device).reshape_as(x)
((x_q - x)**2).mean()

tensor(0.0734, device='cuda:0')

In [None]:
# for s in 4*[4096,] + 3*[11008,]:
#     x = torch.randn(4096, s).cuda()
#     quantizer.weight_shape = x.shape
#     print(s)
#     x_q = quantizer.quantize(x, return_reco=False)#.reshape_as(x)

In [None]:
# ((x_q - x)**2).mean()