1) Реализовать кернель для квантизации 2D матрицы из fp16 в int4
и последующей упаковки квантизованной матрицы в int8 или int32.
При этом потребляемая память должна уменьшиться в 4 раза.

In [2]:
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
        triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
    ],
    key=["n_elements"],
)
@triton.jit
def _quantize_rowwise(x_ptr, output_ptr, output_maxs, n_elements, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * n_elements
    row_start_ptr = x_ptr + block_start

    idx = tl.arange(0, P2 // 2)
    off_even = 2 * idx
    off_odd = 2 * idx + 1

    mask_even = off_even < n_elements
    mask_odd = off_odd < n_elements

    x_even = tl.load(row_start_ptr + off_even, mask=mask_even, other=0.0)
    x_odd = tl.load(row_start_ptr + off_odd, mask=mask_odd, other=0.0)

    absmax_even = tl.max(tl.abs(x_even))
    absmax_odd = tl.max(tl.abs(x_odd))
    absmax = tl.maximum(absmax_even, absmax_odd)

    scale = tl.where(absmax == 0, 0.0, 7.0 / absmax)

    s_even = x_even * scale
    s_odd = x_odd * scale

    q_even = tl.where(s_even >= 0, s_even + 0.5, s_even - 0.5).to(tl.int8).to(tl.uint8) & 0xF
    q_odd = tl.where(s_odd >= 0, s_odd + 0.5, s_odd - 0.5).to(tl.int8).to(tl.uint8) & 0xF

    packed = (q_odd << 4) | q_even

    packed_block_start = pid * ((n_elements + 1) // 2)
    packed_mask = idx < ((n_elements + 1) // 2)

    tl.store(output_ptr + packed_block_start + idx, packed, mask=packed_mask)
    tl.store(output_maxs + pid, absmax)


def quantize_rowwise(x: torch.Tensor):
    N = x.shape[0]
    M = x.shape[1]

    out_cols = (M + 1) // 2

    output_tensor = torch.empty((N, out_cols), dtype=torch.uint8, device=x.device)

    output_maxs = torch.empty(N, dtype=torch.float16, device=x.device)

    P2 = 2 ** int(torch.ceil(torch.log2(torch.tensor(M, dtype=torch.float16))))

    grid = lambda meta: (N,)
    _quantize_rowwise[grid](x_ptr=x, output_ptr=output_tensor, output_maxs=output_maxs, n_elements=M, P2=P2)

    return output_tensor, output_maxs

In [None]:
# вот так вертать назад))
v = 33
low  = v & 0xF        # 1
high = (v >> 4) & 0xF # 2

2

In [3]:
path = "/kaggle/working/gate_proj_weight_0.pt"
w = torch.load(path)

In [4]:
w = w.to('cuda')

In [5]:
w = w.to(torch.float16)

In [6]:
w.shape

torch.Size([8192, 2048])

In [7]:
q, max_ = quantize_rowwise(w)

In [23]:
print(w.dtype)
memory_bf16 = w.element_size() * w.nelement()
print(f"BF16 память: {memory_bf16 / 1024 ** 2}")

torch.float16
BF16 память: 32.0


In [29]:
w.shape

torch.Size([8192, 2048])

In [19]:
q

tensor([[243,  14,  47,  ..., 238,  14,  67],
        [ 47, 224,  18,  ..., 254, 222,  14],
        [254,  16, 241,  ...,   0,   0, 239],
        ...,
        [226,  14,  34,  ..., 241,  49, 252],
        [240,  64, 192,  ...,  81,   2, 241],
        [ 14,   1, 224,  ...,  34, 241,  31]], device='cuda:0',
       dtype=torch.uint8)

In [8]:
import torch
path = "/kaggle/working/gate_proj_weight_0_quant.pt"
torch.save(q, path)

In [20]:
q.shape

torch.Size([8192, 1024])

In [27]:
max_

tensor([0.0610, 0.0645, 0.0806,  ..., 0.0732, 0.0732, 0.0713], device='cuda:0',
       dtype=torch.float16)

In [30]:
max_.shape

torch.Size([8192])

In [9]:
import torch
path = "/kaggle/working/gate_proj_weight_0_quant_scale.pt"
torch.save(max_, path)

In [25]:
print(q.dtype)
memory_int4 = q.element_size() * q.nelement()
print(f"INT4 память: {memory_int4 / 1024 ** 2}")

torch.uint8
INT4 память: 8.0
