3) Сравнить скорость перемножения (X16@W4^T) с (X16@W16^T). Размеры матрицы W такие же, как размеры матриц весов для модели Llama-3.2-1B-Instruct (https://huggingface.co/unsloth/Llama-3.2-1B-Instruct).
Количество строк (токенов) в матрице активаций X: 128, 512, 2048

In [18]:
import torch

# Check CUDA availability and GPU model
if torch.cuda.is_available():
    print(f"CUDA is available")
    print(f"GPU Model: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.current_device()}")
else:
    print("CUDA is not available")


CUDA is available
GPU Model: Tesla T4
CUDA Version: 12.4
Number of GPUs: 2
Current GPU: 0


In [19]:
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 128, "BLOCK_K": 32},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 128, "BLOCK_K": 64},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64},  num_warps=4, num_stages=3),

        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 256, "BLOCK_K": 32},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 256, "BLOCK_K": 64},  num_warps=4, num_stages=3),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32},  num_warps=4, num_stages=3),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64},  num_warps=8, num_stages=3),
    ],
    key=["B", "IN", "OUT"],
)
@triton.jit
def _forward_int4_fused_kernel(x_q_ptr,
                               w_q_ptr, w_scale_ptr,
                               b_ptr, y_ptr,
                               B, IN, OUT,
                               BLOCK_M: tl.constexpr,
                               BLOCK_N: tl.constexpr,
                               BLOCK_K: tl.constexpr,
                               PER_CHANNEL: tl.constexpr,
                               HAS_BIAS: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    acc = tl.full((BLOCK_M, BLOCK_N), 0.0, dtype=tl.float32)

    pid_0_off = (tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) * OUT
    pid_1_off = tl.arange(0, BLOCK_N) + pid_1 * BLOCK_N
    off = pid_0_off[:, None] + pid_1_off[None, :]
    
    out_mask = ((tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) < B)[:, None] & \
               (pid_1_off < OUT)[None, :]  

    for k in range(0, IN, BLOCK_K):
        off_x_d0 = (tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) * IN
        off_x_d1 = (tl.arange(0, BLOCK_K) + k)
        off_x = off_x_d0[:, None] + off_x_d1[None, :]
        mask_x = (off_x_d1 < IN)[None, :] & ((tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) < B)[:, None]

        packed_IN = (IN + 1) // 2
        global_cols = pid_1 * BLOCK_N + tl.arange(0, BLOCK_N)
        out_guard = global_cols < OUT
        safe_cols = tl.where(out_guard, global_cols, 0)
        k_indices = tl.arange(0, BLOCK_K) + k
        row_offsets = safe_cols[None, :] * packed_IN
        byte_cols = (k_indices // 2)[:, None]
        off_w = row_offsets + byte_cols
        mask_w = (k_indices[:, None] < IN) & out_guard[None, :]
        is_high = (k_indices & 1) == 1
        

        x = tl.load(x_q_ptr + off_x, mask_x, 0)
        w_byte = tl.load(w_q_ptr + off_w, mask_w, 0)

        w_u32 = w_byte.to(tl.uint32)
        low = w_u32 & 0xF
        high = (w_u32 >> 4) & 0xF
        sel = is_high[:, None]
        w_nib = tl.where(sel, high, low)
        w_i32 = w_nib.to(tl.int32)
        w_signed_i32 = tl.where(w_i32 < 8, w_i32, w_i32 - 16)

        x_f16 = x.to(tl.float16)
        w_f16 = w_signed_i32.to(tl.float16)
        acc += tl.dot(x_f16, w_f16)
    
        

    if PER_CHANNEL:
        w_scale_mask = pid_1_off < OUT
        w_scale = tl.load(w_scale_ptr + pid_1_off, mask=w_scale_mask)
        alpha = w_scale[None, :].to(tl.float32)
    else:
        w_scale = tl.load(w_scale_ptr)
        alpha = w_scale.to(tl.float32)

    if HAS_BIAS:
        bias_mask = pid_1_off < OUT
        bias = tl.load(b_ptr + pid_1_off, mask=bias_mask, other=0).to(tl.float32)
        acc = acc * alpha + bias[None, :]
    else:
        acc = acc * alpha

   
    tl.store(y_ptr + off, acc.to(tl.float16), out_mask)               

def matmul_int4_fused(x: torch.Tensor,
                      w_q: torch.Tensor,
                      w_scale: torch.Tensor,
                      bias: torch.Tensor | None = None,
                      *, per_channel: bool = True) -> torch.Tensor:

    B, IN = x.shape
    OUT = w_scale.shape[0]

    x_f16 = x.to(torch.float16)
    w_scale_f16 = (w_scale.to(dtype=torch.float16, device=x.device) / 7)
    y = torch.empty((B, OUT), dtype=torch.float16, device=x.device)

    grid = lambda meta: (triton.cdiv(B, meta["BLOCK_M"]),
                     triton.cdiv(OUT, meta["BLOCK_N"]))

    _forward_int4_fused_kernel[grid](x_q_ptr=x_f16,
                               w_q_ptr=w_q, w_scale_ptr=w_scale_f16,
                               b_ptr=bias, y_ptr=y,
                               B=B, IN=IN, OUT=OUT,
                               PER_CHANNEL=per_channel,
                               HAS_BIAS=(bias is not None))

    return y

In [20]:
x = torch.randn(512, 2048, dtype=torch.float16, device='cuda')

In [21]:
path = "/kaggle/working/gate_proj_weight_0_quant.pt"
w_quant = torch.load(path)


path = "/kaggle/working/gate_proj_weight_0_quant_scale.pt"
w_scale = torch.load(path)

In [22]:
res = matmul_int4_fused(x, w_quant, w_scale)

In [25]:
res = matmul_int4_fused(x, w_quant, w_scale)

In [26]:
path = "/kaggle/working/gate_proj_weight_0.pt"
w = torch.load(path)

In [27]:
w = w.to('cuda')

In [28]:
w = w.to(torch.float16)

In [31]:
res_wo_q = x @ w.T

In [32]:
res

tensor([[-0.3804, -0.3999, -0.5967,  ..., -1.2314, -0.8862,  0.3347],
        [ 1.5898,  0.2499,  0.1747,  ...,  2.7715,  1.2373, -1.1924],
        [ 1.8408,  0.3977, -1.5039,  ..., -0.1364, -1.0596,  0.0104],
        ...,
        [-1.4541, -0.3726,  1.3721,  ...,  1.3975, -0.3584, -1.9141],
        [ 0.5132, -1.2920,  0.0789,  ...,  0.5181, -1.5557, -0.7412],
        [-0.2737,  0.7788, -0.0873,  ..., -1.3936,  1.9121, -1.2314]],
       device='cuda:0', dtype=torch.float16)

In [33]:
res_wo_q

tensor([[-0.5435, -0.5117, -0.2964,  ..., -1.3330, -0.7231,  0.2429],
        [ 1.5557,  0.2181,  0.2837,  ...,  2.7109,  1.2432, -1.0713],
        [ 1.6670,  0.1923, -1.6064,  ..., -0.3875, -1.0605,  0.0747],
        ...,
        [-1.3945, -0.3428,  1.4385,  ...,  1.3604, -0.1771, -1.9512],
        [ 0.6309, -1.1504,  0.1685,  ...,  0.2430, -1.5176, -0.7949],
        [-0.3479,  0.6191,  0.2129,  ..., -1.2236,  1.9395, -1.2617]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>)

In [34]:
torch.abs(res - res_wo_q).mean()

tensor(0.1083, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)

In [36]:
import torch

w_quant = w_quant.to('cuda')
w_scale = w_scale.to('cuda')
w = w.to('cuda').to(torch.float16)

def bench(fn, iters=50, warmup=10):
    for _ in range(warmup):
        fn()
        torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    times = []
    for _ in range(iters):
        start.record()
        fn()
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))
    return sum(times) / len(times)

sizes = [128, 512, 2048]
results = []
for b in sizes:
    x_cur = torch.randn(b, 2048, dtype=torch.float16, device='cuda')
    t_int4 = bench(lambda: matmul_int4_fused(x_cur, w_quant, w_scale))
    t_fp16 = bench(lambda: x_cur @ w.T)
    results.append((b, t_int4, t_fp16))

for b, t1, t2 in results:
    print(f"B={b}: int4={t1:.3f} ms, fp16={t2:.3f} ms, speedup={t2/t1:.2f}x")


B=128: int4=0.531 ms, fp16=0.226 ms, speedup=0.43x
B=512: int4=1.299 ms, fp16=0.581 ms, speedup=0.45x
B=2048: int4=4.806 ms, fp16=2.884 ms, speedup=0.60x


- оферхед на распаковку и деквантизацию