<a href="https://colab.research.google.com/github/Aryan8912/Unsolth.ai-challenage/blob/main/nf4_with_triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import triton
import triton.language as tl
from dataclasses import dataclass
from typing import Optional, Tuple
import time

@dataclass
class NF4Config:
    CLIP_MIN: int = -8
    CLIP_MAX: int = 7
    DTYPE_MIN: int = 0
    DTYPE_MAX: int = 15

class MemoryFormat:
    CONTIGUOUS = "contiguous"
    CHANNELS_LAST = "channels_last"

@triton.jit
def compute_absmax_kernel(
    input_ptr,
    absmax_ptr,
    num_elements,
    BLOCK_SIZE: tl.constexpr
):
    """Compute absolute maximum values using efficient reduction."""
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_elements

    # Load and compute absolute values
    x = tl.load(input_ptr + offsets, mask=mask, other=0.0)
    x_abs = tl.abs(x)

    # Perform reduction to find maximum
    block_max = tl.max(x_abs, axis=0)

    # Store result
    tl.store(absmax_ptr + pid, block_max)

@triton.jit
def dequantize_kernel(
    quantized_ptr,
    absmax_ptr,
    double_quant_scale_ptr,
    output_ptr,
    M, N,
    stride_qm, stride_qn,
    stride_om, stride_on,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    MEMORY_FORMAT: tl.constexpr,
    USE_DOUBLE_QUANT: tl.constexpr,
):
    """Dequantize NF4 values with support for double quantization and different memory formats."""
    # Constants for NF4
    NF4_CLIP_MIN = -8
    NF4_CLIP_MAX = 7

    # Program ID for 2D grid
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Calculate start indices
    start_m = pid_m * BLOCK_M
    start_n = pid_n * BLOCK_N

    # Create ranges for the block
    rm = start_m + tl.arange(0, BLOCK_M)
    rn = start_n + tl.arange(0, BLOCK_N)

    # Create masks for valid elements
    mask_m = rm[:, None] < M
    mask_n = rn[None, :] < N
    mask = mask_m & mask_n

    # Shared memory for frequently accessed scales
    scale_cache = tl.zeros([BLOCK_M], dtype=tl.float32)

    # Load quantized values based on memory format
    if MEMORY_FORMAT == 1:  # channels_last
        quantized = tl.load(
            quantized_ptr + rm[:, None] * stride_qn + rn[None, :] * stride_qm,
            mask=mask, other=0
        )
    else:  # contiguous
        quantized = tl.load(
            quantized_ptr + rm[:, None] * stride_qm + rn[None, :] * stride_qn,
            mask=mask, other=0
        )

    # Load and cache absmax values
    absmax = tl.load(absmax_ptr + rm, mask=rm < M, other=1.0)
    scale_cache = absmax / NF4_CLIP_MAX

    # Apply double quantization if enabled
    if USE_DOUBLE_QUANT:
        double_scale = tl.load(double_quant_scale_ptr + rm, mask=rm < M, other=1.0)
        scale_cache = scale_cache * double_scale

    # Dequantize
    dequantized = (quantized - 8) * scale_cache[:, None]

    # Store result
    tl.store(
        output_ptr + rm[:, None] * stride_om + rn[None, :] * stride_on,
        dequantized,
        mask=mask
    )

class NF4Dequantizer:
    def __init__(
        self,
        use_double_quant: bool = True,
        memory_format: str = MemoryFormat.CONTIGUOUS,
        block_size: int = 1024,
        compile_mode: str = "reduce-overhead"
    ):
        self.use_double_quant = use_double_quant
        self.memory_format = memory_format
        self.block_size = block_size
        self.config = NF4Config()

        # Compile the compute methods
        self._compute_absmax_compiled = torch.compile(
            self._compute_absmax_impl,
            mode=compile_mode,
            fullgraph=True
        )
        self._dequantize_compiled = torch.compile(
            self._dequantize_impl,
            mode=compile_mode,
            fullgraph=True
        )

    @torch.no_grad()
    def _compute_absmax_impl(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Implementation of absmax computation."""
        num_elements = input_tensor.numel()
        num_blocks = (num_elements + self.block_size - 1) // self.block_size

        absmax = torch.empty(num_blocks, device=input_tensor.device, dtype=torch.float32)

        compute_absmax_kernel[(num_blocks,)](
            input_tensor,
            absmax,
            num_elements,
            BLOCK_SIZE=self.block_size
        )

        return absmax

    @torch.no_grad()
    def compute_absmax(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Compute absolute maximum values for input tensor using compiled implementation."""
        return self._compute_absmax_compiled(input_tensor)

    @torch.no_grad()
    def _dequantize_impl(
        self,
        quantized_tensor: torch.Tensor,
        absmax_tensor: torch.Tensor,
        double_quant_scale: Optional[torch.Tensor]
    ) -> torch.Tensor:
        """Implementation of dequantization."""
        M, N = quantized_tensor.shape

        output = torch.empty(
            (M, N),
            device=quantized_tensor.device,
            dtype=torch.float32
        )

        BLOCK_M, BLOCK_N = 128, 128
        grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
        memory_format_int = 1 if self.memory_format == MemoryFormat.CHANNELS_LAST else 0

        dequantize_kernel[grid](
            quantized_tensor,
            absmax_tensor,
            double_quant_scale if double_quant_scale is not None else absmax_tensor,
            output,
            M, N,
            quantized_tensor.stride(0), quantized_tensor.stride(1),
            output.stride(0), output.stride(1),
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            MEMORY_FORMAT=memory_format_int,
            USE_DOUBLE_QUANT=self.use_double_quant
        )

        return output

    @torch.no_grad()
    def dequantize(
        self,
        quantized_tensor: torch.Tensor,
        absmax_tensor: Optional[torch.Tensor] = None,
        double_quant_scale: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Dequantize NF4 tensor to fp16/bf16 using compiled implementation."""
        # Input validation
        if not torch.is_tensor(quantized_tensor):
            raise TypeError("quantized_tensor must be a torch.Tensor")

        if not quantized_tensor.is_cuda:
            raise ValueError("quantized_tensor must be on CUDA device")

        if torch.any(quantized_tensor < self.config.DTYPE_MIN) or \
           torch.any(quantized_tensor > self.config.DTYPE_MAX):
            raise ValueError(f"Quantized values must be in range [{self.config.DTYPE_MIN}, {self.config.DTYPE_MAX}]")

        # Compute absmax if not provided
        if absmax_tensor is None:
            absmax_tensor = self.compute_absmax(quantized_tensor)

        return self._dequantize_compiled(quantized_tensor, absmax_tensor, double_quant_scale)

def benchmark_dequantizer(
    shapes: list[Tuple[int, int]] = None,
    num_warmup: int = 5,
    num_runs: int = 100
):
    """Benchmark the NF4Dequantizer with various matrix shapes."""
    if shapes is None:
        shapes = [
            (10, 10),      # Small square matrix
            (1024, 32),    # Tall matrix
            (32, 1024),    # Wide matrix
            (1024, 1024),  # Large square matrix
            (4096, 4096),  # Very large square matrix
        ]

    # Create dequantizers - one with compilation and one without
    dequantizer_compiled = NF4Dequantizer(use_double_quant=True, compile_mode="reduce-overhead")
    dequantizer_normal = NF4Dequantizer(use_double_quant=True)

    print("\nBenchmarking NF4 Dequantization:")
    print("=" * 60)
    print(f"{'Shape':>12} | {'Normal (ms)':>12} | {'Compiled (ms)':>12} | {'Speedup':>8}")
    print("-" * 60)

    for M, N in shapes:
        # Generate test data
        quantized = torch.randint(0, 16, (M, N), dtype=torch.int32, device='cuda')
        absmax = torch.rand(M, device='cuda') * 10
        double_quant_scale = torch.rand(M, device='cuda') * 2

        # Warmup
        for _ in range(num_warmup):
            _ = dequantizer_normal.dequantize(quantized, absmax, double_quant_scale)
            _ = dequantizer_compiled.dequantize(quantized, absmax, double_quant_scale)

        # Benchmark normal version
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        for _ in range(num_runs):
            _ = dequantizer_normal.dequantize(quantized, absmax, double_quant_scale)
        end.record()
        torch.cuda.synchronize()
        normal_time = start.elapsed_time(end) / num_runs

        # Benchmark compiled version
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        for _ in range(num_runs):
            _ = dequantizer_compiled.dequantize(quantized, absmax, double_quant_scale)
        end.record()
        torch.cuda.synchronize()
        compiled_time = start.elapsed_time(end) / num_runs

        # Calculate speedup
        speedup = normal_time / compiled_time

        print(f"{M}x{N:>7} | {normal_time:>10.3f} | {compiled_time:>10.3f} | {speedup:>7.2f}x")

    print("=" * 60)

if __name__ == "__main__":
    print("Running NF4 dequantization benchmarks...")
    benchmark_dequantizer()


Running NF4 dequantization benchmarks...

Benchmarking NF4 Dequantization:
       Shape |  Normal (ms) | Compiled (ms) |  Speedup
------------------------------------------------------------
10x     10 |      0.365 |      0.386 |    0.95x
1024x     32 |      0.346 |      0.326 |    1.06x
32x   1024 |      0.358 |      0.455 |    0.79x
1024x   1024 |      0.385 |      0.361 |    1.07x
4096x   4096 |      2.284 |      2.232 |    1.02x
