In [1]:
!pip install -q bitsandbytes
!pip install -q triton
!pip install -q transformers
!pip install -q peft
!pip install -q unsloth

import torch
import time
import torch.nn as nn
import numpy as np

# Helper functions for testing
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

def _F(s):
    return f"File '{s}'"

def _C():
    import traceback
    return traceback.extract_stack()[-3][0]

def assert_same(a, b, file, dtype):
    if a.shape != b.shape:
        print(f"Shape mismatch: {a.shape} vs {b.shape}")
        assert(False)
    if dtype == torch.float16:
        assert(torch.allclose(a, b, rtol = 1e-3, atol = 3e-3))
    else:
        assert(torch.allclose(a, b, rtol = 1e-3, atol = 3e-3))

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m107.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m82.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        (5,  777, 1024,  4096, 3409, torch.bfloat16),
        (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth.kernels.utils import fast_dequantize


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [12]:
from triton import jit
import torch
import triton
import triton.language as tl

# Global buffers for memory reuse
WEIGHT_BUFFERS = {}

@triton.jit
def _dequantize_nf4_kernel(
    # Input pointers
    weight_ptr,
    absmax_ptr,
    code_ptr,
    offset_ptr,
    state2_absmax_ptr,
    state2_code_ptr,
    output_ptr,

    # Dimensions and parameters
    n_elements,
    blocksize,
    state2_blocksize,

    # Block size for parallelism
    BLOCK_SIZE: tl.constexpr,
):
    """Simple linear kernel for NF4 dequantization."""
    # Get program ID
    pid = tl.program_id(0)

    # Calculate start offset
    offset = pid * BLOCK_SIZE

    # Create a simple 1D processing range
    offsets = offset + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    # Load quantized weights
    weight_ptrs = weight_ptr + offsets
    weights = tl.load(weight_ptrs, mask=mask, other=0)

    # Extract NF4 values (lower 4 bits of each byte)
    nf4_indices = weights & 0xF

    # Calculate block indices
    block_indices = offsets // blocksize
    state2_block_indices = offsets // state2_blocksize

    # Load quantization parameters
    absmax_ptrs = absmax_ptr + block_indices
    absmax_indices = tl.load(absmax_ptrs, mask=mask, other=0)

    offset_ptrs = offset_ptr + block_indices
    offset_vals = tl.load(offset_ptrs, mask=mask, other=0)

    state2_absmax_ptrs = state2_absmax_ptr + state2_block_indices
    state2_absmax = tl.load(state2_absmax_ptrs, mask=mask, other=0)

    state2_code_ptrs = state2_code_ptr + state2_block_indices
    state2_code = tl.load(state2_code_ptrs, mask=mask, other=0)

    # Load code values
    code_ptrs = code_ptr + nf4_indices
    code_vals = tl.load(code_ptrs, mask=None, other=0)

    absmax_code_ptrs = code_ptr + absmax_indices
    absmax_code_vals = tl.load(absmax_code_ptrs, mask=None, other=0)

    # Perform dequantization
    dequantized_absmax = absmax_code_vals * state2_absmax * state2_code
    dequantized = code_vals * dequantized_absmax + offset_vals

    # Store result
    output_ptrs = output_ptr + offsets
    tl.store(output_ptrs, dequantized, mask=mask)

def _your_dequantize_nf4(weight, quant_state, out=None, use_global_buffer=True):
    """
    Dequantize NF4 weights using a simplified approach for compatibility.
    """
    # Get device and shape
    device = weight.device
    shape = weight.shape
    n_elements = weight.numel()

    # Extract quantization parameters
    absmax = quant_state.absmax
    dtype = quant_state.dtype
    blocksize = quant_state.blocksize
    offset = quant_state.offset
    state2 = quant_state.state2
    absmax2 = state2.absmax
    code2 = state2.code
    blocksize2 = state2.blocksize

    # Reuse buffer if requested
    if use_global_buffer:
        global WEIGHT_BUFFERS
        device_idx = device.index if hasattr(device, 'index') else 0

        if device_idx not in WEIGHT_BUFFERS or WEIGHT_BUFFERS[device_idx] is None:
            WEIGHT_BUFFERS[device_idx] = torch.empty(n_elements, dtype=dtype, device=device)

        buffer = WEIGHT_BUFFERS[device_idx]
        if n_elements > buffer.numel():
            buffer.resize_(n_elements)

        out = buffer[:n_elements].view(shape)
    else:
        if out is None:
            out = torch.empty(shape, dtype=dtype, device=device)

    # Get data pointers
    weight_ptr = weight.data_ptr()
    absmax_ptr = absmax.data_ptr()
    code_ptr = quant_state.code.data_ptr()
    offset_ptr = offset.data_ptr()
    state2_absmax_ptr = absmax2.data_ptr()
    state2_code_ptr = code2.data_ptr()
    output_ptr = out.data_ptr()

    # Use a simple 1D grid for better compatibility
    BLOCK_SIZE = 1024  # Good balance for T4
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)

    # Launch kernel
    _dequantize_nf4_kernel[grid](
        weight_ptr,
        absmax_ptr,
        code_ptr,
        offset_ptr,
        state2_absmax_ptr,
        state2_code_ptr,
        output_ptr,
        n_elements,
        blocksize,
        blocksize2,
        BLOCK_SIZE,
    )

    # Handle transposition
    is_transposed = (weight.shape[0] == 1)
    return out.t() if is_transposed else out

def your_dequantize_nf4(weight):
    """
    Entry point function for dequantizing weights.
    """
    return _your_dequantize_nf4(weight.weight.data, weight.weight.quant_state, use_global_buffer=True)

In [13]:
### TEST IT BELOW:
test_dequantize(your_dequantize_nf4)

### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
# test_dequantize(unsloth_dequantize) / test_dequantize(your_dequantize_nf4)

CompilationError: at 32:14:
    pid = tl.program_id(0)

    # Calculate start offset
    offset = pid * BLOCK_SIZE

    # Create a simple 1D processing range
    offsets = offset + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    # Load quantized weights
    weight_ptrs = weight_ptr + offsets
    weights = tl.load(weight_ptrs, mask=mask, other=0)
              ^