In [2]:
import torch
import numpy as np

In [3]:
def vector_quantize_batched(w, codebook, batch_size=1024):
    m, n = w.shape
    codebook_size, vecdim = codebook.shape
    
    
    # if (m * n) % vecdim != 0:
    #     raise ValueError(f"Размер матрицы {m}x{n} не делится на vecdim={vecdim}")
    
    # n_vectors = m * n // vecdim
    # w_vectors = w.reshape(n_vectors, vecdim)  # (n_vectors, vecdim)
    
    w = w.reshape(-1)
    w_pad = 0

    if w.shape[0] % vecdim != 0:
        w_pad = vecdim - w.shape[0] % vecdim
        w = torch.cat([w, torch.zeros(w_pad).to(w.device)])
    
    n_vectors = w.shape[0] // vecdim
    w_vectors = w.reshape(n_vectors, vecdim)  # (n_vectors, vecdim)


    indices = torch.empty(n_vectors, dtype=torch.long, device=w.device)
    
    for start in range(0, n_vectors, batch_size):
        end = min(start + batch_size, n_vectors)
        batch = w_vectors[start:end]  # (batch_size, vecdim)
        
        distances = torch.cdist(batch, codebook, p=2)  # (batch_size, codebook_size)
        indices[start:end] = torch.argmin(distances, dim=1)
    
    w_q_vectors = codebook[indices]  # (n_vectors, vecdim)
    
    if w_pad != 0:
        w_q_vectors = w_q_vectors.reshape(-1)[:-w_pad]

    w_q = w_q_vectors.reshape(m, n)
    
    return w_q, indices

In [4]:
def fp4_s1e2m1_subnormal(x4):
        s = (x4 >> 3) & 0b1
        e = (x4 >> 1) & 0b11
        m = x4 & 0b1
        value = torch.where(e==0, 0.5 * m.float(), (1.0 + 0.5*m.float()) * (2.0 ** (e.float() - 1)))
        value = value * torch.where(s==0, 1.0, -1.0)        
        return value

def fp6_s1e3m2_subnormal(x6):
    s = (x6 >> 5) & 0b1       
    e = (x6 >> 2) & 0b111     
    m = x6 & 0b11    
             
    e_f = e.float()
    m_f = m.float()

    # Bias = 3 (2^(3-1)-1)
    value = torch.where(
        e == 0,
        (m_f / (2**2)) * (2.0 ** (1 - 3)),
        (1.0 + m_f / (2**2)) * (2.0 ** (e_f - 3))
    )

    value = value * torch.where(s == 0, 1.0, -1.0)
    return value

def fp8_s1e4m3_subnormal(x8):
    # Разбор битов
    s = (x8 >> 7) & 0b1       
    e = (x8 >> 3) & 0b1111    
    m = x8 & 0b111            

    # Преобразуем к float
    e_f = e.float()
    m_f = m.float()

    # Subnormal: e == 0
    value = torch.where(
        e == 0,
        (m_f / (2**3)) * (2.0 ** (1 - 7)), 
        (1.0 + m_f / (2**3)) * (2.0 ** (e_f - 7))
    )

    value = value * torch.where(s == 0, 1.0, -1.0)
    return value


def int8_to_fp4_pair(x):
    x = x.int() & 0xFF

    x1 = (x >> 4) & 0b1111
    x2 = x & 0b1111

    return fp4_s1e2m1_subnormal(torch.stack([x1, x2], dim=-1))

In [5]:
def permute_16bits(x):
    x = x.to(torch.int32) * 34038481 + 76625530
    x = x * (x + 1)
    x = (x >> 9) & ((1<<16) - 1)
    return x


def pack_bits(matrix, in_bits=4, out_bits=16):
    """
    matrix: [M, N], значения в пределах (0..2^in_bits-1)
    in_bits: число бит для каждого входного значения
    out_bits: разрядность упаковки (8, 16, 32, 64)
    """
    M, N = matrix.shape
    total_bits = N * in_bits
    num_out = (total_bits + out_bits - 1) // out_bits  # ceil

    # Внутренний тип всегда int64 для сдвигов
    packed = torch.zeros((M, num_out), dtype=torch.int64, device=matrix.device)

    bit_position = 0
    mask_in = (1 << in_bits) - 1

    for col in range(N):
        # Берём значения текущего столбца
        values = (matrix[:, col].to(torch.int64) & mask_in) << (bit_position % out_bits)
        idx_low = bit_position // out_bits
        packed[:, idx_low] |= values & ((1 << out_bits) - 1)

        # Если пересекли границу out_bits, записываем остаток в следующий элемент
        if (bit_position % out_bits) + in_bits > out_bits:
            packed[:, idx_low + 1] |= values >> out_bits

        bit_position += in_bits

    # Приводим к целевому типу
    if out_bits == 8:
        return packed.to(torch.uint8)
    elif out_bits == 16:
        return packed.to(torch.uint16)
    elif out_bits == 32:
        return packed.to(torch.int32)  # uint32 нет в PyTorch
    elif out_bits == 64:
        return packed.to(torch.int64)
    else:
        raise ValueError("out_bits должен быть 8, 16, 32 или 64")
    
    
def unpack_bits(packed, N, in_bits=4, out_bits=16):
    M, num_out = packed.shape
    packed64 = packed.to(torch.int64)
    unpacked = torch.zeros((M, N), dtype=torch.int64, device=packed.device)

    mask_in = (1 << in_bits) - 1
    bit_position = 0

    for col in range(N):
        idx_low = bit_position // out_bits
        shift = bit_position % out_bits
        value = packed64[:, idx_low] >> shift
        if shift + in_bits > out_bits:
            value |= packed64[:, idx_low + 1] << (out_bits - shift)
        unpacked[:, col] = value & mask_in
        bit_position += in_bits

    return unpacked

# indices = torch.arange(1 << 16)
# windows = sliding_bit_windows_16bit(indices, 4, 2)
# #codebook = fp4_s1e2m1_subnormal(windows).cuda()

# packed_windows = pack_bits(windows, in_bits=4, out_bits=16)
# packed_windows = permute_16bits(packed_windows)
# windows = unpack_bits(packed_windows, N=windows.shape[-1], in_bits=4, out_bits=16)

In [6]:
def sliding_bit_windows_16bit(X: torch.Tensor, bit_per_value: int, bit_step: int) -> torch.Tensor:
    N = 16
    assert bit_per_value <= 8  # output will be packed into int8
    num_windows = (N + bit_step - 1) // bit_step
    print("num windows:", num_windows)
    print("bit per window:", 16 / num_windows)

    # Convert to unsigned 32-bit for safe bit shifting and mask to 16 bits
    X_u = X.to(torch.int32) & 0xFFFF  # (*X.shape)

    # Extract all bits: shape (*X.shape, N)
    bit_positions = torch.arange(N - 1, -1, -1, dtype=torch.int32, device=X.device)
    bits = ((X_u.unsqueeze(-1) >> bit_positions) & 1).to(torch.uint8)

    # Starting positions of windows and bit weights for packing
    start = torch.arange(0, N, bit_step, device=X.device)[:num_windows]
    pow2 = (1 << torch.arange(bit_per_value - 1, -1, -1, device=X.device)).to(torch.uint8)

    # Collect window bits: shape (*X.shape, num_windows, bit_per_value)
    idx = (start[:, None] + torch.arange(bit_per_value, device=X.device)) % N
    selected_bits = bits[..., idx]

    # Pack bits into uint8 values: shape (*X.shape, num_windows)
    windows_uint8 = torch.sum(selected_bits * pow2, dim=-1)

    return windows_uint8


def sliding_bit_windows_8bit(X: torch.Tensor, bit_per_value: int, bit_step: int) -> torch.Tensor:
    N = 8
    assert bit_per_value <= 8  # output will be packed into int8
    num_windows = (N + bit_step - 1) // bit_step
    print("num windows:", num_windows)
    print("bit per window:", 8 / num_windows)

    # Convert to unsigned 32-bit for safe bit shifting and mask to 8 bits
    X_u = X.to(torch.int32) & 0xFF  # (*X.shape)

    # Extract all bits: shape (*X.shape, N)
    bit_positions = torch.arange(N - 1, -1, -1, dtype=torch.int32, device=X.device)
    bits = ((X_u.unsqueeze(-1) >> bit_positions) & 1).to(torch.uint8)

    # Starting positions of windows and bit weights for packing
    start = torch.arange(0, N, bit_step, device=X.device)[:num_windows]
    pow2 = (1 << torch.arange(bit_per_value - 1, -1, -1, device=X.device)).to(torch.uint8)

    # Collect window bits: shape (*X.shape, num_windows, bit_per_value)
    idx = (start[:, None] + torch.arange(bit_per_value, device=X.device)) % N
    selected_bits = bits[..., idx]

    # Pack bits into uint8 values: shape (*X.shape, num_windows)
    windows_uint8 = torch.sum(selected_bits * pow2, dim=-1)

    return windows_uint8

In [7]:
def get_scalars(bits):
	values = torch.linspace(-3, 3, steps=1<<bits)
	return values[torch.randperm(1<<bits)]

In [19]:
indices = torch.arange(1 << 16)
windows = sliding_bit_windows_16bit(indices, 4, 2)

# packed_windows = pack_bits(windows, in_bits=4, out_bits=16)
# packed_windows = permute_16bits(packed_windows)
# windows = unpack_bits(packed_windows, N=windows.shape[-1], in_bits=4, out_bits=16)

codebook = fp4_s1e2m1_subnormal(windows).cuda()

# indices = torch.arange(1 << 16)
# windows = sliding_bit_windows_16bit(indices, 6, 2)
# codebook = fp6_s1e3m2_subnormal(windows).cuda()

# indices = torch.arange(1 << 16)
# windows = sliding_bit_windows_16bit(indices, 8, 3)
# # codebook = fp8_s1e4m3_subnormal(windows).cuda()
# codebook = torch.randn(1<<8)[windows].cuda()
# #codebook = get_scalars(4)[windows].cuda()

# indices = torch.arange(1 << 8)
# windows = sliding_bit_windows_8bit(indices, 4, 2)
# # packed_windows = pack_bits(windows, in_bits=4, out_bits=16)
# # packed_windows = permute_16bits(packed_windows)
# # windows = unpack_bits(packed_windows, N=windows.shape[-1], in_bits=4, out_bits=16)
# codebook = fp4_s1e2m1_subnormal(windows).cuda()

# indices = torch.arange(1 << 8)
# windows = sliding_bit_windows_8bit(indices, 8, 2)
# codebook = torch.randn(1<<8)[windows].cuda()

# codebook = torch.randn(2**8, 8).cuda()

codebook /= codebook.std()
codebook.shape

num windows: 8
bit per window: 2.0


torch.Size([65536, 8])

In [20]:
w = torch.randn(2048, 2048).cuda()

w_q, best_indices = vector_quantize_batched(w, codebook, batch_size=1024)

loss_fn = torch.nn.MSELoss()
loss = loss_fn(w, w_q)
print("mse err:", loss)

print("zeros:", (w_q==0).sum() / (w_q.shape[0] * w_q.shape[1]))

mse err: tensor(0.1423, device='cuda:0')
zeros: tensor(0.0987, device='cuda:0')
