In [77]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class TopKSparse(nn.Module):
    def __init__(self, token_dim, code_dim, discrete_size,
                 randomized=True, random_p=0.1):
        super().__init__()
        self.token_dim = token_dim
        self.code_dim = code_dim
        self.discrete_size = discrete_size
        self.sparsity_ratio = np.log2(discrete_size) / 16
        self.randomized = randomized
        self.random_p = random_p

    def compress(self, x):
        x_shape = x.shape                        # [B,S,H]
        flattened_x = x.view(-1, x.shape[-1])        # [B*S,H]
        H = flattened_x.shape[1]

        # ---------- 1. Top-K ----------
        k = max(1, int(self.sparsity_ratio * H))
        _, topk_idx = torch.topk(flattened_x.abs(), k, dim=1)

        mask = torch.zeros_like(flattened_x, dtype=torch.bool)
        mask.scatter_(1, topk_idx, True)

        # ---------- 2. random augmentation ----------
        if self.randomized and self.random_p > 0:
            rand_mask = torch.rand_like(flattened_x) < self.random_p
            mask |= rand_mask

        # ---------- 3. FINAL idx ----------
        # 每一行不同数量，需转成 ragged / packed
        idx_list = mask.nonzero(as_tuple=False)  # [N,2] -> (row, col)

        # ---------- 4. values FROM idx ----------
        values = flattened_x[idx_list[:, 0], idx_list[:, 1]]

        payload = {
            "indices": idx_list,   # (row_idx, col_idx)
            "values": values
        }

        aux = {
            "embedding_shape": x_shape,
            "code_dim": H
        }

        return payload, aux, torch.tensor(0.0, device=x.device)

    def decompress(self, payload, aux):
        indices = payload["indices"]
        values = payload["values"]
        x_shape = aux["embedding_shape"]
        H = aux["code_dim"]

        BxS = x_shape[0] * x_shape[1]
        flattened_x = torch.zeros(
            (BxS, H), device=values.device, dtype=values.dtype
        )

        flattened_x[indices[:, 0], indices[:, 1]] = values
        return flattened_x.view(x_shape)

In [79]:
s=TopKSparse(128,128,4)
x=torch.randn(4,10,128)
payload, aux, _=s.compress(x)
e=s.decompress(payload, aux)
e.shape

torch.Size([4, 10, 128])

In [82]:
payload["indices"].shape

torch.Size([1115, 2])

In [98]:
def cosine_similarity(x, y, dim=-1, eps=1e-4):
    dot_product = (x * y).sum(dim=dim)
    x_norm = x.norm(p=2, dim=dim)
    y_norm = y.norm(p=2, dim=dim)
    return dot_product / (x_norm * y_norm + eps)
class CosineLoss(nn.Module):
    def __init__(self):
        super(CosineLoss, self).__init__()
    def forward(self,x1,x2,target,eps=1e-3):
        sim=cosine_similarity(x1,x2,dim=-1,eps=eps)
        loss=1-sim
        return(loss.sum()/loss.shape[0])
def pack_2bit_tensor(q: torch.Tensor):
    """
    q: torch.Tensor, dtype=torch.uint8, values in {0,1,2,3}
    returns:
        packed: torch.Tensor, dtype=torch.uint8
        pad: int
    """
    assert q.dtype == torch.uint8
    q_shape=q.shape
    q = q.reshape(-1)

    pad = (-q.numel()) % 4
    if pad:
        q = torch.cat([
            q,
            torch.zeros(pad, dtype=q.dtype, device=q.device)
        ])
    else:
        pad=0

    q = q.view(-1, 4)
    packed = (
        (q[:, 0] << 0) |
        (q[:, 1] << 2) |
        (q[:, 2] << 4) |
        (q[:, 3] << 6)
    ).to(torch.uint8)

    return packed,q_shape, pad

def unpack_2bit_tensor(packed: torch.Tensor,q_shape,pad=0):
    """
    packed: torch.uint8 tensor
    pad: int
    """
    assert packed.dtype == torch.uint8

    q = torch.empty(
        (packed.numel(), 4),
        dtype=torch.uint8,
        device=packed.device
    )

    q[:, 0] = (packed >> 0) & 0b11
    q[:, 1] = (packed >> 2) & 0b11
    q[:, 2] = (packed >> 4) & 0b11
    q[:, 3] = (packed >> 6) & 0b11

    q = q.view(-1)
    if pad:
        q = q[:-pad]
    q=q.reshape(q_shape)
    return q.float()
    
    
def robust_minmax(X,width=3):
    mean = X.mean()
    std = X.std()
    q_min = mean - width * std
    q_max = mean + width * std
    X_clipped = torch.clamp(X, q_min, q_max)
    X_norm = 2 * (X_clipped - q_min+1e-4) / (q_max - q_min+1e-4) - 1
    #X_norm = 2 * (X_clipped - X_clipped.min()+1e-4) / (X_clipped.max() - X_clipped.min()+1e-4) - 1
    return X_norm,q_min,q_max
def quantize(X, size):
    device=X.device
    half_width=(size-1)/2
    offset=((size-1)%2)/2
    X=X*half_width-offset
    X_round=torch.round(X)
    X_ste=(X_round-X).detach()+X
    X_ste=(X_ste+offset)/half_width
    #X_scaled=X*half_width-offset #modified on 11.3
    #X_round=torch.round(X_scaled) #modified on 11.3
    #X_round=(X_round+offset)/half_width #modified on 11.3
    #X_ste=(X_round-X).detach()+X #modified on 11.3
    indices=torch.round(X_ste*half_width+half_width)
    return X_ste, indices.detach().int()
    

class FSQ(nn.Module):
    def __init__(self, token_dim, code_dim, discrete_size=4):
        super(FSQ, self).__init__()
        assert discrete_size == 4, "2-bit packing requires discrete_size=4"

        self.token_dim = token_dim
        self.code_dim = code_dim
        self.discrete_size = discrete_size

        self.loss = CosineLoss()
        self.in_proj = nn.Linear(token_dim, code_dim)
        self.out_proj = nn.Linear(code_dim, token_dim)

        self.levels = torch.linspace(-1, 1, steps=discrete_size)

    # =========================
    # Client-side
    # =========================
    def compress(self, x):
        """
        x: [B, S, token_dim]
        """
        x_shape = x.shape
        x = self.in_proj(x)
        x, q_min, q_max = robust_minmax(x)
        flattened_x = x.view(-1, x.shape[-1])   # [B*S, H]
        flattened_x_q, indices = quantize(
            flattened_x, self.discrete_size
        )                                       # indices ∈ {0,1,2,3}
        target = torch.ones(flattened_x_q.shape[0], device=x.device)
        L_comm = self.loss(flattened_x, flattened_x_q.detach(), target)
        indices_uint8 = indices.to(torch.uint8)
        packed,q_shape, pad = pack_2bit_tensor(indices_uint8)

        payload = {
            "packed_indices": packed,   # torch.uint8
            "pad": pad,
            "q_shape":q_shape
        }

        aux = {
            "embedding_shape": x_shape,
        }

        return payload, aux, L_comm

    # =========================
    # Server-side
    # =========================
    def decompress(self, payload, aux):
        packed = payload["packed_indices"]
        pad = payload["pad"]
        q_shape = payload["q_shape"]
        x_shape = aux["embedding_shape"]

        indices = unpack_2bit_tensor(packed,q_shape, pad)   # float tensor
        indices = indices.view(-1, self.code_dim)

        half_len = (self.discrete_size - 1) / 2
        flattened_x_q = (indices - half_len) / half_len
        output = self.out_proj(flattened_x_q)

        return output.view(x_shape)

In [108]:
s=FSQ(128,128,4)

payload, aux, _=s.compress(x)
e=s.decompress(payload, aux)
e.shape

RuntimeError: mat1 and mat2 shapes cannot be multiplied (40x1280 and 128x128)

In [119]:


class NFNDoubleQuantizer_split(nn.Module):
    def __init__(self, bits=4, block_size=64, use_double_quant=True):
        super().__init__()
        self.bits = bits
        self.block_size = block_size
        self.use_double_quant = use_double_quant
        self.table=generate_nf_table(bits)

    def compress(self, x):
        original_shape=x.shape
        x=x.view(-1,x.shape[2])
        x_shape = x.shape
        x = x.view(x_shape[0], x_shape[1] // self.block_size, self.block_size)
        x_min = x.min(dim=2).values.unsqueeze(-1)
        x_max = x.max(dim=2).values.unsqueeze(-1)
        scales = (x_max - x_min).squeeze(-1)

        x_norm = 2 * (x - x_min) / (x_max - x_min + 1e-8) - 1
        dist = torch.abs(x_norm.unsqueeze(-1) - self.table.to(x.device))
        q_idx = torch.argmin(dist, dim=-1).to(torch.uint8)

        if self.use_double_quant:
            s_min = scales.min(dim=-1).values.unsqueeze(-1)
            s_max = scales.max(dim=-1).values.unsqueeze(-1)
            scales_q = ((scales - s_min) / (s_max - s_min + 1e-8) * 255).round().to(torch.uint8)
        else:
            scales_q, s_min, s_max = None, scales, None
    
        packed,q_shape, pad = pack_2bit_tensor(q_idx)
        payload = {
            "packed_indices": packed,   # torch.uint8
            "pad": pad,
            "q_shape":q_shape
        }

        aux = {
            "scales_q": scales_q,
            "s_min": s_min,
            "s_max": s_max,
            "mins": x_min,
            "x_shape": x_shape,
            "original_shape":original_shape
        }

        return payload, aux, 0
    def decompress(self, payload, aux):
        packed = payload["packed_indices"]
        pad = payload["pad"]
        q_shape = payload["q_shape"]
        q_idx=unpack_2bit_tensor(packed,q_shape,pad)
        scales_q = aux["scales_q"]
        s_min = aux["s_min"]
        s_max = aux["s_max"]
        mins = aux["mins"]
        x_shape = aux["x_shape"]
        original_shape=aux["original_shape"]

        if scales_q is not None:
            scales = s_min + (scales_q.float() / 255) * (s_max - s_min)
        else:
            scales = s_min

        scales = scales.unsqueeze(-1)

        w_block = self.table[q_idx.long()].to(dtype=torch.float32, device=scales.device)
        w_block = (w_block + 1) / 2 * scales + mins
        flatten_x=w_block.view(x_shape)
        return flatten_x.view(original_shape)
    
    

In [121]:
from scipy.stats import norm
s=NFNDoubleQuantizer_split()
x=torch.randn(4,10,1280)
payload, aux, _=s.compress(x)
e=s.decompress(payload, aux)
e.shape

torch.Size([4, 10, 1280])

(tensor([139,  92,  61,  ..., 253, 216, 118], dtype=torch.uint8),
 torch.Size([40, 20, 64]),
 0)