# Challenge A: NF4 to Triton Benchmark

This notebook verifies the correctness and performance of the custom NF4 to Triton dequantization kernel.

In [None]:
!pip install triton bitsandbytes unsloth -U --quiet

In [None]:
import torch
import triton
import triton.language as tl
from bitsandbytes.functional import dequantize_nf4
import time

# NF4 Triton Kernel Implementation
@triton.jit
def _your_dequantize_nf4_kernel(
    weight_ptr,
    absmax_ptr,
    code_ptr,
    out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    
    byte_offsets = (block_start // 2) + tl.arange(0, BLOCK_SIZE // 2)
    mask = byte_offsets < (n_elements // 2)
    
    packed_weights = tl.load(weight_ptr + byte_offsets, mask=mask)
    
    low_nibble = (packed_weights & 0xF).to(tl.int32)
    high_nibble = (packed_weights >> 4).to(tl.int32)
    
    val_low = tl.load(code_ptr + low_nibble)
    val_high = tl.load(code_ptr + high_nibble)
    
    abs_low = tl.load(absmax_ptr + (block_start + tl.arange(0, BLOCK_SIZE // 2) * 2) // 64, mask=mask)
    abs_high = tl.load(absmax_ptr + (block_start + tl.arange(0, BLOCK_SIZE // 2) * 2 + 1) // 64, mask=mask)
    
    val_low = val_low * abs_low
    val_high = val_high * abs_high
    
    out_offsets_low = block_start + tl.arange(0, BLOCK_SIZE // 2) * 2
    out_offsets_high = out_offsets_low + 1
    
    tl.store(out_ptr + out_offsets_low, val_low, mask=mask)
    tl.store(out_ptr + out_offsets_high, val_high, mask=mask)

def _your_dequantize_nf4(weight, quant_state):
    n_elements = weight.numel() * 2
    out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=weight.device)
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    _your_dequantize_nf4_kernel[grid](
        weight, quant_state.absmax, quant_state.code, out, n_elements, BLOCK_SIZE=BLOCK_SIZE,
    )
    return out

def your_dequantize_nf4(weight_param):
    return _your_dequantize_nf4(weight_param.weight.data, weight_param.weight.quant_state)

In [None]:
from unsloth.kernels import fast_dequantize
from bitsandbytes.nn import LinearNF4

# Setup
device = "cuda"
shape = (4096, 4096)
linear = LinearNF4(shape[1], shape[0], bias=False).to(device)

# Correctness
out_ref = fast_dequantize(linear.weight.data, linear.weight.quant_state)
out_custom = your_dequantize_nf4(linear)
correct = torch.allclose(out_ref, out_custom, atol=1e-5)
print(f"Correctness check: {correct}")

# Benchmark
def benchmark(fn, name, iters=100):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(iters):
        fn()
    torch.cuda.synchronize()
    end = time.time()
    print(f"{name}: {(end-start)/iters*1000:.4f} ms")
    return (end-start)/iters

t_ref = benchmark(lambda: fast_dequantize(linear.weight.data, linear.weight.quant_state), "Unsloth fast_dequantize")
t_custom = benchmark(lambda: your_dequantize_nf4(linear), "Your Triton kernel")
speedup = t_ref / t_custom
print(f"Speedup: {speedup:.4f}x")
if speedup >= 1.15:
    print("PASSED: Speedup is >= 1.15x")
else:
    print("FAILED: Speedup is < 1.15x")