In [1]:
import gc
import os
import random

import numpy as np
import torch
import triton
import triton.language as tl

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 128}, num_stages=1, num_warps=8),
        triton.Config({'BLOCK_SIZE': 256}, num_stages=1, num_warps=8),
        triton.Config({'BLOCK_SIZE': 512}, num_stages=1, num_warps=8),
    ],
    key=['n_cols'],
)
@triton.jit
def _quantize_rowwise_int4(
    x_ptr, output_ptr, output_maxs, n_rows, n_cols,
    BLOCK_SIZE: tl.constexpr
):

    row_idx = tl.program_id(0)
    if row_idx >= n_rows:
        return
    
    row_start = row_idx * n_cols
    
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    row = tl.load(x_ptr + row_start + col_offsets, mask=mask, other=0.0)
    
    abs_row = tl.abs(row)
    row_max = tl.max(tl.where(mask, abs_row, 0.0))
    
    row_max_safe = tl.maximum(row_max, 1e-12)
    scale = 8.0 / row_max_safe
    
    
    tl.store(output_maxs + row_idx, row_max_safe.to(tl.float16))
    
    packed_row_stride = n_cols // 2
    packed_row_start = row_idx * packed_row_stride
    
    for packed_start in range(0, packed_row_stride, BLOCK_SIZE):
        packed_offs = packed_start + tl.arange(0, BLOCK_SIZE)
        mask_packed = packed_offs < packed_row_stride
        
        orig_offs = packed_offs * 2
        mask1 = orig_offs < n_cols
        mask2 = (orig_offs + 1) < n_cols
        
        val1 = tl.load(x_ptr + row_start + orig_offs, mask=mask1 & mask_packed, other=0.0)
        val2 = tl.load(x_ptr + row_start + orig_offs + 1, mask=mask2 & mask_packed, other=0.0)
        
        quant1 = tl.extra.cuda.libdevice.rint(val1 * scale)
        quant2 = tl.extra.cuda.libdevice.rint(val2 * scale)
        
        quant1 = tl.minimum(tl.maximum(quant1, -8.0), 7.0)
        quant2 = tl.minimum(tl.maximum(quant2, -8.0), 7.0)
        
        uint4_1 = (quant1 + 8).to(tl.uint8)
        uint4_2 = (quant2 + 8).to(tl.uint8)
        packed = (uint4_1 << 4) | uint4_2
        
        tl.store(output_ptr + packed_row_start + packed_offs, packed, mask=mask_packed)

def quantize_rowwise_int4(x: torch.Tensor):
    assert x.is_cuda and x.dim() == 2
    n_rows, n_cols = x.shape
    assert n_cols % 2 == 0
    
    packed_n_cols = n_cols // 2
    q_packed = torch.empty((n_rows, packed_n_cols), device=x.device, dtype=torch.uint8)
    absmaxs = torch.empty(n_rows, device=x.device, dtype=torch.float16)
    
    grid = (n_rows,)
    _quantize_rowwise_int4[grid](x, q_packed, absmaxs, n_rows, n_cols)
    
    return q_packed, absmaxs

def dequantize_rowwise_int4(packed: torch.Tensor, absmaxs: torch.Tensor):
    assert packed.is_cuda and absmaxs.is_cuda
    assert packed.dtype == torch.uint8
    assert absmaxs.dtype == torch.float16
    assert packed.dim() == 2 and absmaxs.dim() == 1
    assert packed.size(0) == absmaxs.size(0)
    
    n_rows, packed_n_cols = packed.shape
    n_cols = packed_n_cols * 2
    
    uint4_1 = (packed >> 4) & 0x0F
    uint4_2 = packed & 0x0F
    
    int4_1 = uint4_1.to(torch.int8) - 8
    int4_2 = uint4_2.to(torch.int8) - 8
    
    dequantized = torch.empty((n_rows, n_cols), device=packed.device, dtype=torch.float16)
    dequantized[:, 0::2] = int4_1.to(torch.float16)
    dequantized[:, 1::2] = int4_2.to(torch.float16)
    
    scale = (absmaxs / 8.0).unsqueeze(1)
    dequantized = dequantized * scale
    
    return dequantized

def matmul_fp16_int4(a: torch.Tensor, b_packed: torch.Tensor, absmaxs: torch.Tensor):
    assert a.is_cuda and b_packed.is_cuda and absmaxs.is_cuda
    assert a.dtype == torch.float16, f"Expected torch.float16, got {a.dtype}"
    assert b_packed.dtype == torch.uint8
    assert absmaxs.dtype == torch.float16
    
    M, K = a.shape
    N, packed_K = b_packed.shape
    
    expected_packed_K = K // 2
    assert packed_K == expected_packed_K, f"Expected packed_K={expected_packed_K}, got {packed_K}. K must be divisible by 2"
    
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    
    grid = (triton.cdiv(M, 64), triton.cdiv(N, 64))
    
    try:
        _matmul_fp16_int4_kernel[grid](
            a, b_packed, absmaxs, c,
            M, N, K,
            a.stride(0), a.stride(1),
            b_packed.stride(0), b_packed.stride(1),
            c.stride(0), c.stride(1),
        )
    except Exception as e:
        print(f"Error in kernel execution: {e}")
        print("Using fallback matmul")
        dequantized_weights = dequantize_rowwise_int4(b_packed, absmaxs)
        c = torch.matmul(a, dequantized_weights.t())
    
    return c

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def _matmul_fp16_int4_kernel(
    a_ptr, b_ptr, absmaxs_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bn, stride_bk,
    stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    
    mask_m = rm < M
    mask_n = rn < N
    
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
        mask_k = rk < K
        
        a_ptrs = a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak
        a_block = tl.load(a_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
        a_block = a_block.to(tl.float32)
        
        packed_k = rk // 2
        mask_packed_k = packed_k < (K // 2)
        
        b_ptrs = b_ptr + rn[:, None] * stride_bn + packed_k[None, :] * stride_bk
        b_packed = tl.load(b_ptrs, mask=mask_n[:, None] & mask_packed_k[None, :], other=0)
        
        b_uint4_1 = (b_packed >> 4) & 0x0F
        b_uint4_2 = b_packed & 0x0F
        
        b_int4_1 = (b_uint4_1.to(tl.int8) - 8).to(tl.float32)
        b_int4_2 = (b_uint4_2.to(tl.int8) - 8).to(tl.float32)
        
        b_block = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
        
        even_mask = (rk % 2 == 0)[None, :]
        odd_mask = (rk % 2 == 1)[None, :]
        
        b_block = tl.where(even_mask & mask_k[None, :], b_int4_1, b_block)
        b_block = tl.where(odd_mask & mask_k[None, :], b_int4_2, b_block)
        
        scales_ptrs = absmaxs_ptr + rn
        scales = tl.load(scales_ptrs, mask=mask_n, other=1.0)
        scales = scales.to(tl.float32) / 8.0
        
        b_block = b_block * scales[:, None]
        
        b_block_t = tl.trans(b_block)
        
        acc += tl.dot(a_block, b_block_t, allow_tf32=False)
    
    c_ptrs = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
    tl.store(c_ptrs, acc.to(tl.float16), mask=mask_m[:, None] & mask_n[None, :])

def timing_wrapper(func, input_data):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    result = func(input_data)
    end_event.record()
    
    torch.cuda.synchronize()
    elapsed = start_event.elapsed_time(end_event)
    
    return result, elapsed

def load_testset(seed=42, seqlen=2048, tokenizer=None):
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
    
    random.seed(seed)
    np.random.seed(seed)
    
    test_text = "\n\n".join(testdata["text"])
    testenc = tokenizer(test_text, return_tensors='pt')
    
    return testenc


def perplexity_evaluator(model, testenc, bs=1, device=None):
    model.seqlen = 2048
    testenc = testenc.input_ids.to(device)
    
    if testenc.numel() == 0:
        print("Warning: Empty testenc tensor")
        return float('inf')
    
    nsamples = testenc.numel() // model.seqlen
    nlls = []
    print(f"nsamples {nsamples}")
    
    for i in range(0, nsamples, bs):        
        j = min(i + bs, nsamples)
        inputs = testenc[:, (i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j - i, model.seqlen)
        
        with torch.no_grad():
            lm_logits = model(inputs).logits
            
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]
        
        if shift_labels.numel() == 0:
            continue
            
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.reshape(-1, shift_logits.size(-1)), 
            shift_labels.reshape(-1)
        )
        
        neg_log_likelihood = loss.float() * model.seqlen * (j - i)
        nlls.append(neg_log_likelihood)
    
    if not nlls:
        print("Warning: No valid batches processed")
        return float('inf')
    
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    
    torch.cuda.empty_cache()
    
    return ppl.item()

def evaluate_model(model, tokenizer, device=torch.device("cuda:0")):
    testenc = load_testset(tokenizer=tokenizer)
    ppl = perplexity_evaluator(model, testenc, bs=1, device=device)
    return ppl

class QuantizedLinearModule(torch.nn.Module):
    def __init__(self, input_dim, output_dim, has_bias=True):
        super(QuantizedLinearModule, self).__init__()
        self.input_features = input_dim
        self.output_features = output_dim
        
        self.weight = torch.nn.Parameter(
            torch.empty(
                output_dim,
                input_dim // 2,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        
        if has_bias:
            self.bias = torch.nn.Parameter(
                torch.empty(
                    self.output_features,
                    dtype=torch.float16,
                ),
                requires_grad=False,
            )
        else:
            self.register_parameter("bias", None)
        
        self.register_buffer("weight_scale", torch.ones(output_dim))

    def forward(self, input_tensor_3d):
        reshaped_input = input_tensor_3d.view(-1, input_tensor_3d.size(-1))
        reshaped_input = reshaped_input.to(torch.float16)
        
        intermediate_result = matmul_fp16_int4(
            reshaped_input, 
            self.weight, 
            self.weight_scale
        ).view(*input_tensor_3d.size()[:-1], -1)
        
        if self.bias is not None:
            intermediate_result = intermediate_result + self.bias
        return intermediate_result
    
    @classmethod 
    def convert_from_linear(cls, original_linear):
        assert original_linear.in_features % 2 == 0, "Input features must be divisible by 2 for int4 packing"
        
        quantized_instance = cls(
            original_linear.in_features,
            original_linear.out_features,
            original_linear.bias is not None,
        )
    
        if original_linear.bias is not None:
            quantized_instance.bias = original_linear.bias.clone().to(torch.float16)
        
        original_weight = original_linear.weight.data.clone()
        
        weight_quantized, weight_scale = quantize_rowwise_int4(original_weight.to(torch.float16))
        
        quantized_instance.weight_scale = weight_scale.contiguous()
        quantized_instance.weight.data = weight_quantized.contiguous()
        
        return quantized_instance

    def __repr__(self):
        return f'Quantized_Linear({self.input_features}, {self.output_features}, bias={self.bias is not None})'

def module_replacer(root_module):
    module_registry = {name: module for name, module in root_module.named_modules()}
    
    for module_name, module_instance in module_registry.items():
        if isinstance(module_instance, torch.nn.Linear):
            
            last_dot_index = module_name.rfind(".")
            if last_dot_index == -1:
                parent_module = module_registry[""]
            else:
                parent_module = module_registry[module_name[:last_dot_index]]
            
            quantized_version = QuantizedLinearModule.convert_from_linear(module_instance)
            
            setattr(parent_module, module_name[last_dot_index + 1 :], quantized_version)
            
            print(f"replace layer {module_name} with {quantized_version}")
            del module_instance

def benchmark():
    loaded_model = AutoModelForCausalLM.from_pretrained(
        "unsloth/Llama-3.2-1B-Instruct",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        device_map='cuda:0'
    )
    
    tokenizer_instance = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
    
    if not tokenizer_instance.pad_token_id:
        tokenizer_instance.pad_token = tokenizer_instance.eos_token
    
    target_device = torch.device("cuda:0")
    print(f"Using device: {target_device}")
    print(f"GPU: {torch.cuda.get_device_name(target_device)}")
    
    start_timer = torch.cuda.Event(enable_timing=True)
    end_timer = torch.cuda.Event(enable_timing=True)
    
    start_timer.record()
    perplexity_score = evaluate_model(loaded_model, tokenizer_instance)
    end_timer.record()
    torch.cuda.synchronize()
    print("ppl:", perplexity_score)
    print("time (s)"f"{start_timer.elapsed_time(end_timer) / 1000: .2f}")
    
    module_replacer(loaded_model.model)
    gc.collect()
    torch.cuda.empty_cache()
    
    start_timer.record()
    perplexity_score_quantized = evaluate_model(loaded_model, tokenizer_instance)
    end_timer.record()
    torch.cuda.synchronize()
    start_timer.elapsed_time(end_timer)
    print("ppl:", perplexity_score_quantized)
    print("time (s)"f"{start_timer.elapsed_time(end_timer) / 1000: .2f}")

In [2]:
benchmark()

config.json:   0%|          | 0.00/894 [00:00<?, ?B/s]

2025-11-17 21:26:52.707078: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763414813.135884      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763414813.247823      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

Using device: cuda:0
GPU: Tesla T4


README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (289077 > 131072). Running this sequence through the model will result in indexing errors


nsamples 141
ppl: 13.159286499023438
time (s) 48.89
replace layer layers.0.self_attn.q_proj with Quantized_Linear(2048, 2048, bias=False)
replace layer layers.0.self_attn.k_proj with Quantized_Linear(2048, 512, bias=False)
replace layer layers.0.self_attn.v_proj with Quantized_Linear(2048, 512, bias=False)
replace layer layers.0.self_attn.o_proj with Quantized_Linear(2048, 2048, bias=False)
replace layer layers.0.mlp.gate_proj with Quantized_Linear(2048, 8192, bias=False)
replace layer layers.0.mlp.up_proj with Quantized_Linear(2048, 8192, bias=False)
replace layer layers.0.mlp.down_proj with Quantized_Linear(8192, 2048, bias=False)
replace layer layers.1.self_attn.q_proj with Quantized_Linear(2048, 2048, bias=False)
replace layer layers.1.self_attn.k_proj with Quantized_Linear(2048, 512, bias=False)
replace layer layers.1.self_attn.v_proj with Quantized_Linear(2048, 512, bias=False)
replace layer layers.1.self_attn.o_proj with Quantized_Linear(2048, 2048, bias=False)
replace layer lay