<a href="https://colab.research.google.com/github/Aryan8912/Unsolth.ai-challenage/blob/main/nf4_with_triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import triton
import triton.language as tl
from dataclasses import dataclass
from typing import Optional, Tuple
import time

@dataclass
class NF4Config:
    CLIP_MIN: int = -8
    CLIP_MAX: int = 7
    DTYPE_MIN: int = 0
    DTYPE_MAX: int = 15

class MemoryFormat:
    CONTIGUOUS = "contiguous"
    CHANNELS_LAST = "channels_last"

@triton.jit
def compute_absmax_kernel(
    input_ptr,
    absmax_ptr,
    num_elements,
    BLOCK_SIZE: tl.constexpr
):
    """Compute absolute maximum values using efficient reduction."""
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_elements

    # Load and compute absolute values
    x = tl.load(input_ptr + offsets, mask=mask, other=0.0)
    x_abs = tl.abs(x)

    # Perform reduction to find maximum
    block_max = tl.max(x_abs, axis=0)

    # Store result
    tl.store(absmax_ptr + pid, block_max)

@triton.jit
def dequantize_kernel(
    quantized_ptr,
    absmax_ptr,
    double_quant_scale_ptr,
    output_ptr,
    M, N,
    stride_qm, stride_qn,
    stride_om, stride_on,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    MEMORY_FORMAT: tl.constexpr,
    USE_DOUBLE_QUANT: tl.constexpr,
):
    """Dequantize NF4 values with support for double quantization and different memory formats."""
    # Constants for NF4
    NF4_CLIP_MIN = -8
    NF4_CLIP_MAX = 7

    # Program ID for 2D grid
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Calculate start indices
    start_m = pid_m * BLOCK_M
    start_n = pid_n * BLOCK_N

    # Create ranges for the block
    rm = start_m + tl.arange(0, BLOCK_M)
    rn = start_n + tl.arange(0, BLOCK_N)

    # Create masks for valid elements
    mask_m = rm[:, None] < M
    mask_n = rn[None, :] < N
    mask = mask_m & mask_n

    # Shared memory for frequently accessed scales
    scale_cache = tl.zeros([BLOCK_M], dtype=tl.float32)

    # Load quantized values based on memory format
    if MEMORY_FORMAT == 1:  # channels_last
        quantized = tl.load(
            quantized_ptr + rm[:, None] * stride_qn + rn[None, :] * stride_qm,
            mask=mask, other=0
        )
    else:  # contiguous
        quantized = tl.load(
            quantized_ptr + rm[:, None] * stride_qm + rn[None, :] * stride_qn,
            mask=mask, other=0
        )

    # Load and cache absmax values
    absmax = tl.load(absmax_ptr + rm, mask=rm < M, other=1.0)
    scale_cache = absmax / NF4_CLIP_MAX

    # Apply double quantization if enabled
    if USE_DOUBLE_QUANT:
        double_scale = tl.load(double_quant_scale_ptr + rm, mask=rm < M, other=1.0)
        scale_cache = scale_cache * double_scale

    # Dequantize
    dequantized = (quantized - 8) * scale_cache[:, None]

    # Store result
    tl.store(
        output_ptr + rm[:, None] * stride_om + rn[None, :] * stride_on,
        dequantized,
        mask=mask
    )

class NF4Dequantizer:
    def __init__(
        self,
        use_double_quant: bool = True,
        memory_format: str = MemoryFormat.CONTIGUOUS,
        block_size: int = 1024
    ):
        self.use_double_quant = use_double_quant
        self.memory_format = memory_format
        self.block_size = block_size
        self.config = NF4Config()

    @torch.no_grad()
    def compute_absmax(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """Compute absolute maximum values for input tensor."""
        num_elements = input_tensor.numel()
        num_blocks = (num_elements + self.block_size - 1) // self.block_size

        absmax = torch.empty(num_blocks, device=input_tensor.device, dtype=torch.float32)

        compute_absmax_kernel[(num_blocks,)](
            input_tensor,
            absmax,
            num_elements,
            BLOCK_SIZE=self.block_size
        )

        return absmax

    @torch.no_grad()
    def dequantize(
        self,
        quantized_tensor: torch.Tensor,
        absmax_tensor: Optional[torch.Tensor] = None,
        double_quant_scale: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Dequantize NF4 tensor to fp16/bf16."""
        # Input validation
        if not torch.is_tensor(quantized_tensor):
            raise TypeError("quantized_tensor must be a torch.Tensor")

        if not quantized_tensor.is_cuda:
            raise ValueError("quantized_tensor must be on CUDA device")

        if torch.any(quantized_tensor < self.config.DTYPE_MIN) or \
           torch.any(quantized_tensor > self.config.DTYPE_MAX):
            raise ValueError(f"Quantized values must be in range [{self.config.DTYPE_MIN}, {self.config.DTYPE_MAX}]")

        # Get tensor dimensions
        M, N = quantized_tensor.shape

        # Compute or validate absmax
        if absmax_tensor is None:
            absmax_tensor = self.compute_absmax(quantized_tensor)

        # Prepare output tensor
        output = torch.empty(
            (M, N),
            device=quantized_tensor.device,
            dtype=torch.float32
        )

        # Define block sizes for the kernel
        BLOCK_M, BLOCK_N = 128, 128

        # Calculate grid dimensions
        grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

        # Convert memory format to integer for kernel
        memory_format_int = 1 if self.memory_format == MemoryFormat.CHANNELS_LAST else 0

        # Launch kernel
        dequantize_kernel[grid](
            quantized_tensor,
            absmax_tensor,
            double_quant_scale if double_quant_scale is not None else absmax_tensor,  # Fallback
            output,
            M, N,
            quantized_tensor.stride(0), quantized_tensor.stride(1),
            output.stride(0), output.stride(1),
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            MEMORY_FORMAT=memory_format_int,
            USE_DOUBLE_QUANT=self.use_double_quant
        )

        return output

def test_dequantizer(shapes: list[Tuple[int, int]] = None):
    """Test the NF4Dequantizer with various matrix shapes."""
    if shapes is None:
        shapes = [
            (10, 10),      # Small square matrix
            (1024, 32),    # Tall matrix
            (32, 1024),    # Wide matrix
            (1, 1024),     # Single row
            (1024, 1),     # Single column
            (256, 256)     # Medium square matrix
        ]

    dequantizer = NF4Dequantizer(use_double_quant=True)

    for M, N in shapes:
        print(f"\nTesting shape: {M}x{N}")

        # Generate test data
        quantized = torch.randint(
            0, 16, (M, N),
            dtype=torch.int32,
            device='cuda'
        )
        absmax = torch.rand(M, device='cuda') * 10
        double_quant_scale = torch.rand(M, device='cuda') * 2

        try:
            # Time the dequantization
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            start_event.record()
            output = dequantizer.dequantize(
                quantized,
                absmax,
                double_quant_scale
            )
            end_event.record()

            torch.cuda.synchronize()
            elapsed_time = start_event.elapsed_time(end_event)

            print(f"✓ Dequantization successful")
            print(f"✓ Output shape: {output.shape}")
            print(f"✓ Processing time: {elapsed_time:.2f} ms")

            # Basic output validation
            assert not torch.isnan(output).any(), "Output contains NaN values"
            assert not torch.isinf(output).any(), "Output contains Inf values"

        except Exception as e:
            print(f"✗ Test failed: {str(e)}")

if __name__ == "__main__":
    # Run tests
    print("Running NF4 dequantization tests...")
    test_dequantizer()


Running NF4 dequantization tests...

Testing shape: 10x10
✓ Dequantization successful
✓ Output shape: torch.Size([10, 10])
✓ Processing time: 5456.33 ms

Testing shape: 1024x32
✓ Dequantization successful
✓ Output shape: torch.Size([1024, 32])
✓ Processing time: 412.36 ms

Testing shape: 32x1024
✓ Dequantization successful
✓ Output shape: torch.Size([32, 1024])
✓ Processing time: 0.42 ms

Testing shape: 1x1024
✓ Dequantization successful
✓ Output shape: torch.Size([1, 1024])
✓ Processing time: 681.82 ms

Testing shape: 1024x1
✓ Dequantization successful
✓ Output shape: torch.Size([1024, 1])
✓ Processing time: 649.02 ms

Testing shape: 256x256
✓ Dequantization successful
✓ Output shape: torch.Size([256, 256])
✓ Processing time: 0.39 ms
