In [1]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install transformers tf-keras

Collecting bitsandbytes
  Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting xformers==0.0.29
  Downloading xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting trl
  Downloading trl-0.18.1-py3-none-any.whl.metadata (11 kB)
Downloading xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl (15.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.3/15.3 MB[0m [31m80.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl (67.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading trl-0.18.1-py3-none-any.whl (366 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m366.3/366.3 kB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xformers, trl, bitsandbytes
Successfully installed bitsandbytes-0.46.0 trl-0.18.1 xformers-0

In [2]:
# Helpful functions used through the entire notebook
import torch
import torch.nn as nn
from transformers import set_seed
import time
import inspect
import os
major_version, minor_version = torch.cuda.get_device_capability()
HAS_BFLOAT16 = (major_version >= 8)
from inspect import currentframe as _C, getframeinfo
_F = lambda c: getframeinfo(c).lineno # Gets line number
WARN = lambda x: print(f"\033[31m{x}\033[0m") # Red colored warnings

# https://stackoverflow.com/questions/18425225/getting-the-name-of-a-variable-as-a-string
def NAME(var):
    callers_local_vars = inspect.currentframe().f_back.f_locals.items()
    names = [var_name for var_name, var_val in callers_local_vars if var_val is var]
    return names[0] if len(names) != 0 else ""

def assert_same(x, y, line, dtype):
    assert(x.dtype == dtype)
    try: torch.testing.assert_close(x, y, check_stride = True, atol=1e-1, rtol=1e-1)
    except Exception as error:
        raise RuntimeError(
            f"Failed allclose at line [{line}]: {NAME(x)}, {NAME(y)}\n{str(error)}"
        )

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [3]:
from bitsandbytes.nn import Linear4bit
from transformers.activations import ACT2FN
from unsloth.kernels.utils import fast_dequantize
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)

def bnb_Linear4bit(hd, m, dtype = torch.float16):
    return Linear4bit(
        hd, m, bias = None,
        compute_dtype       = dtype,
        compress_statistics = True,
        quant_type          = "nf4",
    )

# [NEW] as at 18th Feb 2025
def assert_correct_bnb(weight, dtype):
    assert(weight.weight.dtype == torch.uint8)
    assert(weight.weight.quant_state.dtype == dtype)
    assert(weight.weight.quant_state.absmax.dtype == torch.uint8)
    assert(weight.weight.quant_state.code.dtype == torch.float32)
    assert(weight.weight.quant_state.offset.dtype == torch.float32)
    assert(weight.weight.quant_state.blocksize == 64)
    assert(weight.weight.quant_state.state2.absmax.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.code.dtype == torch.float32)
    assert(weight.weight.quant_state.state2.blocksize == 256)

class MLP(nn.Module):
    def __init__(self, hd = 4096, m = 14336, dtype = torch.float16):
        super().__init__()
        self.gate_proj = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.up_proj   = bnb_Linear4bit(hd, m, dtype = dtype).to("cuda")
        self.down_proj = bnb_Linear4bit(m, hd, dtype = dtype).to("cuda")
        # [NEW] as at 18th Feb 2025
        self.gate_proj.weight.quant_state.dtype = dtype
        self.up_proj  .weight.quant_state.dtype = dtype
        self.down_proj.weight.quant_state.dtype = dtype
        self.act_fn = ACT2FN["silu"]
    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

def mlp_forward(X, mlp, fx):
    up   = X @ fx(mlp.  up_proj).t()
    gate = X @ fx(mlp.gate_proj).t()
    h = mlp.act_fn(gate) * up
    down = h @ fx(mlp.down_proj).t()
    return down

def mlp_dequantize(X, mlp, fx):
    a = fx(mlp.  up_proj).t(); torch.cuda.synchronize()
    b = fx(mlp.gate_proj).t(); torch.cuda.synchronize()
    c = fx(mlp.down_proj).t(); torch.cuda.synchronize()
    return a, b, c

def test_dequantize(dequantize_fx):
    elapsed = 0
    options = [
        (2, 3333, 2048,  8192, 3407, torch.float16),
        # (5,  777, 1024,  4096, 3409, torch.bfloat16),
        # (3, 2048, 4096, 14336, 3408, torch.bfloat16),
    ]
    for (bsz, qlen, hd, m, seed, dt) in options:
        set_seed(seed)
        torch.set_default_dtype(torch.float32)
        mlp = MLP(hd = hd, m = m, dtype = dt)
        X = torch.randn((bsz, qlen, hd), device = "cuda", dtype = dt)
        torch.cuda.synchronize()

        # Warmup
        for _ in range(2):
            assert_same( mlp_forward(X, mlp, dequantize_fx), mlp(X), _F(_C()), dt)
            # [NEW] as at 18th Feb 2025
            assert_correct_bnb(mlp.  up_proj, dt)
            assert_correct_bnb(mlp.gate_proj, dt)
            assert_correct_bnb(mlp.down_proj, dt)
            a, b, c = mlp_dequantize(X, mlp, dequantize_fx)
            A, B, C = mlp_dequantize(X, mlp, unsloth_dequantize)
            assert_same(a, A, _F(_C()), dt)
            assert_same(b, B, _F(_C()), dt)
            assert_same(c, C, _F(_C()), dt)

        # Benchmarking
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(1000): mlp_dequantize(X, mlp, dequantize_fx)
        elapsed += time.time() - start
    return elapsed


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth.kernels.utils import fast_dequantize


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.6.0+cu124)
    Python  3.11.11 (you have 3.11.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!


In [4]:
from unsloth.kernels.utils import fast_dequantize
def unsloth_dequantize(weight):
    return fast_dequantize(weight.weight, weight.weight.quant_state)
test_dequantize(unsloth_dequantize)

1.1453566551208496

In [5]:
from peft.utils.integrations import dequantize_module_weight as peft_dequantize
test_dequantize(peft_dequantize)

1.372269868850708

In [15]:
import torch
import triton
import triton.language as tl
import math

# --- 1. Triton JIT Kernel (V13c - Cautious internal changes for speed) ---
@triton.jit
def _your_dequantize_nf4_kernel( # V13 kernel base
    w_ptr, abs_idx_ptr, offset_ptr, abs2_scales_ptr, code2_ptr, nf4_code_ptr, output_ptr,
    TOTAL_ELEMENTS_IN_OUTPUT_TENSOR: tl.constexpr,
    BLOCK_SIZE_BYTES_PER_CHUNK: tl.constexpr,
    BLOCK_SIZE_ELEMENTS_PER_CHUNK: tl.constexpr,
    NUM_GROUPS_PER_CHUNK: tl.constexpr,
    ELEMENTS_PER_GROUP_CONST: tl.constexpr, # NF4 block size, e.g., 64
    LOG2_L2_BLOCK_SIZE_CONST_KERNEL: tl.constexpr,
    gsize_num_chunks: tl.constexpr
):
    pid = tl.program_id(0)
    if pid >= gsize_num_chunks: return

    log2_l2_block_size = LOG2_L2_BLOCK_SIZE_CONST_KERNEL

    # --- Scale Calculation (identical to V13, this was numerically correct) ---
    chunk_element_start_offset = pid * BLOCK_SIZE_ELEMENTS_PER_CHUNK
    group_arange_local = tl.arange(0, NUM_GROUPS_PER_CHUNK)
    absmax_group_indices_potential = (chunk_element_start_offset // ELEMENTS_PER_GROUP_CONST) + group_arange_local
    group_mask = (absmax_group_indices_potential * ELEMENTS_PER_GROUP_CONST) < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR

    quantized_absmax_indices = tl.load(abs_idx_ptr + absmax_group_indices_potential, mask=group_mask, other=0, eviction_policy="evict_first")
    dequantized_l1_scales = tl.load(code2_ptr + quantized_absmax_indices.to(tl.int32), mask=group_mask, other=0.0, eviction_policy="evict_last")

    absmax_l2_group_indices_potential = absmax_group_indices_potential >> log2_l2_block_size
    l2_scales = tl.load(abs2_scales_ptr + absmax_l2_group_indices_potential, mask=group_mask, other=0.0, eviction_policy="evict_last")

    offset_val = tl.load(offset_ptr + 0)
    intermediate_product_scales = l2_scales * dequantized_l1_scales
    final_group_scales_masked = intermediate_product_scales + offset_val # Shape: (NUM_GROUPS_PER_CHUNK,)

    # --- Weight Processing & Storing - Modified Section ---
    # Iterate over each element this pid is responsible for
    element_arange_pid_local = tl.arange(0, BLOCK_SIZE_ELEMENTS_PER_CHUNK)
    global_element_indices = chunk_element_start_offset + element_arange_pid_local

    # Overall mask for operations on these elements
    element_op_mask = global_element_indices < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR

    # Determine the byte and within-byte nibble position for each element
    byte_index_global_for_element = global_element_indices // 2
    is_low_nibble_flag = (global_element_indices % 2) != 0 # True if element is low nibble (odd index)

    # Load the single byte containing the 2 nibbles for each element
    # Mask this load to avoid reading out of bounds for w_ptr
    # Each element needs its specific byte; this will involve redundant loads if not careful,
    # but Triton's cache might help. For simplicity and correctness of masking:
    packed_byte_for_element = tl.load(w_ptr + byte_index_global_for_element, mask=element_op_mask, other=0)

    # Extract the relevant 4-bit index
    nibble_shift = tl.where(is_low_nibble_flag, 0, 4)
    quantized_idx_for_element = (packed_byte_for_element >> nibble_shift) & 0x0F

    # Dequantize the 4-bit index using the NF4 codebook
    # Mask this load if the element itself is out of bounds (already covered by element_op_mask if used)
    dequant_val_for_element = tl.load(nf4_code_ptr + quantized_idx_for_element.to(tl.int32), mask=element_op_mask, other=0.0)

    # Determine the correct scale for each element
    group_idx_of_element_local = element_arange_pid_local // ELEMENTS_PER_GROUP_CONST

    # Gather the scale for each element.
    # final_group_scales_masked is shape (NUM_GROUPS_PER_CHUNK,)
    # group_idx_of_element_local is shape (BLOCK_SIZE_ELEMENTS_PER_CHUNK,)
    # We need to use a gather workaround if tl.gather is not available,
    # otherwise, the reshape-broadcast-reshape used in the previous "gather fix" is correct.
    # Assuming the reshape-broadcast-reshape for element_scales was correct and working:
    scales_reshaped_for_broadcast = tl.reshape(final_group_scales_masked, (NUM_GROUPS_PER_CHUNK, 1))
    scales_broadcasted_to_elements = tl.broadcast_to(scales_reshaped_for_broadcast, (NUM_GROUPS_PER_CHUNK, ELEMENTS_PER_GROUP_CONST))
    element_scales_vector = tl.reshape(scales_broadcasted_to_elements, (BLOCK_SIZE_ELEMENTS_PER_CHUNK,))

    # Apply element_op_mask to the gathered scales
    final_scale_for_element = tl.where(element_op_mask, element_scales_vector, 0.0)

    # Final scaled value for each element
    scaled_element_output = dequant_val_for_element * final_scale_for_element

    # Store the result
    tl.store(output_ptr + global_element_indices, scaled_element_output, mask=element_op_mask)

    return

# --- Python Launcher for the V13c Kernel ---
# (This remains largely the same as the V13 launcher, just calls the new kernel name)
def _your_dequantize_nf4(
    weight_data_actual,
    quant_state_actual,
    CHUNK_SIZE_ELEMENTS_Br: int,
    num_warps_to_use: int,
    num_stages_to_use: int
    ):
    device = weight_data_actual.device
    output_shape = quant_state_actual.shape
    total_elements_in_output = output_shape.numel()

    if total_elements_in_output == 0:
        props_local = torch.cuda.get_device_properties(device) if device.type == 'cuda' else None
        is_t4_local = (props_local and props_local.major == 7 and props_local.minor == 5)
        final_out_dtype_empty = quant_state_actual.dtype
        if is_t4_local and quant_state_actual.dtype == torch.bfloat16: final_out_dtype_empty = torch.float16
        return torch.empty(output_shape, device=device, dtype=final_out_dtype_empty)

    props = torch.cuda.get_device_properties(device)
    is_t4 = (props.major == 7 and props.minor == 5)
    final_out_dtype = quant_state_actual.dtype
    if is_t4 and quant_state_actual.dtype == torch.bfloat16: final_out_dtype = torch.float16

    output_tensor = torch.empty(output_shape, dtype=final_out_dtype, device=device, requires_grad=False)

    current_br = CHUNK_SIZE_ELEMENTS_Br
    ELEMENTS_PER_GROUP_CONST = quant_state_actual.blocksize

    if not hasattr(quant_state_actual, 'state2') or \
       not hasattr(quant_state_actual.state2, 'blocksize') or \
       quant_state_actual.state2.blocksize <= 0:
        raise ValueError("quant_state.state2.blocksize is missing or invalid.")
    L2_BLOCK_SIZE_CONST_PYTHON = quant_state_actual.state2.blocksize
    if not (L2_BLOCK_SIZE_CONST_PYTHON > 0 and (L2_BLOCK_SIZE_CONST_PYTHON & (L2_BLOCK_SIZE_CONST_PYTHON - 1) == 0)):
        raise ValueError(f"L2_BLOCK_SIZE_CONST_PYTHON must be a power of 2 and > 0. Got {L2_BLOCK_SIZE_CONST_PYTHON}")

    if total_elements_in_output < current_br:
        current_br = triton.cdiv(total_elements_in_output, ELEMENTS_PER_GROUP_CONST) * ELEMENTS_PER_GROUP_CONST \
            if total_elements_in_output > ELEMENTS_PER_GROUP_CONST else ELEMENTS_PER_GROUP_CONST
    if current_br == 0 and total_elements_in_output > 0 : current_br = ELEMENTS_PER_GROUP_CONST
    elif current_br < ELEMENTS_PER_GROUP_CONST and total_elements_in_output > 0: current_br = ELEMENTS_PER_GROUP_CONST

    CHUNK_SIZE_BYTES_Br_half = current_br // 2
    NUM_GROUPS_PER_CHUNK_Br_div_64 = current_br // ELEMENTS_PER_GROUP_CONST

    if NUM_GROUPS_PER_CHUNK_Br_div_64 == 0 and total_elements_in_output > 0:
        if current_br < ELEMENTS_PER_GROUP_CONST: current_br = ELEMENTS_PER_GROUP_CONST
        CHUNK_SIZE_BYTES_Br_half = current_br // 2
        NUM_GROUPS_PER_CHUNK_Br_div_64 = current_br // ELEMENTS_PER_GROUP_CONST
        if NUM_GROUPS_PER_CHUNK_Br_div_64 == 0 :
            raise ValueError(f"Cannot form groups. Effective Br: {current_br}, Total Elem: {total_elements_in_output}, Group Size: {ELEMENTS_PER_GROUP_CONST}")

    LOG2_L2_BLOCK_SIZE_CONST_PYTHON = int(math.log2(L2_BLOCK_SIZE_CONST_PYTHON))

    gsize_num_chunks_val = triton.cdiv(total_elements_in_output, current_br)
    if gsize_num_chunks_val == 0 and total_elements_in_output > 0:
        gsize_num_chunks_val = 1

    grid = (gsize_num_chunks_val,)

    const_args_for_kernel = {
        'TOTAL_ELEMENTS_IN_OUTPUT_TENSOR': total_elements_in_output,
        'BLOCK_SIZE_BYTES_PER_CHUNK': CHUNK_SIZE_BYTES_Br_half, # May not be directly used by kernel if logic changes
        'BLOCK_SIZE_ELEMENTS_PER_CHUNK': current_br,
        'NUM_GROUPS_PER_CHUNK': NUM_GROUPS_PER_CHUNK_Br_div_64,
        'ELEMENTS_PER_GROUP_CONST': ELEMENTS_PER_GROUP_CONST,
        'LOG2_L2_BLOCK_SIZE_CONST_KERNEL': LOG2_L2_BLOCK_SIZE_CONST_PYTHON,
        'gsize_num_chunks': gsize_num_chunks_val
    }

    # Make sure to call the correct kernel name
    _your_dequantize_nf4_kernel[grid](
        weight_data_actual, quant_state_actual.absmax, quant_state_actual.offset,
        quant_state_actual.state2.absmax, quant_state_actual.state2.code, quant_state_actual.code,
        output_tensor,
        **const_args_for_kernel,
        num_warps=num_warps_to_use,
        num_stages=num_stages_to_use
    )
    return output_tensor

# --- 3. Top-Level Wrapper (Set your best launch parameters) ---
def your_dequantize_nf4(weight_param):
    OPTIMIZED_BR = 8192
    OPTIMIZED_WARPS = 16
    OPTIMIZED_STAGES = 1 # Start with parameters that were fast previously

    if not hasattr(weight_param, 'weight') or \
       not hasattr(weight_param.weight, 'data') or \
       not hasattr(weight_param.weight, 'quant_state'):
        raise ValueError("Input 'weight_param' is not the expected Linear4bit layer format or lacks necessary attributes.")
    weight_data_actual = weight_param.weight.data
    quant_state_actual = weight_param.weight.quant_state
    return _your_dequantize_nf4(
        weight_data_actual, quant_state_actual,
        CHUNK_SIZE_ELEMENTS_Br=OPTIMIZED_BR,
        num_warps_to_use=OPTIMIZED_WARPS,
        num_stages_to_use=OPTIMIZED_STAGES
    )

In [19]:
### TEST IT BELOW:
# test_dequantize(your_dequantize_nf4)

### CALCULATE SPEEDUP (hopefully 1.15x faster or more)
test_dequantize(unsloth_dequantize) / test_dequantize(your_dequantize_nf4)

1.1819698468989508

In [8]:
import torch
import triton
import triton.language as tl
import math
from transformers import set_seed
from bitsandbytes.nn import Linear4bit
# Ensure fast_dequantize is accessible from your Unsloth installation
from unsloth.kernels.utils import fast_dequantize # Assuming this path is correct

# --- 1. Triton JIT Kernel (V13c - Fast Version with element-centric processing) ---
@triton.jit
def _your_dequantize_nf4_kernel(
    w_ptr, abs_idx_ptr, offset_ptr, abs2_scales_ptr, code2_ptr, nf4_code_ptr, output_ptr,
    TOTAL_ELEMENTS_IN_OUTPUT_TENSOR: tl.constexpr,
    BLOCK_SIZE_BYTES_PER_CHUNK: tl.constexpr, # Not directly used in V13c main loop, but part of const_args
    BLOCK_SIZE_ELEMENTS_PER_CHUNK: tl.constexpr, # This is BR_CHUNK_SIZE_ELEMENTS
    NUM_GROUPS_PER_CHUNK: tl.constexpr,
    ELEMENTS_PER_GROUP_CONST: tl.constexpr, # NF4 block size, e.g., 64
    LOG2_L2_BLOCK_SIZE_CONST_KERNEL: tl.constexpr,
    gsize_num_chunks: tl.constexpr
):
    pid = tl.program_id(0)
    if pid >= gsize_num_chunks: return

    log2_l2_block_size = LOG2_L2_BLOCK_SIZE_CONST_KERNEL

    # --- Scale Calculation (identical to V13) ---
    chunk_element_start_offset = pid * BLOCK_SIZE_ELEMENTS_PER_CHUNK
    group_arange_local = tl.arange(0, NUM_GROUPS_PER_CHUNK)
    absmax_group_indices_potential = (chunk_element_start_offset // ELEMENTS_PER_GROUP_CONST) + group_arange_local
    group_mask = (absmax_group_indices_potential * ELEMENTS_PER_GROUP_CONST) < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR

    quantized_absmax_indices = tl.load(abs_idx_ptr + absmax_group_indices_potential, mask=group_mask, other=0, eviction_policy="evict_first")
    dequantized_l1_scales = tl.load(code2_ptr + quantized_absmax_indices.to(tl.int32), mask=group_mask, other=0.0, eviction_policy="evict_last")

    absmax_l2_group_indices_potential = absmax_group_indices_potential >> log2_l2_block_size
    l2_scales = tl.load(abs2_scales_ptr + absmax_l2_group_indices_potential, mask=group_mask, other=0.0, eviction_policy="evict_last")

    offset_val = tl.load(offset_ptr + 0)
    intermediate_product_scales = l2_scales * dequantized_l1_scales
    final_group_scales_masked = intermediate_product_scales + offset_val # Shape: (NUM_GROUPS_PER_CHUNK,)

    # --- Weight Processing & Storing (V13c element-centric approach) ---
    element_arange_pid_local = tl.arange(0, BLOCK_SIZE_ELEMENTS_PER_CHUNK) # Iterates 0 to BR_CHUNK_SIZE_ELEMENTS-1
    global_element_indices = chunk_element_start_offset + element_arange_pid_local
    element_op_mask = global_element_indices < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR

    byte_index_global_for_element = global_element_indices // 2
    is_low_nibble_flag = (global_element_indices % 2) != 0

    # Load the single byte that contains the 2 nibbles for each element.
    # Masked to prevent out-of-bounds reads for w_ptr.
    packed_byte_for_element = tl.load(w_ptr + byte_index_global_for_element, mask=element_op_mask, other=0)

    nibble_shift = tl.where(is_low_nibble_flag, 0, 4)
    quantized_idx_for_element = (packed_byte_for_element >> nibble_shift) & 0x0F

    dequant_val_for_element = tl.load(nf4_code_ptr + quantized_idx_for_element.to(tl.int32), mask=element_op_mask, other=0.0)

    # Gather scale for each element (using reshape-broadcast-reshape as tl.gather workaround)
    group_idx_of_element_local = element_arange_pid_local // ELEMENTS_PER_GROUP_CONST
    scales_reshaped_for_broadcast = tl.reshape(final_group_scales_masked, (NUM_GROUPS_PER_CHUNK, 1))
    scales_broadcasted_to_elements = tl.broadcast_to(scales_reshaped_for_broadcast, (NUM_GROUPS_PER_CHUNK, ELEMENTS_PER_GROUP_CONST))
    element_scales_vector = tl.reshape(scales_broadcasted_to_elements, (BLOCK_SIZE_ELEMENTS_PER_CHUNK,))
    final_scale_for_element = tl.where(element_op_mask, element_scales_vector, 0.0)

    scaled_element_output = dequant_val_for_element * final_scale_for_element

    tl.store(output_ptr + global_element_indices, scaled_element_output, mask=element_op_mask)
    return

# --- 2. Python Launcher for the V13c Kernel ---
def _your_dequantize_nf4(
    weight_data_actual,
    quant_state_actual,
    OPTIMIZED_BR: int,
    OPTIMIZED_WARPS: int,
    OPTIMIZED_STAGES: int
    ):
    device = weight_data_actual.device
    output_shape = quant_state_actual.shape
    total_elements_in_output = output_shape.numel()

    if total_elements_in_output == 0:
        props_local = torch.cuda.get_device_properties(device) if device.type == 'cuda' else None; is_t4_local = (props_local and props_local.major == 7 and props_local.minor == 5)
        final_out_dtype_empty = quant_state_actual.dtype;
        if is_t4_local and quant_state_actual.dtype == torch.bfloat16: final_out_dtype_empty = torch.float16
        return torch.empty(output_shape, device=device, dtype=final_out_dtype_empty)

    props = torch.cuda.get_device_properties(device)
    is_t4 = (props.major == 7 and props.minor == 5)
    final_out_dtype = quant_state_actual.dtype
    if is_t4 and quant_state_actual.dtype == torch.bfloat16: final_out_dtype = torch.float16

    output_tensor = torch.empty(output_shape, dtype=final_out_dtype, device=device, requires_grad=False)

    current_br_chunk_size_elements = OPTIMIZED_BR # This is BLOCK_SIZE_ELEMENTS_PER_CHUNK for kernel
    nf4_block_size_elements = quant_state_actual.blocksize # This is ELEMENTS_PER_GROUP_CONST for kernel

    if not hasattr(quant_state_actual, 'state2') or \
       not hasattr(quant_state_actual.state2, 'blocksize') or \
       quant_state_actual.state2.blocksize <= 0:
        raise ValueError("quant_state.state2.blocksize is missing or invalid.")
    block_size2_absmax_val = quant_state_actual.state2.blocksize
    if not (block_size2_absmax_val > 0 and (block_size2_absmax_val & (block_size2_absmax_val - 1) == 0)):
        raise ValueError(f"block_size2_absmax_val must be a power of 2 and > 0. Got {block_size2_absmax_val}")

    # Adjust current_br_chunk_size_elements for small tensors
    if total_elements_in_output < current_br_chunk_size_elements:
        current_br_chunk_size_elements = triton.cdiv(total_elements_in_output, nf4_block_size_elements) * nf4_block_size_elements \
            if total_elements_in_output > nf4_block_size_elements else nf4_block_size_elements
    if current_br_chunk_size_elements == 0 and total_elements_in_output > 0 : current_br_chunk_size_elements = nf4_block_size_elements
    elif current_br_chunk_size_elements < nf4_block_size_elements and total_elements_in_output > 0: current_br_chunk_size_elements = nf4_block_size_elements

    num_nf4_blocks_per_br_chunk_val = current_br_chunk_size_elements // nf4_block_size_elements
    if num_nf4_blocks_per_br_chunk_val == 0 and total_elements_in_output > 0:
        # This implies current_br_chunk_size_elements was < nf4_block_size_elements.
        # Ensure current_br_chunk_size_elements is at least nf4_block_size_elements.
        current_br_chunk_size_elements = nf4_block_size_elements
        num_nf4_blocks_per_br_chunk_val = 1 # Must have at least one group per chunk if processing data

    # BLOCK_SIZE_BYTES_PER_CHUNK is based on current_br_chunk_size_elements for V13c element-centric logic
    # However, the V13c kernel doesn't directly use BLOCK_SIZE_BYTES_PER_CHUNK for a bulk load anymore.
    # It's still passed as a constexpr, so we calculate it for completeness.
    block_size_bytes_per_chunk_val = current_br_chunk_size_elements // 2


    gsize_val = triton.cdiv(total_elements_in_output, current_br_chunk_size_elements)
    if gsize_val == 0 and total_elements_in_output > 0:
        gsize_val = 1

    grid = (gsize_val,)

    const_args_for_kernel = {
        'TOTAL_ELEMENTS_IN_OUTPUT_TENSOR': total_elements_in_output,
        'BLOCK_SIZE_BYTES_PER_CHUNK': block_size_bytes_per_chunk_val, # Used by V13 original, less critical in V13c loop
        'BLOCK_SIZE_ELEMENTS_PER_CHUNK': current_br_chunk_size_elements,
        'NUM_GROUPS_PER_CHUNK': num_nf4_blocks_per_br_chunk_val,
        'ELEMENTS_PER_GROUP_CONST': nf4_block_size_elements,
        'LOG2_L2_BLOCK_SIZE_CONST_KERNEL': int(math.log2(block_size2_absmax_val)),
        'gsize_num_chunks': gsize_val # Kernel uses this as gsize_num_chunks
    }

    # Kernel call for V13c
    _your_dequantize_nf4_kernel[grid](
        weight_data_actual,             # w_ptr
        quant_state_actual.absmax,      # abs_idx_ptr (L1 uint8 indices)
        quant_state_actual.offset,      # offset_ptr
        quant_state_actual.state2.absmax, # abs2_scales_ptr (L2 float32 scales)
        quant_state_actual.state2.code, # code2_ptr (L2 codebook for L1 absmax)
        quant_state_actual.code,        # nf4_code_ptr (NF4 codebook for weights)
        output_tensor,                  # output_ptr
        # **const_args_for_kernel, # Pass them as keywords to match kernel def
        TOTAL_ELEMENTS_IN_OUTPUT_TENSOR=total_elements_in_output,
        BLOCK_SIZE_BYTES_PER_CHUNK=block_size_bytes_per_chunk_val,
        BLOCK_SIZE_ELEMENTS_PER_CHUNK=current_br_chunk_size_elements,
        NUM_GROUPS_PER_CHUNK=num_nf4_blocks_per_br_chunk_val,
        ELEMENTS_PER_GROUP_CONST=nf4_block_size_elements,
        LOG2_L2_BLOCK_SIZE_CONST_KERNEL=int(math.log2(block_size2_absmax_val)),
        gsize_num_chunks=gsize_val,
        num_warps=OPTIMIZED_WARPS,
        num_stages=OPTIMIZED_STAGES
    )
    return output_tensor

# --- 3. Top-Level Wrapper ---
def your_dequantize_nf4(weight_param):
    # Using parameters that achieved 1.32x speedup
    OPTIMIZED_BR = 8192
    OPTIMIZED_WARPS = 16
    OPTIMIZED_STAGES = 1

    if not hasattr(weight_param, 'weight') or \
       not hasattr(weight_param.weight, 'data') or \
       not hasattr(weight_param.weight, 'quant_state'):
        raise ValueError("Input 'weight_param' is not the expected Linear4bit layer format or lacks necessary attributes.")

    packed_weight_data = weight_param.weight.data
    quant_state_object = weight_param.weight.quant_state

    return _your_dequantize_nf4(
        packed_weight_data, quant_state_object,
        OPTIMIZED_BR     = OPTIMIZED_BR,
        OPTIMIZED_WARPS  = OPTIMIZED_WARPS,
        OPTIMIZED_STAGES = OPTIMIZED_STAGES
    )

# --- 4. Reference Dequantization Function ---
def unsloth_dequantize_for_test(weight_param):
    if 'fast_dequantize' not in globals():
        try:
            from unsloth.kernels.utils import fast_dequantize as fd_temp
            globals()['fast_dequantize'] = fd_temp
        except ImportError:
            raise ImportError("Please ensure 'fast_dequantize' from 'unsloth.kernels.utils' is available.")
    return fast_dequantize(weight_param.weight, weight_param.weight.quant_state)

# --- 5. Numerical Correctness Check Function (Same as before) ---
def check_numerical_correctness_one_shot(
    your_kernel_func,
    reference_func,
    hidden_size=2048,
    intermediate_size=8192,
    dtype_str="fp16",
    seed=3407,
    atol_strict=1e-3,
    rtol_strict=1e-3
    ):
    print(f"\n--- Numerical Correctness Check ---")
    print(f"Your Kernel: {your_kernel_func.__name__} (Testing V13c element-centric kernel)")
    print(f"Reference  : {reference_func.__name__}")
    print(f"Parameters : hidden_size={hidden_size}, intermediate_size={intermediate_size}, dtype={dtype_str}, seed={seed}")
    print(f"Tolerances : atol={atol_strict}, rtol={rtol_strict}")

    set_seed(seed)
    torch.set_default_dtype(torch.float32)

    current_dtype = torch.float16
    quant_state_target_dtype = torch.float16

    if dtype_str == "bf16":
        quant_state_target_dtype = torch.bfloat16
        major_version, minor_version = torch.cuda.get_device_capability()
        if major_version >= 8:
            current_dtype = torch.bfloat16
        else:
            print(f"Note: bf16 requested. GPU (SM {major_version}.{minor_version}) will use fp16 for Linear4bit compute_dtype. Kernel output path for T4 (if applicable) will be fp16.")
            current_dtype = torch.float16

    try:
        layer_to_test = Linear4bit(
            hidden_size,
            intermediate_size,
            bias=None,
            compute_dtype=current_dtype,
            compress_statistics=True,
            quant_type="nf4",
        ).to("cuda")
        layer_to_test.weight.quant_state.dtype = quant_state_target_dtype

    except Exception as e:
        print(f"Error during Linear4bit layer setup: {e}")
        return False

    try:
        W_your_kernel = your_kernel_func(layer_to_test)
    except Exception as e:
        print(f"Error running YOUR dequantization kernel: {e}")
        import traceback
        traceback.print_exc()
        return False

    try:
        W_reference = reference_func(layer_to_test)
    except Exception as e:
        print(f"Error running REFERENCE dequantization function: {e}")
        return False

    if W_your_kernel.device != W_reference.device:
        W_reference = W_reference.to(W_your_kernel.device)
    if W_your_kernel.dtype != W_reference.dtype:
        print(f"Casting reference tensor from {W_reference.dtype} to {W_your_kernel.dtype} for comparison.")
        W_reference = W_reference.to(W_your_kernel.dtype)

    print(f"\nComparing outputs (Your Kernel Dtype: {W_your_kernel.dtype}, Reference Dtype After Cast: {W_reference.dtype})...")
    try:
        torch.testing.assert_close(W_your_kernel, W_reference, atol=atol_strict, rtol=rtol_strict, check_stride=False)
        print(f"SUCCESS: Outputs are numerically close within atol={atol_strict}, rtol={rtol_strict}.")
        return True
    except AssertionError as e:
        print(f"FAILURE: Outputs are NOT numerically close within atol={atol_strict}, rtol={rtol_strict}.")
        abs_diff = torch.abs(W_your_kernel - W_reference)
        max_abs_diff = torch.max(abs_diff)
        mean_abs_diff = torch.mean(abs_diff)
        idx_max_abs_diff_flat = torch.argmax(abs_diff.flatten())
        val_your_kernel_at_max_diff = W_your_kernel.flatten()[idx_max_abs_diff_flat]
        val_reference_at_max_diff = W_reference.flatten()[idx_max_abs_diff_flat]
        print(f"  Max absolute difference: {max_abs_diff.item():.8e}")
        print(f"  Mean absolute difference: {mean_abs_diff.item():.8e}")
        print(f"  Value from your kernel at max diff flat_index ({idx_max_abs_diff_flat.item()}): {val_your_kernel_at_max_diff.item():.8f}")
        print(f"  Value from reference at max diff flat_index ({idx_max_abs_diff_flat.item()}): {val_reference_at_max_diff.item():.8f}")
        threshold_context = 0.01
        num_elements_over_threshold = torch.sum(abs_diff > threshold_context).item()
        total_elements_val = W_your_kernel.numel()
        percentage_over_threshold = (num_elements_over_threshold / total_elements_val) * 100 if total_elements_val > 0 else 0
        print(f"  Number of elements with abs_diff > {threshold_context}: {num_elements_over_threshold} / {total_elements_val} ({percentage_over_threshold:.4f}%)")
        if 0 < num_elements_over_threshold < 20:
            print(f"  First few flat_indices where abs_diff > {threshold_context}: {torch.nonzero(abs_diff.flatten() > threshold_context).squeeze()[:5].tolist()}")
        return False

# --- 6. Example Usage for Numerical Correctness Test ---
print("Running numerical correctness check for V13c Hybrid Kernel (fp16)...")
is_correct_fp16_v13c = check_numerical_correctness_one_shot(
    your_kernel_func=your_dequantize_nf4,
    reference_func=unsloth_dequantize_for_test,
    hidden_size=2048,
    intermediate_size=8192,
    dtype_str="fp16",
    seed=3407,
    atol_strict=1e-3,
    rtol_strict=1e-3
)
print(f"Overall numerical correctness for V13c Hybrid Kernel fp16 (strict check): {'PASSED' if is_correct_fp16_v13c else 'FAILED'}")

print("\nRunning numerical correctness check for V13c Hybrid Kernel (bf16)...")
is_correct_bf16_v13c = check_numerical_correctness_one_shot(
    your_kernel_func=your_dequantize_nf4,
    reference_func=unsloth_dequantize_for_test,
    hidden_size=1024,
    intermediate_size=4096,
    dtype_str="bf16",
    seed=3409,
    atol_strict=1e-2,
    rtol_strict=1e-2
)
print(f"Overall numerical correctness for V13c Hybrid Kernel bf16 input (strict check): {'PASSED' if is_correct_bf16_v13c else 'FAILED'}")

Running numerical correctness check for V13c Hybrid Kernel (fp16)...

--- Numerical Correctness Check ---
Your Kernel: your_dequantize_nf4 (Testing V13c element-centric kernel)
Reference  : unsloth_dequantize_for_test
Parameters : hidden_size=2048, intermediate_size=8192, dtype=fp16, seed=3407
Tolerances : atol=0.001, rtol=0.001

Comparing outputs (Your Kernel Dtype: torch.float16, Reference Dtype After Cast: torch.float16)...
SUCCESS: Outputs are numerically close within atol=0.001, rtol=0.001.
Overall numerical correctness for V13c Hybrid Kernel fp16 (strict check): PASSED

Running numerical correctness check for V13c Hybrid Kernel (bf16)...

--- Numerical Correctness Check ---
Your Kernel: your_dequantize_nf4 (Testing V13c element-centric kernel)
Reference  : unsloth_dequantize_for_test
Parameters : hidden_size=1024, intermediate_size=4096, dtype=bf16, seed=3409
Tolerances : atol=0.01, rtol=0.01
Note: bf16 requested. GPU (SM 7.5) will use fp16 for Linear4bit compute_dtype. Kernel ou

In [9]:
import torch
import triton
import triton.language as tl
import math
from bitsandbytes.nn import Linear4bit
import torch._dynamo as dynamo

# --- 1. Your V13c Kernel and Launcher (as provided) ---
@triton.jit
def _your_dequantize_nf4_kernel( # V13 kernel base
    w_ptr, abs_idx_ptr, offset_ptr, abs2_scales_ptr, code2_ptr, nf4_code_ptr, output_ptr,
    TOTAL_ELEMENTS_IN_OUTPUT_TENSOR: tl.constexpr,
    BLOCK_SIZE_BYTES_PER_CHUNK: tl.constexpr,
    BLOCK_SIZE_ELEMENTS_PER_CHUNK: tl.constexpr,
    NUM_GROUPS_PER_CHUNK: tl.constexpr,
    ELEMENTS_PER_GROUP_CONST: tl.constexpr, # NF4 block size, e.g., 64
    LOG2_L2_BLOCK_SIZE_CONST_KERNEL: tl.constexpr,
    gsize_num_chunks: tl.constexpr
):
    pid = tl.program_id(0)
    if pid >= gsize_num_chunks: return
    log2_l2_block_size = LOG2_L2_BLOCK_SIZE_CONST_KERNEL
    chunk_element_start_offset = pid * BLOCK_SIZE_ELEMENTS_PER_CHUNK
    group_arange_local = tl.arange(0, NUM_GROUPS_PER_CHUNK)
    absmax_group_indices_potential = (chunk_element_start_offset // ELEMENTS_PER_GROUP_CONST) + group_arange_local
    group_mask = (absmax_group_indices_potential * ELEMENTS_PER_GROUP_CONST) < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR
    quantized_absmax_indices = tl.load(abs_idx_ptr + absmax_group_indices_potential, mask=group_mask, other=0, eviction_policy="evict_first")
    dequantized_l1_scales = tl.load(code2_ptr + quantized_absmax_indices.to(tl.int32), mask=group_mask, other=0.0, eviction_policy="evict_last")
    absmax_l2_group_indices_potential = absmax_group_indices_potential >> log2_l2_block_size
    l2_scales = tl.load(abs2_scales_ptr + absmax_l2_group_indices_potential, mask=group_mask, other=0.0, eviction_policy="evict_last")
    offset_val = tl.load(offset_ptr + 0)
    intermediate_product_scales = l2_scales * dequantized_l1_scales
    final_group_scales_masked = intermediate_product_scales + offset_val
    element_arange_pid_local = tl.arange(0, BLOCK_SIZE_ELEMENTS_PER_CHUNK)
    global_element_indices = chunk_element_start_offset + element_arange_pid_local
    element_op_mask = global_element_indices < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR
    byte_index_global_for_element = global_element_indices // 2
    is_low_nibble_flag = (global_element_indices % 2) != 0
    packed_byte_for_element = tl.load(w_ptr + byte_index_global_for_element, mask=element_op_mask, other=0)
    nibble_shift = tl.where(is_low_nibble_flag, 0, 4)
    quantized_idx_for_element = (packed_byte_for_element >> nibble_shift) & 0x0F
    dequant_val_for_element = tl.load(nf4_code_ptr + quantized_idx_for_element.to(tl.int32), mask=element_op_mask, other=0.0)
    scales_reshaped_for_broadcast = tl.reshape(final_group_scales_masked, (NUM_GROUPS_PER_CHUNK, 1))
    scales_broadcasted_to_elements = tl.broadcast_to(scales_reshaped_for_broadcast, (NUM_GROUPS_PER_CHUNK, ELEMENTS_PER_GROUP_CONST))
    element_scales_vector = tl.reshape(scales_broadcasted_to_elements, (BLOCK_SIZE_ELEMENTS_PER_CHUNK,))
    final_scale_for_element = tl.where(element_op_mask, element_scales_vector, 0.0)
    scaled_element_output = dequant_val_for_element * final_scale_for_element
    tl.store(output_ptr + global_element_indices, scaled_element_output, mask=element_op_mask)
    return

def _your_dequantize_nf4(
    weight_data_actual, quant_state_actual,
    CHUNK_SIZE_ELEMENTS_Br: int, num_warps_to_use: int, num_stages_to_use: int
):
    device = weight_data_actual.device
    output_shape = quant_state_actual.shape
    total_elements_in_output = output_shape.numel()
    if total_elements_in_output == 0:
        return torch.empty(output_shape, device=device, dtype=quant_state_actual.dtype)
    props = torch.cuda.get_device_properties(device)
    is_t4 = (props.major == 7 and props.minor == 5)
    final_out_dtype = quant_state_actual.dtype
    if is_t4 and quant_state_actual.dtype == torch.bfloat16: final_out_dtype = torch.float16
    output_tensor = torch.empty(output_shape, dtype=final_out_dtype, device=device, requires_grad=False)
    current_br = CHUNK_SIZE_ELEMENTS_Br
    ELEMENTS_PER_GROUP_CONST = quant_state_actual.blocksize
    L2_BLOCK_SIZE_CONST_PYTHON = quant_state_actual.state2.blocksize
    if not (L2_BLOCK_SIZE_CONST_PYTHON > 0 and (L2_BLOCK_SIZE_CONST_PYTHON & (L2_BLOCK_SIZE_CONST_PYTHON - 1) == 0)):
        raise ValueError(f"L2_BLOCK_SIZE_CONST_PYTHON must be a power of 2 and > 0. Got {L2_BLOCK_SIZE_CONST_PYTHON}")
    if total_elements_in_output < current_br:
        current_br = triton.cdiv(total_elements_in_output, ELEMENTS_PER_GROUP_CONST) * ELEMENTS_PER_GROUP_CONST \
            if total_elements_in_output > ELEMENTS_PER_GROUP_CONST else ELEMENTS_PER_GROUP_CONST
    if current_br == 0 and total_elements_in_output > 0 : current_br = ELEMENTS_PER_GROUP_CONST
    elif current_br < ELEMENTS_PER_GROUP_CONST and total_elements_in_output > 0: current_br = ELEMENTS_PER_GROUP_CONST
    CHUNK_SIZE_BYTES_Br_half = current_br // 2
    NUM_GROUPS_PER_CHUNK_Br_div_64 = current_br // ELEMENTS_PER_GROUP_CONST
    if NUM_GROUPS_PER_CHUNK_Br_div_64 == 0 and total_elements_in_output > 0:
        if current_br < ELEMENTS_PER_GROUP_CONST: current_br = ELEMENTS_PER_GROUP_CONST
        CHUNK_SIZE_BYTES_Br_half = current_br // 2
        NUM_GROUPS_PER_CHUNK_Br_div_64 = current_br // ELEMENTS_PER_GROUP_CONST
        if NUM_GROUPS_PER_CHUNK_Br_div_64 == 0 :
            raise ValueError(f"Cannot form groups. Effective Br: {current_br}, Total Elem: {total_elements_in_output}, Group Size: {ELEMENTS_PER_GROUP_CONST}")
    LOG2_L2_BLOCK_SIZE_CONST_PYTHON = int(math.log2(L2_BLOCK_SIZE_CONST_PYTHON))
    gsize_num_chunks_val = triton.cdiv(total_elements_in_output, current_br)
    if gsize_num_chunks_val == 0 and total_elements_in_output > 0:
        gsize_num_chunks_val = 1
    grid = (gsize_num_chunks_val,)
    const_args_for_kernel = {
        'TOTAL_ELEMENTS_IN_OUTPUT_TENSOR': total_elements_in_output,
        'BLOCK_SIZE_BYTES_PER_CHUNK': CHUNK_SIZE_BYTES_Br_half,
        'BLOCK_SIZE_ELEMENTS_PER_CHUNK': current_br,
        'NUM_GROUPS_PER_CHUNK': NUM_GROUPS_PER_CHUNK_Br_div_64,
        'ELEMENTS_PER_GROUP_CONST': ELEMENTS_PER_GROUP_CONST,
        'LOG2_L2_BLOCK_SIZE_CONST_KERNEL': LOG2_L2_BLOCK_SIZE_CONST_PYTHON,
        'gsize_num_chunks': gsize_num_chunks_val
    }
    _your_dequantize_nf4_kernel[grid](
        weight_data_actual, quant_state_actual.absmax, quant_state_actual.offset,
        quant_state_actual.state2.absmax, quant_state_actual.state2.code, quant_state_actual.code,
        output_tensor,
        **const_args_for_kernel,
        num_warps=num_warps_to_use,
        num_stages=num_stages_to_use
    )
    return output_tensor

def your_dequantize_nf4(weight_param):
    OPTIMIZED_BR = 4096
    OPTIMIZED_WARPS = 4
    OPTIMIZED_STAGES = 2
    if not hasattr(weight_param, 'weight') or \
       not hasattr(weight_param.weight, 'data') or \
       not hasattr(weight_param.weight, 'quant_state'):
        raise ValueError("Input 'weight_param' is not the expected Linear4bit layer format or lacks necessary attributes.")
    weight_data_actual = weight_param.weight.data
    quant_state_actual = weight_param.weight.quant_state
    return _your_dequantize_nf4(
        weight_data_actual, quant_state_actual,
        CHUNK_SIZE_ELEMENTS_Br=OPTIMIZED_BR,
        num_warps_to_use=OPTIMIZED_WARPS,
        num_stages_to_use=OPTIMIZED_STAGES
    )

# --- 2. Corrected Test Harness for torch.compile ---
def run_torch_compile_test():
    """Creates a sample layer, compiles the dequantize function, and runs it."""
    print("--- Setting up test for torch.compile ---")
    try:
        # CORRECTED: Using `input_features` and `output_features`
        sample_layer = Linear4bit(
            input_features=1024,
            output_features=2048,
            bias=None,
            compute_dtype=torch.bfloat16,
            quant_type='nf4'
        ).to("cuda")
        weight_param = sample_layer
        print("Sample Linear4bit layer created successfully.")
    except Exception as e:
        print(f"\n[ERROR] Failed to create sample layer: {e}")
        return

    print("\n--- Attempting to JIT compile `your_dequantize_nf4` ---")
    try:
        compiled_fn = torch.compile(your_dequantize_nf4, mode="max-autotune")
        print("torch.compile() call finished without error.")
    except Exception as e:
        print(f"\n[FAILED] torch.compile threw an error during the initial compilation step: {e}")
        return

    print("\n--- Running the compiled function for the first time (triggers JIT) ---")
    try:
        result = compiled_fn(weight_param)
        print("\n[SUCCESS] Compiled function ran without errors!")
        print(f"  Output tensor shape: {result.shape}")
        print(f"  Output tensor dtype: {result.dtype}")
    except Exception as e:
        import traceback
        print(f"\n[FAILED] The compiled function threw an error on its first run:")
        traceback.print_exc()

# --- 3. Execute the Test ---
run_torch_compile_test()

--- Setting up test for torch.compile ---
Sample Linear4bit layer created successfully.

--- Attempting to JIT compile `your_dequantize_nf4` ---
torch.compile() call finished without error.

--- Running the compiled function for the first time (triggers JIT) ---

[SUCCESS] Compiled function ran without errors!
  Output tensor shape: torch.Size([2048, 1024])
  Output tensor dtype: torch.float32
