In [None]:
!mamba create -n py311 -y
!source /opt/conda/bin/activate py311 && mamba install python=3.11 jupyter mamba -y

!sudo rm /opt/conda/bin/python3
!sudo ln -sf /opt/conda/envs/py311/bin/python3 /opt/conda/bin/python3
!sudo rm /opt/conda/bin/python3.7
!sudo ln -sf /opt/conda/envs/py311/bin/python3 /opt/conda/bin/python3.7
!sudo rm /opt/conda/bin/python
!sudo ln -sf /opt/conda/envs/py311/bin/python3 /opt/conda/bin/python

In [None]:
!python --version

In [1]:
!pip install  bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install  cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install  unsloth

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting accelerate
  Downloading accelerate-1.5.2-py3-none-any.whl.metadata (19 kB)
Collecting xformers==0.0.29
  Downloading xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting peft
  Downloading peft-0.14.0-py3-none-any.whl.metadata (13 kB)
Collecting trl
  Downloading trl-0.15.2-py3-none-any.whl.metadata (11 kB)
Collecting triton
  Downloading triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting numpy (from xformers==0.0.29)
  Downloading numpy-2.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting torch==2.5.1 (from xformers==0.0.29)
  Downloading torch-2.5.1-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting filelock (from torch==2.5.1->xformers==0.0.29)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting networkx (from

In [3]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
import unsloth
from transformers import set_seed
import time
import inspect
import os
from bitsandbytes.nn import Linear4bit
from unsloth.kernels.utils import fast_dequantize
from transformers.activations import ACT2FN
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True, atol=0.01, rtol=0.1)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [4]:
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.float16),
        (3, 2048, 4096, 14336, 3408, torch.float16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, unsloth_dequantize), mlp(X), _F(_C()), dt)
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed

In [5]:
from unsloth.kernels.utils import fast_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)
test_dequantize(unsloth_dequantize)

4.673144340515137

In [6]:
# writing custom triton kernal for blockwise dequantization
# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/86b6c37a8ad448230cedb60753f63150b603a112/bitsandbytes/functional.py#L958
import torch
from triton import jit
import triton
import triton.language as tl

# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/e772a9e8723cfc2036fecc830c328ad3b9705250/csrc/kernels.cu#L116
lut_table = torch.tensor([-1.0, 
                       -0.6961928009986877, 
                       -0.5250730514526367, 
                       -0.39491748809814453,
                       -0.28444138169288635, 
                       -0.18477343022823334, 
                       -0.09105003625154495, 
                       0,  
                       0.07958029955625534, 
                       0.16093020141124725, 
                       0.24611230194568634, 
                       0.33791524171829224,
                       0.44070982933044434, 
                       0.5626170039176941, 
                       0.7229568362236023, 
                       1.0], 
                      dtype = torch.float32)


@triton.jit
def _your_dequantize_nf4_kernel(
    output_ptr,
    quant_weights_ptr, weights_code_ptr,
    absmax_residue, absmax_scale_ptr, absmax_quant_ptr, absmax_code_ptr,
    blocksize_weights, blocksize_absmax, numel_weights,
    BLOCK_SIZE: tl.constexpr
):
    input_offset = tl.program_id(axis = 0)* BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    input_mask = input_offset < numel_weights

    numel_absmax = numel_weights // blocksize_weights
    absmax_offset = input_offset // blocksize_weights
    absmax_mask = absmax_offset < numel_absmax

    quant_absmax = tl.load(
        absmax_quant_ptr + absmax_offset,
        mask = absmax_mask
    )

    absmax_scale_offset = input_offset // blocksize_weights // blocksize_absmax
    absmax_scale_mask = absmax_offset < numel_absmax
    absmax_scale = tl.load(absmax_scale_ptr + absmax_scale_offset, 
                           mask = absmax_scale_mask) 

    absmax_decode = tl.load(absmax_code_ptr + quant_absmax, 
                            mask = absmax_mask)
    absmax = tl.fma(absmax_decode ,absmax_scale ,absmax_residue)

    weights_quant_offset = tl.max_contiguous(
        tl.multiple_of(input_offset, BLOCK_SIZE), BLOCK_SIZE
    )
    weights_quant = tl.load(quant_weights_ptr + weights_quant_offset, mask=input_mask)

    weights_upper = tl.load(weights_code_ptr + (weights_quant >> 4)) * absmax
    weights_lower = tl.load(weights_code_ptr + (weights_quant & 0x0F)) *absmax

    weights = tl.reshape(
        tl.interleave(weights_upper, weights_lower), 2 * BLOCK_SIZE, can_reorder=False
    )

    output_offset = tl.program_id(0) * 2 * BLOCK_SIZE + tl.arange(0, 2 * BLOCK_SIZE)
    output_mask = output_offset < 2 * numel_weights
    tl.store(output_ptr + output_offset, weights, mask=output_mask)
    pass


def _your_dequantize_nf4(weight, quant_state):
    dtype_in = dtype_out = quant_state.dtype
    if dtype_in == torch.bfloat16:
        device = torch.cuda.current_device()
        major, _ = torch.cuda.get_device_capability(device)
        dtype_in = torch.float32 if major < 8 else torch.bfloat16
    
    weights = torch.empty(
        quant_state.shape, 
        dtype = dtype_in,
        device=weight.device
    )
    
    weights_code = lut_table.to(device = weight.device)
    launch_grid = lambda meta: (triton.cdiv(weight.numel(), meta["BLOCK_SIZE"]),)

    _your_dequantize_nf4_kernel[launch_grid](
        output_ptr = weights,
        quant_weights_ptr = weight,
        weights_code_ptr = weights_code,
        absmax_residue = quant_state.offset.item(),
        absmax_scale_ptr = quant_state.state2.absmax, 
        absmax_quant_ptr = quant_state.absmax, 
        absmax_code_ptr = quant_state.state2.code,
        blocksize_weights = quant_state.blocksize // 2, 
        blocksize_absmax = quant_state.state2.blocksize, 
        numel_weights = weight.numel(),
        BLOCK_SIZE = 512
    )
    
    return weights.to(dtype_out)

def your_dequantize_nf4(weight):
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)

In [7]:
test_dequantize(your_dequantize_nf4)

3.767548084259033

In [8]:
time_taken = 0
sample_runs = 5
for sample_run in range(sample_runs):
    val = (test_dequantize(unsloth_dequantize)/test_dequantize(your_dequantize_nf4))
    print(val)
    time_taken += val
print("Ratio Of Speed between Unsloth & Custom Code Is", time_taken/sample_runs)

1.2825994649800163
1.2915359332444512
1.320830369359403
1.336775728190073
1.3971278842952632
Ratio Of Speed between Unsloth & Custom Code Is 1.3257738760138413


In [9]:
# writing custom triton kernal for blockwise dequantization
# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/86b6c37a8ad448230cedb60753f63150b603a112/bitsandbytes/functional.py#L958
import torch
from triton import jit
import triton
import triton.language as tl


# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/e772a9e8723cfc2036fecc830c328ad3b9705250/csrc/kernels.cu#L116
lut_table = torch.tensor([-1.0, 
                       -0.6961928009986877, 
                       -0.5250730514526367, 
                       -0.39491748809814453,
                       -0.28444138169288635, 
                       -0.18477343022823334, 
                       -0.09105003625154495, 
                       0,  
                       0.07958029955625534, 
                       0.16093020141124725, 
                       0.24611230194568634, 
                       0.33791524171829224,
                       0.44070982933044434, 
                       0.5626170039176941, 
                       0.7229568362236023, 
                       1.0], 
                      dtype = torch.float32)

@triton.jit
def _your_dequantize_nf4_kernel_compiled(
    output_ptr,
    quant_weights_ptr, weights_code_ptr,
    absmax_residue, absmax_scale_ptr, absmax_quant_ptr, absmax_code_ptr,
    blocksize_weights, blocksize_absmax, numel_weights,
    BLOCK_SIZE: tl.constexpr
):
    input_offset = tl.program_id(axis = 0)* BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    input_mask = input_offset < numel_weights

    numel_absmax = numel_weights // blocksize_weights
    absmax_offset = input_offset // blocksize_weights
    absmax_mask = absmax_offset < numel_absmax

    quant_absmax = tl.load(
        absmax_quant_ptr + absmax_offset,
        mask = absmax_mask
    )

    absmax_scale_offset = input_offset // blocksize_weights // blocksize_absmax
    absmax_scale_mask = absmax_offset < numel_absmax
    absmax_scale = tl.load(absmax_scale_ptr + absmax_scale_offset, 
                           mask = absmax_scale_mask) 

    absmax_decode = tl.load(absmax_code_ptr + quant_absmax, 
                            mask = absmax_mask)
    absmax = tl.fma(absmax_decode ,absmax_scale ,absmax_residue)

    weights_quant_offset = tl.max_contiguous(
        tl.multiple_of(input_offset, BLOCK_SIZE), BLOCK_SIZE
    )
    weights_quant = tl.load(quant_weights_ptr + weights_quant_offset, mask=input_mask)

    weights_upper = tl.load(weights_code_ptr + (weights_quant >> 4)) * absmax
    weights_lower = tl.load(weights_code_ptr + (weights_quant & 0x0F)) *absmax

    weights = tl.reshape(
        tl.interleave(weights_upper, weights_lower), 2 * BLOCK_SIZE, can_reorder=False
    )

    output_offset = tl.program_id(0) * 2 * BLOCK_SIZE + tl.arange(0, 2 * BLOCK_SIZE)
    output_mask = output_offset < 2 * numel_weights
    tl.store(output_ptr + output_offset, weights, mask=output_mask)
    pass

torch._dynamo.config.capture_scalar_outputs = True

torch_compile_options = torch_compile_options = {
    "epilogue_fusion"   : True,
    "max_autotune"      : True,
    "shape_padding"     : True,
    "trace.enabled"     : True,
    "triton.cudagraphs" : False,
}

@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
def _your_dequantize_nf4_compiled(weight, quant_state, residue):
    # Fix for T4 gpu 
    dtype_in = dtype_out = quant_state.dtype
    if dtype_in == torch.bfloat16:
        device = torch.cuda.current_device()
        major, _ = torch.cuda.get_device_capability(device)
        dtype_in = torch.float32 if major < 8 else torch.bfloat16
    
    weights = torch.empty(
        quant_state.shape, 
        dtype = dtype_in,
        device=weight.device
    )
    
    weights_code = lut_table.to(device = weight.device)
    launch_grid = lambda meta: (triton.cdiv(weight.numel(), meta["BLOCK_SIZE"]),)

    _your_dequantize_nf4_kernel_compiled[launch_grid](
        output_ptr = weights,
        quant_weights_ptr = weight,
        weights_code_ptr = weights_code,
        absmax_residue = residue,
        absmax_scale_ptr = quant_state.state2.absmax, 
        absmax_quant_ptr = quant_state.absmax, 
        absmax_code_ptr = quant_state.state2.code,
        blocksize_weights = quant_state.blocksize // 2, 
        blocksize_absmax = quant_state.state2.blocksize, 
        numel_weights = weight.numel(),
        BLOCK_SIZE = 512
    )
    
    return weights.to(dtype_out)

def your_dequantize_nf4_compiled(weight):
    # Fix for torch compile
    residue = weight.weight.quant_state.offset.item()
    return _your_dequantize_nf4_compiled(weight.weight.data, weight.weight.quant_state, residue)


try:
    test_dequantize(your_dequantize_nf4_compiled)
    print("torch.compile works with triton code -- test passed")
except:
    print("torch.compile does not work with triton code -- test failed")

torch.compile works with triton code -- test passed


In [10]:
time_taken = 0
sample_runs = 5
for sample_run in range(sample_runs):
    val = (test_dequantize(unsloth_dequantize)/test_dequantize(your_dequantize_nf4_compiled))
    print(val)
    time_taken += val
print("Ratio Of Speed between Unsloth & Custom Code Is", time_taken/sample_runs)

1.0855464290407564
1.0946592390159167
1.1000718816097488
1.0859838132329163
1.0819061476592704
Ratio Of Speed between Unsloth & Custom Code Is 1.0896335021117216
