In [None]:
!nvidia-smi


Fri Apr 11 13:22:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   51C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!pip install -q triton


In [None]:
import torch
import triton

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Triton version:", triton.__version__)
print("CUDA device:", torch.cuda.get_device_name(0))


Torch version: 2.6.0+cu124
CUDA available: True
Triton version: 3.2.0
CUDA device: Tesla T4


In [None]:
import torch

def simulate_nf4_input(batch_size=2, num_tokens=512, dim=4096, dtype=torch.float16):
    # Step 1: Create dummy float weights
    weights = torch.randn(dim, dim, device="cuda", dtype=dtype)

    # Step 2: Normalize and quantize to simulate NF4 (4-bit: 0-15)
    absmax = weights.abs().amax(dim=-1, keepdim=True)
    normed = weights / (absmax + 1e-5)
    quantized = torch.clamp(((normed + 1) * 7.5).round(), 0, 15).to(torch.uint8)

    # Step 3: Pack two 4-bit values into 1 byte
    low  = quantized[..., ::2]  # even indices
    high = quantized[..., 1::2]  # odd indices
    packed = (high << 4) | low
    packed = packed.contiguous()

    # Step 4: Simulate absmax scale per row (like BNB does)
    scale = absmax.squeeze(-1)

    # Step 5: Validate shapes
    print("NF4 packed shape:", packed.shape, "| dtype:", packed.dtype)
    print("Absmax shape:", scale.shape, "| dtype:", scale.dtype)

    return packed, scale

# Test simulate
nf4_weights, absmax = simulate_nf4_input()


NF4 packed shape: torch.Size([4096, 2048]) | dtype: torch.uint8
Absmax shape: torch.Size([4096]) | dtype: torch.float16


In [None]:
# 16 possible NF4 values mapped to floats in [-1, 1]
NF4_LUT = torch.tensor([
    -1.0, -0.696, -0.478, -0.335, -0.239, -0.168, -0.112, -0.066,
     0.066,  0.112,  0.168,  0.239,  0.335,  0.478,  0.696,  1.0
], device="cuda", dtype=torch.float16)  # use bfloat16 later if needed


In [None]:
import triton
import triton.language as tl

@triton.jit
def nf4_dequant_kernel(
    packed_ptr,       # uint8 [M, N//2]
    absmax_ptr,       # float16 [M]
    lut_ptr,          # float16 [16]
    out_ptr,          # output [M, N] float16 or bfloat16
    M, N,
    BLOCK_N: tl.constexpr,
    DTYPE: tl.constexpr,
):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_N)

    half_cols = N // 2
    # Bounds check: we skip rows out of range
    if row_idx >= M:
        return

    # ---- Load absmax ----
    absmax = tl.load(absmax_ptr + row_idx).to(tl.float32)

    # ---- Load packed 4-bit weights ----
    cols = col_offsets + tl.program_id(1) * BLOCK_N
    valid = cols < half_cols

    # Each element holds 2 values → [M, N//2]
    packed_vals = tl.load(packed_ptr + row_idx * half_cols + cols, mask=valid, other=0)

    low_bits = packed_vals & 0x0F  # even
    high_bits = (packed_vals >> 4) & 0x0F  # odd

    # Lookup float value from LUT
    low_fp = tl.load(lut_ptr + low_bits)
    high_fp = tl.load(lut_ptr + high_bits)

    # Scale by absmax
    low_fp = low_fp * absmax
    high_fp = high_fp * absmax

    # Store into output: each byte becomes two floats
    base_out = out_ptr + row_idx * N + 2 * cols
    if DTYPE == tl.float16:
        low_fp = low_fp.to(tl.float16)
        high_fp = high_fp.to(tl.float16)
    else:
        low_fp = low_fp.to(tl.float16)
        high_fp = high_fp.to(tl.float16)

    tl.store(base_out + 0, low_fp, mask=valid)
    tl.store(base_out + 1, high_fp, mask=valid)


In [None]:
def dequant_nf4_triton(packed: torch.Tensor, absmax: torch.Tensor, dtype=torch.float16):
    assert packed.dtype == torch.uint8
    assert absmax.dtype in [torch.float16, torch.bfloat16]
    assert packed.shape[0] == absmax.shape[0]
    M, N_half = packed.shape
    N = N_half * 2

    out = torch.empty((M, N), device=packed.device, dtype=dtype)

    # Launch Triton kernel
    BLOCK_N = 256  # Can tune this later
    grid = (M, (N_half + BLOCK_N - 1) // BLOCK_N)

    # Cast LUT dtype if needed
    lut = NF4_LUT.to(dtype)

    nf4_dequant_kernel[grid](
        packed_ptr = packed,
        absmax_ptr = absmax,
        lut_ptr = lut,
        out_ptr = out,
        M = M,
        N = N,
        BLOCK_N = BLOCK_N,
        DTYPE = 0 if dtype == torch.float16 else 1,  # Triton uses constexpr switches
    )

    return out.to(dtype) if dtype == torch.bfloat16 else out


In [None]:
def dequant_nf4_reference(packed: torch.Tensor, absmax: torch.Tensor, dtype=torch.float16):
    M, N_half = packed.shape
    N = N_half * 2

    out = torch.empty((M, N), device=packed.device, dtype=dtype)
    lut = NF4_LUT.to(dtype)

    for i in range(M):
        row_absmax = absmax[i].item()
        row = packed[i]
        low = row & 0x0F
        high = (row >> 4) & 0x0F

        low_fp = lut[low.long()] * row_absmax
        high_fp = lut[high.long()] * row_absmax

        # Interleave low and high into output
        out_row = torch.empty(N, device=packed.device, dtype=dtype)
        out_row[0::2] = low_fp
        out_row[1::2] = high_fp
        out[i] = out_row

    return out


In [None]:
# Simulate NF4 packed data
M, N = 4096, 2048 * 2  # 4096 x 2048 packed, means original is 4096 x 4096
packed = torch.randint(0, 256, (M, N // 2), dtype=torch.uint8, device="cuda")

# Simulate corresponding absmax values
absmax = torch.rand(M, dtype=torch.float16, device="cuda") * 2.0  # Keep scale reasonable


In [None]:
out_triton = dequant_nf4_triton(packed, absmax, dtype=torch.float16)
out_ref = dequant_nf4_reference(packed, absmax, dtype=torch.float16)

print("Max absolute difference:", (out_triton - out_ref).abs().max())
print("Mean absolute difference:", (out_triton - out_ref).abs().mean())


Max absolute difference: tensor(0., device='cuda:0', dtype=torch.float16)
Mean absolute difference: tensor(0., device='cuda:0', dtype=torch.float16)


In [None]:
# Step 1 - Define NF4 lookup table (LUT)
NF4_LUT = torch.tensor([
    -1.0000, -0.6960, -0.4780, -0.3350,
    -0.2390, -0.1680, -0.1120, -0.0660,
     0.0660,  0.1120,  0.1680,  0.2390,
     0.3350,  0.4780,  0.6960,  1.0000,
], dtype=torch.float16, device="cuda")


In [None]:
print(NF4_LUT)
assert NF4_LUT.shape == (16,)


tensor([-1.0000, -0.6958, -0.4780, -0.3350, -0.2390, -0.1680, -0.1120, -0.0660,
         0.0660,  0.1120,  0.1680,  0.2390,  0.3350,  0.4780,  0.6958,  1.0000],
       device='cuda:0', dtype=torch.float16)


In [None]:
NF4_LUT = torch.tensor([
    -1.0, -0.6958, -0.4780, -0.3350,
    -0.2390, -0.1680, -0.1120, -0.0660,
     0.0660,  0.1120,  0.1680,  0.2390,
     0.3350,  0.4780,  0.6958,  1.0
], device="cuda", dtype=torch.float16)


In [None]:
import triton
import triton.language as tl

@triton.jit
def dequant_nf4_kernel(
    x_q_ptr, x_scale_ptr, x_out_ptr,
    M, N,
    lut_ptr,
    BLOCK: tl.constexpr
):
    row = tl.program_id(0)
    if row >= M:
        return

    offs = tl.arange(0, BLOCK)
    mask = offs < N

    x_q = tl.load(x_q_ptr + row * N + offs, mask=mask).to(tl.uint8)

    # Unpack 2 NF4 values per byte
    low  = (x_q & 0x0F).to(tl.int32)
    high = ((x_q >> 4) & 0x0F).to(tl.int32)

    # LUT dequant
    deq_low  = tl.load(lut_ptr + low)
    deq_high = tl.load(lut_ptr + high)

    # Apply scale
    scale = tl.load(x_scale_ptr + row)
    deq_low  *= scale
    deq_high *= scale

    # Compute final output indices
    out_ptr = x_out_ptr + row * N * 2
    out_offs = offs * 2

    # Store interleaved dequantized values
    tl.store(out_ptr + out_offs + 0, deq_low, mask=mask)
    tl.store(out_ptr + out_offs + 1, deq_high, mask=mask)


In [None]:
def dequant_nf4_triton(x_q, x_scale, dtype=torch.float16):
    M, N = x_q.shape
    assert x_q.dtype == torch.uint8
    assert x_scale.shape == (M,)
    assert x_q.device == x_scale.device

    BLOCK = N  # Use full row for BLOCK
    x_out = torch.empty((M, N * 2), device=x_q.device, dtype=dtype)

    dequant_nf4_kernel[(M,)](
        x_q, x_scale, x_out,
        M, N,
        NF4_LUT.to(dtype=torch.float16),
        BLOCK=BLOCK,
        num_warps=4,
        num_stages=1
    )

    return x_out


In [None]:
M, N = 4, 8
x_q = torch.randint(0, 256, (M, N), dtype=torch.uint8, device="cuda")
x_scale = torch.rand(M, device="cuda", dtype=torch.float32)

out_fp16 = dequant_nf4_triton(x_q, x_scale, dtype=torch.float16)
print(out_fp16)


tensor([[ 0.0411,  0.1196,  0.0113,  0.0822, -0.1196, -0.0289, -0.0576,  0.0822,
          0.1196, -0.0576, -0.0411,  0.1719, -0.0113, -0.0576,  0.0411, -0.0192],
        [ 0.0077,  0.0232,  0.0077,  0.0331, -0.0116,  0.0232,  0.0692, -0.0165,
          0.0077, -0.0165, -0.0232,  0.0481,  0.0046, -0.0331, -0.0046,  0.0165],
        [-0.3552, -0.0490, -0.1776, -0.1776, -0.1776,  0.1776, -0.7432, -0.2489,
          0.2489, -0.1248,  0.1248,  0.0490, -0.5171,  0.3552, -0.3552,  0.2489],
        [-0.0343, -0.0717, -0.0047,  0.0120, -0.0343, -0.0080,  0.0120, -0.0047,
          0.0717, -0.0047, -0.0172,  0.0172, -0.0080,  0.0343, -0.0080, -0.0047]],
       device='cuda:0', dtype=torch.float16)


In [None]:
# !pip install -U triton

import torch
import triton
import triton.language as tl
import time

# 1. Kernel
@triton.jit
def dequantize_nf4_kernel(
    weight_ptr,     # [rows, cols // 2]
    absmax_ptr,     # [rows]
    lut_ptr,        # [16]
    output_ptr,     # [rows, cols]
    rows, cols,
    BLOCK_COLS: tl.constexpr,
):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_COLS)
    row_offset = row_idx * cols
    out_ptrs = output_ptr + row_offset + col_offsets

    cols_half = cols // 2
    mask = col_offsets < cols

    qweight_ptrs = weight_ptr + row_idx * cols_half + (col_offsets // 2)
    packed_vals = tl.load(qweight_ptrs, mask=(col_offsets // 2) < cols_half, other=0)

    is_even = (col_offsets % 2) == 0
    nf4_indices = tl.where(is_even, packed_vals & 0x0F, (packed_vals >> 4) & 0x0F)
    dequant_vals = tl.load(lut_ptr + nf4_indices.to(tl.int32))

    absmax = tl.load(absmax_ptr + row_idx).to(dequant_vals.dtype)
    output_vals = dequant_vals * absmax
    tl.store(out_ptrs, output_vals, mask=mask)

# 2. Python wrapper
def get_nf4_lookup_table(dtype=torch.float16):
    base = torch.tensor([
        -1.0, -0.6962, -0.5257, -0.3949,
        -0.2847, -0.1848, -0.0916, 0.0,
         0.0916, 0.1848, 0.2847, 0.3949,
         0.5257, 0.6962, 1.0, 0.0
    ], dtype=torch.float32)
    return base.to(dtype)

def dequantize_nf4(qweight, absmax, dtype=torch.float16):
    rows, cols_half = qweight.shape
    cols = cols_half * 2
    BLOCK_COLS = 128

    output = torch.empty((rows, cols), dtype=dtype, device=qweight.device)
    lut = get_nf4_lookup_table(dtype).to(qweight.device)

    dequantize_nf4_kernel[(rows,)](
        qweight, absmax, lut, output,
        rows, cols,
        BLOCK_COLS=BLOCK_COLS
    )
    return output

# 3. Test function
def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.float16),
        (3, 2048, 4096, 14336, 3408, torch.float16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        torch.manual_seed(seed)
        torch.set_default_dtype(torch.float32)

        # Simulate 3 separate quantized weight matrices like Unsloth (up, gate, down)
        for _ in range(3):
            rows, cols = m, hd
            cols_half = cols // 2
            qweight = torch.randint(0, 256, (rows, cols_half), dtype=torch.uint8, device="cuda")
            absmax = torch.rand((rows,), dtype=dt, device="cuda")

            # Warmup
            for _ in range(2):
                _ = dequantize_fx(qweight, absmax, dtype=dt)
            torch.cuda.synchronize()

            # Benchmark
            start = time.time()
            for _ in range(1000):
                _ = dequantize_fx(qweight, absmax, dtype=dt)
            torch.cuda.synchronize()
            elapsed += time.time() - start

    return elapsed

# 4. Run
print("Running benchmark...")
total_time = test_dequantize(dequantize_nf4)
print(f"Total time: {total_time:.2f} seconds")


Running benchmark...
Total time: 0.96 seconds
