<a href="https://colab.research.google.com/github/Aryan8912/Unsolth.ai-challenage/blob/main/nf4_with_triton_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-

In [3]:
import torch
import triton
import triton.language as tl
from dataclasses import dataclass
from typing import Optional, Tuple
import time
from transformers import BitsAndBytesConfig
import warnings
import torch._dynamo

# Configure torch._dynamo to suppress errors and fall back to eager mode
torch._dynamo.config.suppress_errors = True

# Utility function to check if compilation is supported
def is_compilation_supported():
    try:
        @torch.compile
        def dummy(x):
            return x + 1

        test_tensor = torch.tensor([1.0], device="cuda")
        dummy(test_tensor)
        return True
    except Exception as e:
        warnings.warn(f"Torch compile not supported: {str(e)}. Falling back to eager mode.")
        return False

@dataclass
class NF4Config:
    CLIP_MIN: int = -8
    CLIP_MAX: int = 7
    DTYPE_MIN: int = 0
    DTYPE_MAX: int = 15

class MemoryFormat:
    CONTIGUOUS = "contiguous"
    CHANNELS_LAST = "channels_last"

@triton.jit
def compute_absmax_kernel(
    input_ptr,
    absmax_ptr,
    num_elements,
    BLOCK_SIZE: tl.constexpr
):
    """Compute absolute maximum values using efficient reduction."""
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_elements

    # Load and compute absolute values
    x = tl.load(input_ptr + offsets, mask=mask, other=0.0)
    x_abs = tl.abs(x)

    # Perform reduction to find maximum
    block_max = tl.max(x_abs, axis=0)

    # Store result
    tl.store(absmax_ptr + pid, block_max)

@triton.jit
def dequantize_kernel(
    quantized_ptr,
    absmax_ptr,
    double_quant_scale_ptr,
    output_ptr,
    M, N,
    stride_qm, stride_qn,
    stride_om, stride_on,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    MEMORY_FORMAT: tl.constexpr,
    USE_DOUBLE_QUANT: tl.constexpr,
):
    """Dequantize NF4 values with support for double quantization and different memory formats."""
    # Constants for NF4
    NF4_CLIP_MIN = -8
    NF4_CLIP_MAX = 7

    # Program ID for 2D grid
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Calculate start indices
    start_m = pid_m * BLOCK_M
    start_n = pid_n * BLOCK_N

    # Create ranges for the block
    rm = start_m + tl.arange(0, BLOCK_M)
    rn = start_n + tl.arange(0, BLOCK_N)

    # Create masks for valid elements
    mask_m = rm[:, None] < M
    mask_n = rn[None, :] < N
    mask = mask_m & mask_n

    # Shared memory for frequently accessed scales
    scale_cache = tl.zeros([BLOCK_M], dtype=tl.float32)

    # Load quantized values based on memory format
    if MEMORY_FORMAT == 1:  # channels_last
        quantized = tl.load(
            quantized_ptr + rm[:, None] * stride_qn + rn[None, :] * stride_qm,
            mask=mask, other=0
        )
    else:  # contiguous
        quantized = tl.load(
            quantized_ptr + rm[:, None] * stride_qm + rn[None, :] * stride_qn,
            mask=mask, other=0
        )

    # Load and cache absmax values
    absmax = tl.load(absmax_ptr + rm, mask=rm < M, other=1.0)
    scale_cache = absmax / NF4_CLIP_MAX

    # Apply double quantization if enabled
    if USE_DOUBLE_QUANT:
        double_scale = tl.load(double_quant_scale_ptr + rm, mask=rm < M, other=1.0)
        scale_cache = scale_cache * double_scale

    # Dequantize
    dequantized = (quantized - 8) * scale_cache[:, None]

    # Store result
    tl.store(
        output_ptr + rm[:, None] * stride_om + rn[None, :] * stride_on,
        dequantized,
        mask=mask
    )

class BitsAndBytesNF4:
    """Wrapper for BitsAndBytes NF4 configuration and quantization."""
    def __init__(
        self,
        load_in_4bit: bool = True,
        bnb_4bit_use_double_quant: bool = True,
        bnb_4bit_quant_type: str = "nf4",
        bnb_4bit_compute_dtype: torch.dtype = torch.bfloat16
    ):
        self.config = BitsAndBytesConfig(
            load_in_4bit=load_in_4bit,
            bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
            bnb_4bit_quant_type=bnb_4bit_quant_type,
            bnb_4bit_compute_dtype=bnb_4bit_compute_dtype
        )

        # Validate configuration
        if bnb_4bit_quant_type != "nf4":
            raise ValueError("Only NF4 quantization is supported")

        if bnb_4bit_compute_dtype not in [torch.float16, torch.bfloat16]:
            raise ValueError("Compute dtype must be float16 or bfloat16")

        self.compute_dtype = bnb_4bit_compute_dtype
        self.use_double_quant = bnb_4bit_use_double_quant

class NF4Dequantizer:
    def __init__(
        self,
        bnb_config: Optional[BitsAndBytesConfig] = None,
        memory_format: str = MemoryFormat.CONTIGUOUS,
        block_size: int = 1024,
        compile_mode: str = "reduce-overhead"
    ):
        # Initialize with BitsAndBytes config if provided
        if bnb_config is None:
            bnb_config = BitsAndBytesNF4().config

        self.use_double_quant = bnb_config.bnb_4bit_use_double_quant
        self.compute_dtype = bnb_config.bnb_4bit_compute_dtype
        self.memory_format = memory_format
        self.block_size = block_size
        self.config = NF4Config()

        # Check if compilation is supported
        self.use_compilation = is_compilation_supported()

        if self.use_compilation:
            # Compile the compute methods
            self._compute_absmax_compiled = torch.compile(
                self._compute_absmax_impl,
                mode=compile_mode,
                fullgraph=True
            )
            self._dequantize_compiled = torch.compile(
                self._dequantize_impl,
                mode=compile_mode,
                fullgraph=True
            )
        else:
            # Use non-compiled implementations
            self._compute_absmax_compiled = self._compute_absmax_impl
            self._dequantize_compiled = self._dequantize_impl

    @torch.no_grad()
    def _compute_absmax_impl(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Implementation of absmax computation."""
        try:
            num_elements = input_tensor.numel()
            num_blocks = (num_elements + self.block_size - 1) // self.block_size

            absmax = torch.empty(num_blocks, device=input_tensor.device, dtype=self.compute_dtype)

            compute_absmax_kernel[(num_blocks,)](
                input_tensor,
                absmax,
                num_elements,
                BLOCK_SIZE=self.block_size
            )

            return absmax
        except Exception as e:
            warnings.warn(f"Triton kernel failed: {str(e)}. Falling back to PyTorch implementation.")
            return torch.max(torch.abs(input_tensor)).to(self.compute_dtype)

    @torch.no_grad()
    def _dequantize_impl(
        self,
        quantized_tensor: torch.Tensor,
        absmax_tensor: torch.Tensor,
        double_quant_scale: Optional[torch.Tensor]
    ) -> torch.Tensor:
        """Implementation of dequantization."""
        try:
            M, N = quantized_tensor.shape

            output = torch.empty(
                (M, N),
                device=quantized_tensor.device,
                dtype=self.compute_dtype
            )

            BLOCK_M, BLOCK_N = 128, 128
            grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
            memory_format_int = 1 if self.memory_format == MemoryFormat.CHANNELS_LAST else 0

            dequantize_kernel[grid](
                quantized_tensor,
                absmax_tensor,
                double_quant_scale if double_quant_scale is not None else absmax_tensor,
                output,
                M, N,
                quantized_tensor.stride(0), quantized_tensor.stride(1),
                output.stride(0), output.stride(1),
                BLOCK_M=BLOCK_M,
                BLOCK_N=BLOCK_N,
                MEMORY_FORMAT=memory_format_int,
                USE_DOUBLE_QUANT=self.use_double_quant
            )

            return output
        except Exception as e:
            warnings.warn(f"Triton kernel failed: {str(e)}. Falling back to PyTorch implementation.")
            # PyTorch fallback implementation
            scale = absmax_tensor[:, None] / self.config.CLIP_MAX
            if double_quant_scale is not None and self.use_double_quant:
                scale = scale * double_quant_scale[:, None]
            return ((quantized_tensor - 8) * scale).to(self.compute_dtype)

    @torch.no_grad()
    def compute_absmax(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Compute absolute maximum values for input tensor using compiled implementation."""
        return self._compute_absmax_compiled(input_tensor)

    @torch.no_grad()
    def dequantize(
        self,
        quantized_tensor: torch.Tensor,
        absmax_tensor: Optional[torch.Tensor] = None,
        double_quant_scale: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Dequantize NF4 tensor to fp16/bf16 using compiled implementation."""
        # Input validation
        if not torch.is_tensor(quantized_tensor):
            raise TypeError("quantized_tensor must be a torch.Tensor")

        if not quantized_tensor.is_cuda:
            raise ValueError("quantized_tensor must be on CUDA device")

        if torch.any(quantized_tensor < self.config.DTYPE_MIN) or \
           torch.any(quantized_tensor > self.config.DTYPE_MAX):
            raise ValueError(f"Quantized values must be in range [{self.config.DTYPE_MIN}, {self.config.DTYPE_MAX}]")

        # Compute absmax if not provided
        if absmax_tensor is None:
            absmax_tensor = self.compute_absmax(quantized_tensor)

        return self._dequantize_compiled(quantized_tensor, absmax_tensor, double_quant_scale)

def benchmark_dequantizer(
    shapes: list[Tuple[int, int]] = None,
    num_warmup: int = 5,
    num_runs: int = 100
):
    """Benchmark the NF4Dequantizer with various matrix shapes."""
    if shapes is None:
        shapes = [
            (10, 10),      # Small square matrix
            (1024, 32),    # Tall matrix
            (32, 1024),    # Wide matrix
            (1024, 1024),  # Large square matrix
            (4096, 4096),  # Very large square matrix
        ]

    # Create BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # Create dequantizers with different configurations
    print("\nInitializing dequantizers...")
    dequantizer_compiled = NF4Dequantizer(
        bnb_config=bnb_config,
        compile_mode="reduce-overhead"
    )
    dequantizer_normal = NF4Dequantizer(
        bnb_config=bnb_config
    )

    print("\nBenchmarking NF4 Dequantization with BitsAndBytes config:")
    print("=" * 60)
    print(f"Config: {bnb_config}")
    print(f"Using compilation: {dequantizer_compiled.use_compilation}")
    print("-" * 60)
    print(f"{'Shape':>12} | {'Normal (ms)':>12} | {'Compiled (ms)':>12} | {'Speedup':>8}")
    print("-" * 60)

    for M, N in shapes:
        try:
            # Generate test data
            quantized = torch.randint(0, 16, (M, N), dtype=torch.int32, device='cuda')
            absmax = torch.rand(M, device='cuda') * 10
            double_quant_scale = torch.rand(M, device='cuda') * 2

            # Convert to appropriate dtype
            absmax = absmax.to(bnb_config.bnb_4bit_compute_dtype)
            double_quant_scale = double_quant_scale.to(bnb_config.bnb_4bit_compute_dtype)

            # Warmup
            for _ in range(num_warmup):
                _ = dequantizer_normal.dequantize(quantized, absmax, double_quant_scale)
                _ = dequantizer_compiled.dequantize(quantized, absmax, double_quant_scale)

            # Benchmark normal version
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            for _ in range(num_runs):
                _ = dequantizer_normal.dequantize(quantized, absmax, double_quant_scale)
            end.record()
            torch.cuda.synchronize()
            normal_time = start.elapsed_time(end) / num_runs

            # Benchmark compiled version
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            for _ in range(num_runs):
                _ = dequantizer_compiled.dequantize(quantized, absmax, double_quant_scale)
            end.record()
            torch.cuda.synchronize()
            compiled_time = start.elapsed_time(end) / num_runs

            # Calculate speedup
            speedup = normal_time / compiled_time

            print(f"{M}x{N:>7} | {normal_time:>10.3f} | {compiled_time:>10.3f} | {speedup:>7.2f}x")
        except Exception as e:
            print(f"{M}x{N:>7} | Error: {str(e)}")

    print("=" * 60)

if __name__ == "__main__":
    print("Running NF4 dequantization benchmarks with BitsAndBytes config...")
    benchmark_dequantizer()


Running NF4 dequantization benchmarks with BitsAndBytes config...

Initializing dequantizers...

Benchmarking NF4 Dequantization with BitsAndBytes config:
Config: BitsAndBytesConfig {
  "_load_in_4bit": true,
  "_load_in_8bit": false,
  "bnb_4bit_compute_dtype": "bfloat16",
  "bnb_4bit_quant_storage": "uint8",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": true,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

Using compilation: True
------------------------------------------------------------
       Shape |  Normal (ms) | Compiled (ms) |  Speedup
------------------------------------------------------------


W0223 05:56:12.076000 612 torch/_dynamo/convert_frame.py:844] [4/1050] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:12.076000 612 torch/_dynamo/convert_frame.py:844] [4/1050]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py:605)
W0223 05:56:12.076000 612 torch/_dynamo/convert_frame.py:844] [4/1050]    last reason: Unable to find recompilation reasons
W0223 05:56:12.076000 612 torch/_dynamo/convert_frame.py:844] [4/1050] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0223 05:56:12.076000 612 torch/_dynamo/convert_frame.py:844] [4/1050] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W0223 05:56:12.082000 612 torch/_dynamo/convert_frame.py:844] [4/1051] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:12.082000 612 torch/_dynamo/convert_frame.py:844] [4/1051]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/

10x     10 |      5.966 |      3.139 |    1.90x


W0223 05:56:13.240000 612 torch/_dynamo/convert_frame.py:844] [4/1310] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:13.240000 612 torch/_dynamo/convert_frame.py:844] [4/1310]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py:605)
W0223 05:56:13.240000 612 torch/_dynamo/convert_frame.py:844] [4/1310]    last reason: Unable to find recompilation reasons
W0223 05:56:13.240000 612 torch/_dynamo/convert_frame.py:844] [4/1310] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0223 05:56:13.240000 612 torch/_dynamo/convert_frame.py:844] [4/1310] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W0223 05:56:13.243000 612 torch/_dynamo/convert_frame.py:844] [4/1311] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:13.243000 612 torch/_dynamo/convert_frame.py:844] [4/1311]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/

1024x     32 |      5.214 |      4.187 |    1.25x


W0223 05:56:14.201000 612 torch/_dynamo/convert_frame.py:844] [4/1530] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:14.201000 612 torch/_dynamo/convert_frame.py:844] [4/1530]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py:605)
W0223 05:56:14.201000 612 torch/_dynamo/convert_frame.py:844] [4/1530]    last reason: Unable to find recompilation reasons
W0223 05:56:14.201000 612 torch/_dynamo/convert_frame.py:844] [4/1530] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0223 05:56:14.201000 612 torch/_dynamo/convert_frame.py:844] [4/1530] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W0223 05:56:14.205000 612 torch/_dynamo/convert_frame.py:844] [4/1531] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:14.205000 612 torch/_dynamo/convert_frame.py:844] [4/1531]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/

32x   1024 |      3.785 |      7.679 |    0.49x


W0223 05:56:15.381000 612 torch/_dynamo/convert_frame.py:844] [4/1736] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:15.381000 612 torch/_dynamo/convert_frame.py:844] [4/1736]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py:605)
W0223 05:56:15.381000 612 torch/_dynamo/convert_frame.py:844] [4/1736]    last reason: Unable to find recompilation reasons
W0223 05:56:15.381000 612 torch/_dynamo/convert_frame.py:844] [4/1736] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0223 05:56:15.381000 612 torch/_dynamo/convert_frame.py:844] [4/1736] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W0223 05:56:15.383000 612 torch/_dynamo/convert_frame.py:844] [4/1737] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:15.383000 612 torch/_dynamo/convert_frame.py:844] [4/1737]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/

1024x   1024 |      3.122 |      4.281 |    0.73x


W0223 05:56:16.169000 612 torch/_dynamo/convert_frame.py:844] [4/1913] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:16.169000 612 torch/_dynamo/convert_frame.py:844] [4/1913]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py:605)
W0223 05:56:16.169000 612 torch/_dynamo/convert_frame.py:844] [4/1913]    last reason: Unable to find recompilation reasons
W0223 05:56:16.169000 612 torch/_dynamo/convert_frame.py:844] [4/1913] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0223 05:56:16.169000 612 torch/_dynamo/convert_frame.py:844] [4/1913] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
W0223 05:56:16.178000 612 torch/_dynamo/convert_frame.py:844] [4/1914] torch._dynamo hit config.accumulated_cache_size_limit (256)
W0223 05:56:16.178000 612 torch/_dynamo/convert_frame.py:844] [4/1914]    function: 'run' (/usr/local/lib/python3.11/dist-packages/triton/runtime/

4096x   4096 |      7.798 |      7.342 |    1.06x


