In [1]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install transformers tf-keras

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting xformers==0.0.29
  Downloading xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting trl
  Downloading trl-0.16.0-py3-none-any.whl.metadata (12 kB)
Downloading xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl (15.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.3/15.3 MB[0m [31m46.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bitsandbytes-0.45.4-py3-none-manylinux_2_24_x86_64.whl (76.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.0/76.0 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading trl-0.16.0-py3-none-any.whl (335 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m335.7/335.7 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xformers, trl, bitsandbytes
Successfully installed bitsandbytes-0.45.4 trl-0.16.0 xformers-0.

In [2]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True, atol=1e-1, rtol=1e-1)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [3]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        # (5,  777, 1024,  4096, 3409, torch.bfloat16),
        # (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth.kernels.utils import fast_dequantize


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.6.0+cu124)
    Python  3.11.11 (you have 3.11.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!


In [4]:
from unsloth.kernels.utils import fast_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)
test_dequantize(unsloth_dequantize)

1.1017403602600098

In [5]:
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
test_dequantize(peft_dequantize)

1.1417860984802246

In [6]:
from triton import jit
import triton
import triton.language as tl
import torch

@triton.jit
def _your_dequantize_nf4_kernel(w_ptr, w_out, abs_ptrs,
                               offset_ptr,
                               abs2_ptrs, code2,
                               block_size2, gsize, code, blocks: tl.constexpr, Br: tl.constexpr):
    """
    Optimized kernel for dequantizing NF4 weights
    """
    pid = tl.program_id(0)

    # Guard condition improves wave-front scheduling
    if pid < gsize:
        # Compute indices for absmax values with optimal coalescing
        absmax_group = pid*blocks + tl.arange(0, blocks)

        # Load absmax values with cache control
        absmax = tl.load(abs_ptrs + absmax_group, eviction_policy="evict_first")

        # Use inline assembly for efficient log2 calculation
        lz = tl.inline_asm_elementwise(
            asm="""
            {
                clz.b32 $0, $1;
            }
            """,
            constraints=("=r,r"),
            args=[block_size2],
            dtype=(tl.int32),
            is_pure=True,
            pack=1,
        )

        # Calculate second-level absmax indices using CLZ result
        absmax_group2 = (absmax_group)>>(31-lz)

        # Load scale factors with appropriate cache policies
        real_absmax = tl.load(code2 + absmax, eviction_policy="evict_last")
        absmax2 = tl.load(abs2_ptrs + absmax_group2, eviction_policy="evict_last")
        offset = tl.load(offset_ptr, eviction_policy="evict_last")

        # Calculate final scale factors with fused multiply-add
        final_absmax = absmax2 * real_absmax + offset

        # Calculate weight offsets using 2D memory pattern
        w_off = pid*(Br//2) + tl.arange(0, blocks)[:, None]*(Br//(2*blocks)) + tl.arange(0, Br//(2*blocks))[None, :]

        # Load packed weights
        w_packed = tl.load(w_ptr + w_off, eviction_policy="evict_first")

        # Interleave weights with themselves
        w_packed2 = tl.interleave(w_packed, w_packed)

        # Calculate shift amounts for each position
        shift_sh = tl.arange(0, blocks)[:, None]*(Br//(blocks)) + tl.arange(0, Br//(blocks))[None, :]
        shift = tl.where(shift_sh % 2 == 0, 4, 0)

        # Extract 4-bit values using calculated shifts
        shifted_w = (w_packed2 >> shift) & 0xF

        # Load dequantized values from codebook
        real_w = tl.load(shifted_w + code, eviction_policy="evict_last")

        # Apply scaling with broadcasting
        scaled_w = (real_w * final_absmax[:, None])

        # Calculate output offsets using tiled pattern
        out_off = pid*Br + tl.arange(0, blocks)[:, None]*(Br//blocks) + tl.arange(0, Br//blocks)[None, :]

        # Store results
        tl.store(w_out + out_off, scaled_w, eviction_policy="evict_first")
    return

def _your_dequantize_nf4(weight, quant_state):
    """
    Internal implementation function
    """
    # Get the current device
    device = weight.device

    # Extract metadata from quantization state
    out_dtype = quant_state.dtype
    code = quant_state.code
    absmax = quant_state.absmax
    real_shape = quant_state.shape
    block_size = quant_state.blocksize
    absmax2 = quant_state.state2.absmax
    code2 = quant_state.state2.code
    block_size2 = quant_state.state2.blocksize
    offset = quant_state.offset

    # Calculate sizes
    size = weight.shape[0]
    out_size = size * 2

    # Set processing block size
    Br = 8192
    blocks = Br // block_size

    # Calculate grid size
    gsize = (triton.cdiv(out_size, Br))

    # Optimize thread count based on GPU architecture
    props = torch.cuda.get_device_properties(device)

    # Adjust max threads per SM based on architecture
    if props.major == 8:
        if props.minor == 9:  # Ada Lovelace
            max_th = 24 * props.multi_processor_count
        elif props.minor == 0:  # Ampere
            max_th = 32 * props.multi_processor_count
    elif props.major == 7:
        max_th = 16 * props.multi_processor_count  # Turing
    else:
        # Default for other architectures
        max_th = 16 * props.multi_processor_count

    # Wave-front scheduling optimization
    resto = gsize % max_th
    wave_sze = gsize if resto == 0 else gsize + (max_th - resto)

    # Create output tensor with correct dtype handling
    is_t4 = (props.major == 7 and props.minor == 5)
    final_dtype = out_dtype if (out_dtype == torch.float16 or not is_t4) else torch.float16

    w_out = torch.empty(real_shape, device=device,
                        dtype=final_dtype,
                        requires_grad=False)

    # Launch kernel with wave-front scheduling
    grid = lambda META: ((wave_sze,))

    # Launch with optimized parameters
    _your_dequantize_nf4_kernel[grid](
        weight, w_out,
        absmax, offset,
        absmax2, code2,
        block_size2, gsize, code, blocks, Br,
        num_warps=16,
        num_stages=1,
        maxnreg=8192,
    )

    # Return with correct dtype
    return w_out


def your_dequantize_nf4(weight):
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state)

In [7]:
### TEST IT BELOW:
# test_dequantize(your_dequantize_nf4)

### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
test_dequantize(unsloth_dequantize) / test_dequantize(your_dequantize_nf4)

1.2712974987482513