In [1]:
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install unsloth
!pip install ninja

In [2]:
!pip install --upgrade triton

Collecting triton
  Downloading triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.2/253.2 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
  Attempting uninstall: triton
    Found existing installation: triton 3.1.0
    Uninstalling triton-3.1.0:
      Successfully uninstalled triton-3.1.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.5.1+cu124 requires triton==3.1.0; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13", but you have triton 3.2.0 which is incompatible.[0m[31m
[0mSuccessfully installed triton-3.2.0


In [3]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [4]:
# Global variables used in fast_dequantize.
CUDA_STREAM = None
WEIGHT_BUFFER = None
ABSMAX_BUFFER = None

# ===== Step 1: Monkey-patch dataclasses.fields (if needed) =====
import dataclasses

# Save the original dataclasses.fields.
_original_fields = dataclasses.fields

def safe_fields(obj):
    try:
        return _original_fields(obj)
    except TypeError:
        # If obj is not a dataclass, return an empty tuple (or customize as needed)
        return tuple()

# Monkey-patch fields to avoid errors during unsloth imports.
dataclasses.fields = safe_fields

# ===== Step 2: Import required modules =====
import torch
from bitsandbytes.nn import Linear4bit
from transformers import set_seed
from transformers.activations import ACT2FN
from peft.utils.integrations import dequantize_module_weight as peft_dequantize

# Now import unsloth utilities (which might call dataclasses.fields).
from unsloth.kernels.utils import (
    get_ptr,
    cdequantize_blockwise_fp32,
    ctypes_c_int,
    cdequantize_blockwise_fp16_nf4,
    cdequantize_blockwise_bf16_nf4
)

# ===== Step 3: Import dataclasses utilities and define our override =====
from dataclasses import is_dataclass, dataclass

# Define a new dataclass for the quantization state.
@dataclass
class NewQuantState:
    absmax: torch.Tensor
    code: torch.Tensor
    offset: torch.Tensor   # or float, depending on your usage
    blocksize: int
    dtype: torch.dtype
    state2: object       # refine this type if possible
    shape: tuple         # shape of the dequantized weight matrix

# Helper: ensure the quant_state is a dataclass instance.
def ensure_dataclass_quant_state(qs):
    if is_dataclass(qs):
        return qs
    return NewQuantState(
        absmax = qs.absmax,
        code = qs.code,
        offset = qs.offset,
        blocksize = qs.blocksize,
        dtype = qs.dtype,
        state2 = qs.state2,
        shape = qs.shape
    )

# New fast_dequantize implementation using our helper.
def new_fast_dequantize(W, quant_state=None, out=None, use_global_buffer=False):
    if quant_state is None:
        return W
    if type(quant_state) is not list:
        quant_state = ensure_dataclass_quant_state(quant_state)
        absmax     = quant_state.absmax
        shape      = quant_state.shape
        dtype      = quant_state.dtype
        blocksize  = quant_state.blocksize
        offset     = quant_state.offset
        state2     = quant_state.state2
        absmax2    = state2.absmax
        code2      = state2.code
        blocksize2 = state2.blocksize
    else:
        # Handle the old quant_state (if needed)
        absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
        offset, state2 = compressed_stats
        absmax2, code2, blocksize2, _, _, _, _ = state2
    global CUDA_STREAM
    if CUDA_STREAM is None:
        CUDA_STREAM = torch.cuda.current_stream("cuda:0")
    n_elements_absmax = absmax.numel()
    if use_global_buffer:
        size = shape[0] * shape[1]
        global WEIGHT_BUFFER, ABSMAX_BUFFER
        if WEIGHT_BUFFER is None:
            WEIGHT_BUFFER = torch.empty(size, dtype=dtype, device="cuda:0", requires_grad=False)
            ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype=torch.float32, device="cuda:0", requires_grad=False)
        if size > WEIGHT_BUFFER.numel():
            WEIGHT_BUFFER.resize_(size)
        if n_elements_absmax > ABSMAX_BUFFER.numel():
            ABSMAX_BUFFER.resize_(n_elements_absmax)
        out = WEIGHT_BUFFER[:size].view(shape)
        out_absmax = ABSMAX_BUFFER[:n_elements_absmax]
    else:
        if out is None:
            out = torch.empty(shape, dtype=dtype, device="cuda:0", requires_grad=False)
        else:
            assert(out.shape == shape)
            assert(out.dtype == dtype)
        out_absmax = torch.empty(n_elements_absmax, dtype=torch.float32, device="cuda:0", requires_grad=False)
    ptr_out_absmax = get_ptr(out_absmax)
    cdequantize_blockwise_fp32(
        get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
        ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM,
    )
    out_absmax += offset
    fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else cdequantize_blockwise_bf16_nf4
    fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
       ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM)
    is_transposed = (True if W.shape[0] == 1 else False)
    return out.t() if is_transposed else out

# Override the fast_dequantize in unsloth.kernels.utils.
import unsloth.kernels.utils as uutils
uutils.fast_dequantize = new_fast_dequantize

# (Optionally restore dataclasses.fields here if needed)
dataclasses.fields = _original_fields

# ===== Step 4: Define your model and test functions =====
def dequantize_wrapper(module):
    # module is a Linear4bit instance; extract its weight and quant_state.
    return new_fast_dequantize(module.weight, module.weight.quant_state)


def bnb_Linear4bit(hd, m, dtype=torch.float16):
    return Linear4bit(
        hd, m, bias=None,
        compute_dtype=dtype,
        compress_statistics=True,
        quant_type="nf4",
    )

def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

import torch.nn as nn
class MLP(nn.Module):
    def __init__(self, hd=4096, m=14336, dtype=torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype=dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype=dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype=dtype).to("cuda")
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj.weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up = X @ fx(mlp.up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    import time
    elapsed = 0
    options = [
        (2, 3333, 2048, 8192, 3407, torch.float16),
        (5, 777, 1024, 4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd=hd, m=m, dtype=dt)
        X = torch.randn((bsz, qlen, hd), device="cuda", dtype=dt)
        torch.cuda.synchronize()

        # Warmup: use dequantize_wrapper everywhere.
        for _ in range(2):
            # Use the wrapper so that a tensor is returned.
            assert_same(mlp_forward(X, mlp, dequantize_wrapper), mlp(X), _F(_C()), dt)
            assert_correct_bnb(mlp.up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_wrapper)
            A, B, C = mlp_dequantize(X, mlp, dequantize_wrapper)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000):
            mlp_dequantize(X, mlp, dequantize_wrapper)
        elapsed += time.time() - start
    return elapsed

# ===== Step 5: Run your tests =====
elapsed = test_dequantize(dequantize_wrapper)
print("Elapsed time for your kernel: ", elapsed)



🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Elapsed time for your kernel:  5.022521495819092


In [5]:
import torch
from triton import jit, cdiv
import triton.language as tl

# Define a tile size: each thread will process TILE outputs.
TILE = 16

@jit
def _your_dequantize_nf4_kernel_tiled(
    weight_ptr,             # pointer to the uint8 weight tensor (each byte packs 2 nf4 values)
    absmax_ptr,             # pointer to absmax tensor (dtype=torch.uint8)
    nested_absmax_ptr,      # pointer to nested_absmax tensor (dtype=torch.float32)
    out_ptr,                # pointer to output tensor (fp16 or bf16)
    offset,                 # scalar offset (float)
    num_elements: tl.constexpr,   # total number of dequantized outputs
    blocksize: tl.constexpr,      # block size for weight dequantization (e.g., 64)
    nested_ratio: tl.constexpr,   # number of weight blocks per nested scaling factor (e.g., 256/64 = 4)
    BLOCK_SIZE: tl.constexpr      # total outputs processed per kernel instance (must be a multiple of TILE)
):
    # Each kernel instance processes BLOCK_SIZE outputs.
    # Each thread processes TILE outputs.
    THREADS = BLOCK_SIZE // TILE

    # Get thread index (per kernel instance) as int32.
    tid = tl.arange(0, THREADS)
    tid = tl.cast(tid, tl.int32)
    # Compute a tile of global indices for each thread.
    base = tl.cast(tid, tl.int32) * TILE  # starting index for each thread within the kernel instance
    # Global base index for this kernel instance.
    pid = tl.cast(tl.program_id(0), tl.int32)
    global_base = pid * BLOCK_SIZE
    # Each thread will compute indices: global_base + base + i, for i in [0, TILE)
    # We'll loop over i.
    for i in range(TILE):
        idx = global_base + base + i  # each idx is int32, shape (THREADS,)
        m = idx < tl.cast(num_elements, tl.int32)  # mask

        # Compute weight block index: each block covers 'blocksize' outputs.
        bs = int(blocksize)   # e.g., 64
        nr = int(nested_ratio)  # e.g., 4
        wb = tl.floordiv(idx, bs)       # weight block index (int32)
        nest_idx = tl.floordiv(wb, nr)    # nested index (int32)

        # Asynchronously copy scaling factors from global memory.
        # (If tl.copy_async is not available, this call will fallback to tl.load.)
        abs_val = tl.cast(tl.copy_async(absmax_ptr + wb), tl.float32)
        nest_val = tl.cast(tl.copy_async(nested_absmax_ptr + nest_idx), tl.float32)
        tl.wait_async()  # Ensure async copy has completed.
        scale = (abs_val / 255.0) * (nest_val / 255.0) + offset

        # For the weight: each uint8 packs 2 nf4 values.
        # Compute the byte index: each output corresponds to an nf4 value.
        byte_index = tl.floordiv(idx, 2)
        byte_index = tl.cast(byte_index, tl.int32)
        weight_byte = tl.load(weight_ptr + byte_index, mask=m)
        # Determine if this output comes from lower or upper 4 bits.
        is_even = (idx % 2) == 0
        lower_val = weight_byte & 0x0F
        upper_val = (weight_byte >> 4) & 0x0F
        quant_val = tl.where(is_even, lower_val, upper_val)
        # Dequantize: normalize to [0,1] and multiply by scale.
        dequant = (tl.cast(quant_val, tl.float32) / 15.0) * scale

        tl.store(out_ptr + idx, dequant, mask=m)

###############################################
# Launcher and Wrapper Functions
###############################################
def _your_dequantize_nf4(weight, quant_state):
    shape = quant_state.shape
    total_elements = shape[0] * shape[1]
    out = torch.empty(total_elements, dtype=quant_state.dtype, device=weight.device)

    blocksize = quant_state.blocksize         # e.g., 64
    nested_ratio = quant_state.state2.blocksize // blocksize  # e.g., 256 // 64 = 4
    offset = float(quant_state.offset.item()) if isinstance(quant_state.offset, torch.Tensor) else float(quant_state.offset)

    weight_ptr = weight.data_ptr()
    absmax_ptr = quant_state.absmax.data_ptr()
    nested_absmax_ptr = quant_state.state2.absmax.data_ptr()
    out_ptr = out.data_ptr()

    # Choose a BLOCK_SIZE that is a multiple of TILE.
    BLOCK_SIZE = 1024
    grid = lambda meta: (cdiv(total_elements, meta['BLOCK_SIZE']),)

    _your_dequantize_nf4_kernel_tiled[grid](
        weight_ptr,
        absmax_ptr,
        nested_absmax_ptr,
        out_ptr,
        offset,
        total_elements,
        blocksize,
        nested_ratio,
        BLOCK_SIZE,
    )

    out = out.view(shape)
    if weight.shape[0] == 1:
        out = out.t()
    return out

def your_dequantize_nf4(weight_obj):
    """
    weight_obj: an object with attributes:
         - weight_obj.weight.data         (torch.Tensor, dtype=torch.uint8)
         - weight_obj.weight.quant_state  (with the quantization state as described)
    """
    return _your_dequantize_nf4(weight_obj.weight.data, weight_obj.weight.quant_state)



elapsed = test_dequantize(your_dequantize_nf4)
print("Elapsed time for your kernel: ", elapsed)

print(test_dequantize(dequantize_wrapper) / test_dequantize(your_dequantize_nf4))

Elapsed time for your kernel:  5.098713159561157
1.0129438812848506
