2) Реализовать кернель для перемножения матрицы в bf16 на квантизованную матрицу в int4 на (X16@W4^T)

In [40]:
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 64,  "BLOCK_K": 64},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64,  "BLOCK_K": 64},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 128, "BLOCK_K": 64},  num_warps=4, num_stages=2)
    ],
    key=["B", "IN", "OUT"],
)
@triton.jit
def _forward_int4_fused_kernel(x_q_ptr,
                               w_q_ptr, w_scale_ptr,
                               b_ptr, y_ptr,
                               B, IN, OUT,
                               BLOCK_M: tl.constexpr,
                               BLOCK_N: tl.constexpr,
                               BLOCK_K: tl.constexpr,
                               PER_CHANNEL: tl.constexpr,
                               HAS_BIAS: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    acc = tl.full((BLOCK_M, BLOCK_N), 0.0, dtype=tl.float32)

    pid_0_off = (tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) * OUT
    pid_1_off = tl.arange(0, BLOCK_N) + pid_1 * BLOCK_N
    off = pid_0_off[:, None] + pid_1_off[None, :]
    
    out_mask = ((tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) < B)[:, None] & \
               (pid_1_off < OUT)[None, :]  

    for k in range(0, IN, BLOCK_K):
        off_x_d0 = (tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) * IN
        off_x_d1 = (tl.arange(0, BLOCK_K) + k)
        off_x = off_x_d0[:, None] + off_x_d1[None, :]
        mask_x = (off_x_d1 < IN)[None, :] & ((tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) < B)[:, None]

        packed_IN = (IN + 1) // 2
        global_cols = pid_1 * BLOCK_N + tl.arange(0, BLOCK_N)
        out_guard = global_cols < OUT
        safe_cols = tl.where(out_guard, global_cols, 0)
        k_indices = tl.arange(0, BLOCK_K) + k
        row_offsets = safe_cols[None, :] * packed_IN
        byte_cols = (k_indices // 2)[:, None]
        off_w = row_offsets + byte_cols
        mask_w = (k_indices[:, None] < IN) & out_guard[None, :]
        is_high = (k_indices & 1) == 1
        

        x = tl.load(x_q_ptr + off_x, mask_x, 0)
        w_byte = tl.load(w_q_ptr + off_w, mask_w, 0)

        w_u32 = w_byte.to(tl.uint32)
        low = w_u32 & 0xF
        high = (w_u32 >> 4) & 0xF
        sel = is_high[:, None]
        w_nib = tl.where(sel, high, low)
        w_i32 = w_nib.to(tl.int32)
        w_signed_i32 = tl.where(w_i32 < 8, w_i32, w_i32 - 16)

        x_f16 = x.to(tl.float16)
        w_f16 = w_signed_i32.to(tl.float16)
        acc += tl.dot(x_f16, w_f16)
    
        

    if PER_CHANNEL:
        w_scale_mask = pid_1_off < OUT
        w_scale = tl.load(w_scale_ptr + pid_1_off, mask=w_scale_mask)
        alpha = w_scale[None, :].to(tl.float32)
    else:
        w_scale = tl.load(w_scale_ptr)
        alpha = w_scale.to(tl.float32)

    if HAS_BIAS:
        bias_mask = pid_1_off < OUT
        bias = tl.load(b_ptr + pid_1_off, mask=bias_mask, other=0).to(tl.float32)
        acc = acc * alpha + bias[None, :]
    else:
        acc = acc * alpha

   
    tl.store(y_ptr + off, acc.to(tl.float16), out_mask)               

def matmul_int4_fused(x: torch.Tensor,
                      w_q: torch.Tensor,
                      w_scale: torch.Tensor,
                      bias: torch.Tensor | None = None,
                      *, per_channel: bool = True) -> torch.Tensor:

    B, IN = x.shape
    OUT = w_scale.shape[0]

    x_f16 = x.to(torch.float16)
    w_scale_f16 = (w_scale.to(dtype=torch.float16, device=x.device) / 7)
    y = torch.empty((B, OUT), dtype=torch.float16, device=x.device)

    grid = lambda meta: (triton.cdiv(B, meta["BLOCK_M"]),
                     triton.cdiv(OUT, meta["BLOCK_N"]))

    _forward_int4_fused_kernel[grid](x_q_ptr=x_f16,
                               w_q_ptr=w_q, w_scale_ptr=w_scale_f16,
                               b_ptr=bias, y_ptr=y,
                               B=B, IN=IN, OUT=OUT,
                               PER_CHANNEL=per_channel,
                               HAS_BIAS=(bias is not None))

    return y

In [2]:
path = "/kaggle/working/gate_proj_weight_0_quant.pt"
w_quant = torch.load(path)


path = "/kaggle/working/gate_proj_weight_0_quant_scale.pt"
w_scale = torch.load(path)

In [19]:
w_scale.shape

torch.Size([8192])

In [18]:
w_quant.shape

torch.Size([8192, 1024])

In [7]:
x = torch.randn(512, 2048, dtype=torch.float16, device='cuda')

In [41]:
res = matmul_int4_fused(x, w_quant, w_scale)

In [42]:
res.shape

torch.Size([512, 8192])

In [43]:
res

tensor([[ 0.6948,  1.2178, -0.1401,  ..., -0.1566, -0.9468, -0.6016],
        [ 0.6704,  0.8579, -1.0967,  ..., -0.4187, -1.3291,  1.1094],
        [-0.4497,  0.9160,  1.7734,  ..., -1.6924, -1.0352,  0.4187],
        ...,
        [ 0.7725, -0.5796,  0.0335,  ...,  1.1172, -0.6138, -0.9458],
        [ 0.1462,  0.4265,  0.4204,  ..., -1.6094, -0.7749, -1.1484],
        [ 0.2812,  0.2021, -0.7271,  ..., -1.3047,  0.7363,  0.1017]],
       device='cuda:0', dtype=torch.float16)

In [25]:
path = "/kaggle/working/gate_proj_weight_0.pt"
w = torch.load(path)

In [28]:
w = w.to('cuda')

In [31]:
w = w.to(torch.float16)

In [34]:
res_wo_q = x @ w.T

In [36]:
res_wo_q.shape

torch.Size([512, 8192])

In [37]:
res_wo_q

tensor([[ 0.7295,  1.3838, -0.0294,  ..., -0.3203, -0.9102, -0.4028],
        [ 0.5757,  0.8213, -0.9316,  ..., -0.2351, -1.1084,  1.4385],
        [-0.3137,  0.8125,  1.7363,  ..., -1.5654, -0.9619,  0.5479],
        ...,
        [ 0.7905, -0.5903,  0.2510,  ...,  1.0312, -1.0000, -1.0908],
        [ 0.1451,  0.6016,  0.1072,  ..., -1.6182, -0.8105, -1.3008],
        [ 0.1694,  0.1641, -0.8584,  ..., -1.3818,  0.6714,  0.2216]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>)

In [45]:
torch.abs(res - res_wo_q).mean()

tensor(0.1085, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)