Реализация triton (cuda если интересно) кернелей для квантизации весов в LLM и инференса квантизованной модели.

План:
1) Реализовать кернель для квантизации 2D матрицы из fp16 в int4
и последующей упаковки квантизованной матрицы в int8 или int32.
При этом потребляемая память должна уменьшиться в 4 раза.
2) Реализовать кернель для перемножения матрицы в bf16 на квантизованную матрицу в int4 на (X16@W4^T)
3) Сравнить скорость перемножения (X16@W4^T) с (X16@W16^T). Размеры матрицы W такие же, как размеры матриц весов для модели Llama-3.2-1B-Instruct (https://huggingface.co/unsloth/Llama-3.2-1B-Instruct).
Количество строк (токенов) в матрице активаций X: 128, 512, 2048
4) С использованием написанных кернелей написать квантизованный линейный слой и применить его к линейныс слоям модели Llama-3.2-1B-Instruct
5) Замерить скорость расчета и уровень перплексии на wikitext2

In [53]:
!pip install -U transformers

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [72]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=40)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

I'm an artificial intelligence model known as Llama. Llama stands for "Large Language Model Meta AI."<|eot_id|>


In [73]:
import torch

def get_model_size_mb(model: torch.nn.Module) -> float:
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    return (param_size + buffer_size) / (1024 ** 2)

size_mb = get_model_size_mb(model)
print(f"Model size (MB): {size_mb:.2f}")

Model size (MB): 4714.26


In [74]:
model = model.eval()

In [75]:
model.model.layers[0].mlp.gate_proj.weight.shape
# model.model.layers[0].mlp.gate_proj

torch.Size([8192, 2048])

In [76]:
# Freeze all model parameters
for param in model.parameters():
    param.requires_grad = False


In [None]:
# import torch
# path = "/kaggle/working/gate_proj_weight_0.pt"
# torch.save(model.model.layers[0].mlp.gate_proj.weight, path)

In [77]:
model.to("cuda:0")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,)

In [60]:
model.model.layers[0].mlp.gate_proj.weight[0][0]

tensor(0.0260, device='cuda:0')

In [61]:
model.model.layers[0].mlp.gate_proj.weight

Parameter containing:
tensor([[ 0.0260, -0.0123, -0.0140,  ...,  0.0035,  0.0272,  0.0364],
        [-0.0096,  0.0225,  0.0015,  ..., -0.0272, -0.0151,  0.0024],
        [-0.0261, -0.0095,  0.0002,  ..., -0.0056, -0.0120, -0.0287],
        ...,
        [ 0.0198, -0.0228, -0.0190,  ...,  0.0289, -0.0427, -0.0137],
        [-0.0026, -0.0109,  0.0011,  ..., -0.0037,  0.0061, -0.0140],
        [-0.0231,  0.0049,  0.0086,  ..., -0.0120, -0.0069,  0.0070]],
       device='cuda:0')

In [62]:
model.model.layers[0].mlp.gate_proj.weight.to(torch.float16)[0][0]

tensor(0.0260, device='cuda:0', dtype=torch.float16)

In [78]:
model.device

device(type='cuda', index=0)

In [79]:
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
        triton.Config({"BLOCK_SIZE": 2048}, num_stages=1),
    ],
    key=["n_elements"],
)
@triton.jit
def _quantize_rowwise(x_ptr, output_ptr, output_maxs, n_elements, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr):
    pid = tl.program_id(axis=0)
    block_start = pid * n_elements
    row_start_ptr = x_ptr + block_start

    idx = tl.arange(0, P2 // 2)
    off_even = 2 * idx
    off_odd = 2 * idx + 1

    mask_even = off_even < n_elements
    mask_odd = off_odd < n_elements

    x_even = tl.load(row_start_ptr + off_even, mask=mask_even, other=0.0)
    x_odd = tl.load(row_start_ptr + off_odd, mask=mask_odd, other=0.0)

    absmax_even = tl.max(tl.abs(x_even))
    absmax_odd = tl.max(tl.abs(x_odd))
    absmax = tl.maximum(absmax_even, absmax_odd)

    scale = tl.where(absmax == 0, 0.0, 7.0 / absmax)

    s_even = x_even * scale
    s_odd = x_odd * scale

    q_even = tl.where(s_even >= 0, s_even + 0.5, s_even - 0.5).to(tl.int8).to(tl.uint8) & 0xF
    q_odd = tl.where(s_odd >= 0, s_odd + 0.5, s_odd - 0.5).to(tl.int8).to(tl.uint8) & 0xF

    packed = (q_odd << 4) | q_even

    packed_block_start = pid * ((n_elements + 1) // 2)
    packed_mask = idx < ((n_elements + 1) // 2)

    tl.store(output_ptr + packed_block_start + idx, packed, mask=packed_mask)
    tl.store(output_maxs + pid, absmax)


def quantize_rowwise(x: torch.Tensor):
    N = x.shape[0]
    M = x.shape[1]

    out_cols = (M + 1) // 2

    output_tensor = torch.empty((N, out_cols), dtype=torch.uint8, device=x.device)

    output_maxs = torch.empty(N, dtype=torch.float16, device=x.device)

    P2 = 2 ** int(torch.ceil(torch.log2(torch.tensor(M, dtype=torch.float16))))

    grid = lambda meta: (N,)
    _quantize_rowwise[grid](x_ptr=x, output_ptr=output_tensor, output_maxs=output_maxs, n_elements=M, P2=P2)

    return output_tensor, output_maxs

In [65]:
weight_quant, max_values = quantize_rowwise(model.model.layers[0].mlp.gate_proj.weight.to('cuda').to(torch.float16))

In [None]:
# model.model.layers[0].mlp.gate_proj.weight = torch.nn.Parameter(weight_quant, requires_grad=False)
# model.model.layers[0].mlp.gate_proj.max_values = max_values

In [80]:
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 64,  "BLOCK_K": 64},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 128, "BLOCK_N": 64,  "BLOCK_K": 64},  num_warps=4, num_stages=2),
        triton.Config({"BLOCK_M": 64,  "BLOCK_N": 128, "BLOCK_K": 64},  num_warps=4, num_stages=2)
    ],
    key=["B", "IN", "OUT"],
)
@triton.jit
def _forward_int4_fused_kernel(x_q_ptr,
                               w_q_ptr, w_scale_ptr,
                               b_ptr, y_ptr,
                               B, IN, OUT,
                               BLOCK_M: tl.constexpr,
                               BLOCK_N: tl.constexpr,
                               BLOCK_K: tl.constexpr,
                               PER_CHANNEL: tl.constexpr,
                               HAS_BIAS: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    acc = tl.full((BLOCK_M, BLOCK_N), 0.0, dtype=tl.float32)

    pid_0_off = (tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) * OUT
    pid_1_off = tl.arange(0, BLOCK_N) + pid_1 * BLOCK_N
    off = pid_0_off[:, None] + pid_1_off[None, :]
    
    out_mask = ((tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) < B)[:, None] & \
               (pid_1_off < OUT)[None, :]  

    for k in range(0, IN, BLOCK_K):
        off_x_d0 = (tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) * IN
        off_x_d1 = (tl.arange(0, BLOCK_K) + k)
        off_x = off_x_d0[:, None] + off_x_d1[None, :]
        mask_x = (off_x_d1 < IN)[None, :] & ((tl.arange(0, BLOCK_M) + pid_0 * BLOCK_M) < B)[:, None]

        packed_IN = (IN + 1) // 2
        global_cols = pid_1 * BLOCK_N + tl.arange(0, BLOCK_N)
        out_guard = global_cols < OUT
        safe_cols = tl.where(out_guard, global_cols, 0)
        k_indices = tl.arange(0, BLOCK_K) + k
        row_offsets = safe_cols[None, :] * packed_IN
        byte_cols = (k_indices // 2)[:, None]
        off_w = row_offsets + byte_cols
        mask_w = (k_indices[:, None] < IN) & out_guard[None, :]
        is_high = (k_indices & 1) == 1
        

        x = tl.load(x_q_ptr + off_x, mask_x, 0)
        w_byte = tl.load(w_q_ptr + off_w, mask_w, 0)

        w_u32 = w_byte.to(tl.uint32)
        low = w_u32 & 0xF
        high = (w_u32 >> 4) & 0xF
        sel = is_high[:, None]
        w_nib = tl.where(sel, high, low)
        w_i32 = w_nib.to(tl.int32)
        w_signed_i32 = tl.where(w_i32 < 8, w_i32, w_i32 - 16)

        x_f16 = x.to(tl.float16)
        w_f16 = w_signed_i32.to(tl.float16)
        acc += tl.dot(x_f16, w_f16)
    
        

    if PER_CHANNEL:
        w_scale_mask = pid_1_off < OUT
        w_scale = tl.load(w_scale_ptr + pid_1_off, mask=w_scale_mask)
        alpha = w_scale[None, :].to(tl.float32)
    else:
        w_scale = tl.load(w_scale_ptr)
        alpha = w_scale.to(tl.float32)

    if HAS_BIAS:
        bias_mask = pid_1_off < OUT
        bias = tl.load(b_ptr + pid_1_off, mask=bias_mask, other=0).to(tl.float32)
        acc = acc * alpha + bias[None, :]
    else:
        acc = acc * alpha

   
    tl.store(y_ptr + off, acc.to(tl.float16), out_mask)               

def matmul_int4_fused(x: torch.Tensor,
                      w_q: torch.Tensor,
                      w_scale: torch.Tensor,
                      bias: torch.Tensor | None = None,
                      *, per_channel: bool = True) -> torch.Tensor:

    B, IN = x.shape
    OUT = w_scale.shape[0]

    x_f16 = x.to(torch.float16)
    w_scale_f16 = (w_scale.to(dtype=torch.float16, device=x.device) / 7)
    y = torch.empty((B, OUT), dtype=torch.float16, device=x.device)

    grid = lambda meta: (triton.cdiv(B, meta["BLOCK_M"]),
                     triton.cdiv(OUT, meta["BLOCK_N"]))

    _forward_int4_fused_kernel[grid](x_q_ptr=x_f16,
                               w_q_ptr=w_q, w_scale_ptr=w_scale_f16,
                               b_ptr=bias, y_ptr=y,
                               B=B, IN=IN, OUT=OUT,
                               PER_CHANNEL=per_channel,
                               HAS_BIAS=(bias is not None))

    return y

In [81]:
class QuantLinear(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, weight_quant: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor | None = None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_q = weight_quant
        self.weight_scale = weight_scale
        self.bias = bias

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_shape = x.shape
        x_2d = x.reshape(-1, orig_shape[-1])
        y_2d = matmul_int4_fused(
            x=x_2d,
            w_q=self.weight_q,
            w_scale=self.weight_scale,
            bias=self.bias,
        )
        y = y_2d.reshape(*orig_shape[:-1], self.out_features)
        return y

In [82]:
device = next(model.parameters()).device

def quantize_linear_module(linear: torch.nn.Linear) -> QuantLinear:
    in_features = linear.in_features
    out_features = linear.out_features
    weight = linear.weight.detach().to(device=device, dtype=torch.float16)
    weight_q, max_values = quantize_rowwise(weight)
    bias = None
    if linear.bias is not None:
        bias = linear.bias.detach().to(device=device, dtype=torch.float16)
    return QuantLinear(in_features, out_features, weight_q, max_values, bias)

def replace_linear_with_quantlinear(module: torch.nn.Module):
    for child_name, child in module.named_children():
        if isinstance(child, torch.nn.Linear):
            quant_linear = quantize_linear_module(child)
            setattr(module, child_name, quant_linear)
        else:
            replace_linear_with_quantlinear(child)

replace_linear_with_quantlinear(model)

In [89]:
model.model.layers[5].mlp.gate_proj.weight_scale

tensor([0.0957, 0.0957, 0.0645,  ..., 0.0767, 0.1133, 0.0591], device='cuda:0',
       dtype=torch.float16)

In [None]:
# model.model.layers[0].mlp.gate_proj = QuantLinear(2048, 8192, model.model.layers[0].mlp.gate_proj.weight, model.model.layers[0].mlp.gate_proj.max_values)

In [90]:
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=40)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

I'm an artificial intelligence (AI) called "Luna" (or "Luni" for short). I'm an artificial intelligence designed to simulate conversation, answer questions, and even provide information and


In [91]:
import torch

def get_model_size_mb(model: torch.nn.Module) -> float:
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    return (param_size + buffer_size) / (1024 ** 2)

size_mb = get_model_size_mb(model)
print(f"Model size (MB): {size_mb:.2f}")

Model size (MB): 1002.26


In [92]:
from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
eval_dataset = dataset["validation"]

def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=False)

tokenized = eval_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
)

block_size = 2048

def group_texts(examples):
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = (len(concatenated["input_ids"]) // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_dataset = tokenized.map(group_texts, batched=True)

lm_dataset.set_format(type="torch", columns=["input_ids", "labels", "attention_mask"])


In [93]:
import torch
from torch.utils.data import DataLoader
import time
import math

batch_size = 1

dataloader = DataLoader(lm_dataset, batch_size=batch_size, pin_memory=True)

device = next(model.parameters()).device

start_time = time.time()
num_tokens = 0
loss_sum = 0.0
count = 0

model.eval()
with torch.no_grad():
    for batch in dataloader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch.get("attention_mask"),
            labels=batch["labels"],
        )
        loss = outputs.loss
        num_tokens += batch["labels"].numel()
        loss_sum += loss.item()
        count += 1

if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = time.time() - start_time
mean_loss = loss_sum / max(count, 1)
ppl = math.exp(mean_loss)
tps = num_tokens / elapsed if elapsed > 0 else float("nan")

print(f"Perplexity: {ppl:.4f}")
print(f"Eval time (s): {elapsed:.2f}")
print(f"Tokens processed: {num_tokens}")
print(f"Tokens/sec: {tps:.2f}")


Perplexity: 71.0608
Eval time (s): 143.44
Tokens processed: 249856
Tokens/sec: 1741.88


In [99]:
messages = [
    {"role": "user", "content": "how to hire a good engineer?"},
]
inputs = tokenizer.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=400)
print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

Here are some tips on how to hire a good engineer:

1.  **Define the Job Requirements**: Before you can start interviewing engineers, you need to define their requirements. Consider factors like:
    * Education and experience (Bachelor's or Master's degree, typically 1-2 years, and relevant work experience)
    * Language (in English or a language) and skills (in computer programming or other relevant technical skills)
    * Technical skills (software or hardware relevant skills, such as:
        * Database design (SQL, MySQL, or Oracle)
        * Data visualization (Table, Chart, and Graph)
        * Data mining (SQL, MySQL, or Oracle)
    * Software or hardware relevant skills (e.g., SQL Server, Oracle Database)
    * Domain (web development, e.g., PHP, Java, or.NET)
    * Database (database design, database design, database design, database design, database design, database design, database design, database design, database design, database design, database design, database design,