**Link to Colab :** [NF4 - Triton](https://colab.research.google.com/drive/1zp8zvbRl1V3_WKwSlbWnf5I6p30GdFkJ?usp=sharing)

### Installing Required Libraries

In [1]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

### Importing Libraries

In [2]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
import unsloth
from transformers import set_seed
import time
import inspect
import os
from bitsandbytes.nn import Linear4bit
from unsloth.kernels.utils import fast_dequantize
from transformers.activations import ACT2FN
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True, atol=0.01, rtol=0.1)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


### Initial Setup

In [3]:
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.float16),
        (3, 2048, 4096, 14336, 3408, torch.float16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, unsloth_dequantize), mlp(X), _F(_C()), dt)
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

### Unsloth Dequantize Speed Test

In [4]:
test_dequantize(unsloth_dequantize)

6.7735817432403564

### Custom Triton Code

In [5]:
from triton import jit
import triton
import triton.language as tl

@triton.jit
def lookup_const(x):
    result = tl.where(x == 0, -1.0, 0.0)
    result = tl.where(x == 1, -0.6961928009986877, result)
    result = tl.where(x == 2, -0.5250730514526367, result)
    result = tl.where(x == 3, -0.39491748809814453, result)
    result = tl.where(x == 4, -0.28444138169288635, result)
    result = tl.where(x == 5, -0.18477343022823334, result)
    result = tl.where(x == 6, -0.09105003625154495, result)
    result = tl.where(x == 7,  0.0, result)
    result = tl.where(x == 8,  0.07958029955625534, result)
    result = tl.where(x == 9,  0.16093020141124725, result)
    result = tl.where(x == 10, 0.24611230194568634, result)
    result = tl.where(x == 11, 0.33791524171829224, result)
    result = tl.where(x == 12, 0.44070982933044434, result)
    result = tl.where(x == 13, 0.5626170039176941, result)
    result = tl.where(x == 14, 0.7229568362236023, result)
    result = tl.where(x == 15, 1.0, result)
    return result

@triton.jit
def _your_dequantize_nf4_kernel(weight_ptr, out_ptr, absmax_ptr, absmax2_ptr, code2_ptr, residue, n_elements, BLOCK_SIZE: tl.constexpr):

    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    out_offsets = 2 * pid * BLOCK_SIZE + tl.arange(0, 2*BLOCK_SIZE)
    mask = offsets < n_elements
    out_mask = out_offsets < 2 * n_elements

    weight_val = tl.load(weight_ptr + offsets, mask=mask)

    high_nibble = weight_val >> 4
    low_nibble  = weight_val & 0x0F

    high_nibble_f32  = lookup_const(high_nibble)
    low_nibble_f32  = lookup_const(low_nibble)

    current_absmax = tl.load(absmax_ptr + (offsets >> 5), mask = mask)
    decoded_absmax = tl.load(code2_ptr + current_absmax, mask = mask)

    current_absmax2 = tl.load(absmax2_ptr + (offsets >> 13), mask=mask)

    scaling = tl.fma(decoded_absmax, current_absmax2, residue)

    high_val = high_nibble_f32 * scaling
    low_val = low_nibble_f32 * scaling

    out = tl.interleave(high_val, low_val)

    tl.store(out_ptr + out_offsets,  out, mask=out_mask)

def _your_dequantize_nf4(weight, quant_state, offset):

    n_elements = weight.numel()
    out = torch.empty(2*n_elements, device=weight.device, dtype=quant_state.dtype)

    BLOCK_SIZE = 1024
    grid = lambda meta: ((n_elements + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)

    _your_dequantize_nf4_kernel[grid](
        weight,
        out,
        quant_state.absmax,
        quant_state.state2.absmax,
        quant_state.state2.code,
        quant_state.offset.item(),
        n_elements,
        BLOCK_SIZE
    )
    return out.view(quant_state.shape)

def your_dequantize_nf4(weight):
    offset = weight.weight.quant_state.offset
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state, offset)

### Custom Code Speed Test

In [6]:
test_dequantize(your_dequantize_nf4)

4.116127014160156

In [9]:
time_taken = 0
sample_runs = 5
for sample_run in range(sample_runs):
    val = (test_dequantize(unsloth_dequantize)/test_dequantize(your_dequantize_nf4))
    print(val)
    time_taken += val
print("Ratio Of Speed between Unsloth & Custom Code Is", time_taken/sample_runs)

1.1454133907231538
1.1647778856211157
1.417303095420293
1.1747517897251316
1.1881643938005837
Ratio Of Speed between Unsloth & Custom Code Is 1.2180821110580555


### Compilation Check With `torch.compile`

In [10]:
from triton import jit
import triton
import triton.language as tl

@triton.jit
def lookup_const_compiled(x):
    result = tl.where(x == 0, -1.0, 0.0)
    result = tl.where(x == 1, -0.6961928009986877, result)
    result = tl.where(x == 2, -0.5250730514526367, result)
    result = tl.where(x == 3, -0.39491748809814453, result)
    result = tl.where(x == 4, -0.28444138169288635, result)
    result = tl.where(x == 5, -0.18477343022823334, result)
    result = tl.where(x == 6, -0.09105003625154495, result)
    result = tl.where(x == 7,  0.0, result)
    result = tl.where(x == 8,  0.07958029955625534, result)
    result = tl.where(x == 9,  0.16093020141124725, result)
    result = tl.where(x == 10, 0.24611230194568634, result)
    result = tl.where(x == 11, 0.33791524171829224, result)
    result = tl.where(x == 12, 0.44070982933044434, result)
    result = tl.where(x == 13, 0.5626170039176941, result)
    result = tl.where(x == 14, 0.7229568362236023, result)
    result = tl.where(x == 15, 1.0, result)
    return result

@triton.jit
def _your_dequantize_nf4_kernel_compiled(weight_ptr, out_ptr, absmax_ptr, absmax2_ptr, code2_ptr, residue, n_elements, BLOCK_SIZE: tl.constexpr):

    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    out_offsets = 2 * pid * BLOCK_SIZE + tl.arange(0, 2*BLOCK_SIZE)
    mask = offsets < n_elements
    out_mask = out_offsets < 2 * n_elements

    weight_val = tl.load(weight_ptr + offsets, mask=mask)

    high_nibble = weight_val >> 4
    low_nibble  = weight_val & 0x0F

    high_nibble_f32  = lookup_const_compiled(high_nibble)
    low_nibble_f32  = lookup_const_compiled(low_nibble)

    current_absmax = tl.load(absmax_ptr + (offsets >> 5), mask = mask)
    decoded_absmax = tl.load(code2_ptr + current_absmax, mask = mask)

    current_absmax2 = tl.load(absmax2_ptr + (offsets >> 13), mask=mask)

    scaling = tl.fma(decoded_absmax, current_absmax2, residue)

    high_val = high_nibble_f32 * scaling
    low_val = low_nibble_f32 * scaling

    out = tl.interleave(high_val, low_val)

    tl.store(out_ptr + out_offsets,  out, mask=out_mask)

torch._dynamo.config.capture_scalar_outputs = True

torch_compile_options = torch_compile_options = {
    "epilogue_fusion"   : True,
    "max_autotune"      : True,
    "shape_padding"     : True,
    "trace.enabled"     : True,
    "triton.cudagraphs" : False,
}

@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def _your_dequantize_nf4_compiled(weight, quant_state, offset):

    n_elements = weight.numel()
    out = torch.empty(2*n_elements, device=weight.device, dtype=quant_state.dtype)

    BLOCK_SIZE = 1024
    grid = lambda meta: ((n_elements + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)

    _your_dequantize_nf4_kernel_compiled[grid](
        weight,
        out,
        quant_state.absmax,
        quant_state.state2.absmax,
        quant_state.state2.code,
        offset,
        n_elements,
        BLOCK_SIZE
    )
    return out.view(quant_state.shape)

def your_dequantize_nf4_compiled(weight):
    offset = weight.weight.quant_state.offset.item()
    return _your_dequantize_nf4_compiled(weight.weight.data, weight.weight.quant_state, offset)

try:
    test_dequantize(your_dequantize_nf4_compiled)
    print("torch.compile works with triton code -- test passed")
except:
    print("torch.compile does not work with triton code -- test failed")

torch.compile works with triton code -- test passed
