<a href="https://colab.research.google.com/github/Datbwoyyy/SlothAi/blob/main/CONVERT_NF4_Quantized_Tensor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##**Convert** a nf4 quantized tensor into fp16 or bf16 into a **single Triton kernel**

In [7]:
import triton
import triton.language as tl
import torch
import time

@triton.jit
def nf4_dequant_kernel(weight_ptr, absmax_ptr, out_ptr,
                       M: int, N: int, stride_row: int, stride_col: int,
                       BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
    # Compute row and column indices of the block this program instance handles
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # Use tl.arange instead of tl.static_range for creating ranges that can be used in arithmetic operations
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)  # row indices
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)  # column indices

    # Create a 2D mask to avoid out–of–bounds accesses
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

    # Load a tile of nf4 weights
    weight_tile_ptr = weight_ptr + offs_m[:, None] * stride_row + offs_n[None, :] * stride_col
    w_q = tl.load(weight_tile_ptr, mask=mask, other=8)

    # Convert [0,15] range to [-8,7] range
    w_centered = tl.cast(w_q, tl.int32) - 8

    # Load per-row absmax scaling factors
    absmax = tl.load(absmax_ptr + offs_m, mask=(offs_m < M), other=0.0)
    absmax_fp16 = tl.cast(absmax, tl.float16)

    # Convert weights to fp16 and apply scaling
    w_centered_fp16 = tl.cast(w_centered, tl.float16)
    deq = (w_centered_fp16 * absmax_fp16[:, None]) / 7.0

    # Store dequantized tile
    out_ptr_tile = out_ptr + offs_m[:, None] * stride_row + offs_n[None, :] * stride_col
    tl.store(out_ptr_tile, deq, mask=mask)



# Wrapper function
def nf4_to_fp16(nf4, absmax):
    M, N = nf4.shape
    out = torch.empty((M, N), device=nf4.device, dtype=torch.float16)
    grid = (triton.cdiv(M, 64), triton.cdiv(N, 64))
    nf4_dequant_kernel[grid](
        nf4, absmax, out,
        M, N,
        nf4.stride(0), nf4.stride(1),
        BLOCK_M=64, BLOCK_N=64
    )
    return out


# Testing the fix
def test_dequantize_function():
    M, N = 1024, 1024
    nf4 = torch.randint(0, 16, (M, N), device='cuda', dtype=torch.int8)
    absmax = (torch.rand(M, device='cuda', dtype=torch.float16) * 0.9 + 0.1)

    torch.cuda.synchronize()
    out_triton = nf4_to_fp16(nf4, absmax)
    torch.cuda.synchronize()
    out_ref = ((nf4.to(torch.int32) - 8).to(torch.float16) * absmax.unsqueeze(1)) / 7.0

    print("Results match:", torch.allclose(out_triton, out_ref, atol=1e-2))

    # Benchmark
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(10):
        nf4_to_fp16(nf4, absmax)
    torch.cuda.synchronize()
    triton_time = (time.time() - t0) / 10.0

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(10):
        out_ref = ((nf4.to(torch.int32) - 8).to(torch.float16) * absmax.unsqueeze(1)) / 7.0
    torch.cuda.synchronize()
    ref_time = (time.time() - t0) / 10.0

    print(f"Triton Speedup: {ref_time / triton_time:.2f}x")

# Run test
test_dequantize_function()


Results match: True
Triton Speedup: 1.03x


In [3]:
!pip install --upgrade triton

Collecting triton
  Downloading triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (253.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.2/253.2 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton
  Attempting uninstall: triton
    Found existing installation: triton 2.0.0
    Uninstalling triton-2.0.0:
      Successfully uninstalled triton-2.0.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.0.1 requires triton==2.0.0; platform_system == "Linux" and platform_machine == "x86_64", but you have triton 3.2.0 which is incompatible.
torchaudio 2.5.1+cu124 requires torch==2.5.1, but you have torch 2.0.1 which is incompatible.
torchvision 0.20.1+cu124 requires torch==2.5.1,

## IMPROVED SPEED TO 1.51X

In [4]:
import triton
import triton.language as tl
import torch
import time

@triton.jit
def nf4_dequant_kernel(weight_ptr, absmax_ptr, out_ptr,
                       M: int, N: int, stride_row: int, stride_col: int,
                       BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)

    weight_tile_ptr = weight_ptr + offs_m[:, None] * stride_row + offs_n[None, :] * stride_col
    w_q = tl.load(weight_tile_ptr, mask=mask, other=8)

    w_centered = tl.cast(w_q, tl.int32) - 8

    absmax = tl.load(absmax_ptr + offs_m, mask=(offs_m < M), other=0.0)
    absmax_fp16 = tl.cast(absmax, tl.float16)

    w_centered_fp16 = tl.cast(w_centered, tl.float16)
    deq = (w_centered_fp16 * absmax_fp16[:, None]) / 7.0

    out_ptr_tile = out_ptr + offs_m[:, None] * stride_row + offs_n[None, :] * stride_col
    tl.store(out_ptr_tile, deq, mask=mask)

def nf4_to_fp16(nf4, absmax):
    M, N = nf4.shape
    out = torch.empty((M, N), device=nf4.device, dtype=torch.float16)
    grid = (triton.cdiv(M, 128), triton.cdiv(N, 128))  # Adjusted block sizes
    nf4_dequant_kernel[grid](
        nf4, absmax, out,
        M, N,
        nf4.stride(0), nf4.stride(1),
        BLOCK_M=128, BLOCK_N=128  # Adjusted block sizes
    )
    return out

def test_dequantize_function():
    M, N = 1024, 1024
    nf4 = torch.randint(0, 16, (M, N), device='cuda', dtype=torch.int8)
    absmax = (torch.rand(M, device='cuda', dtype=torch.float16) * 0.9 + 0.1)

    torch.cuda.synchronize()
    out_triton = nf4_to_fp16(nf4, absmax)
    torch.cuda.synchronize()
    out_ref = ((nf4.to(torch.int32) - 8).to(torch.float16) * absmax.unsqueeze(1)) / 7.0

    print("Results match:", torch.allclose(out_triton, out_ref, atol=1e-2))

    # Benchmark
    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(10):
        nf4_to_fp16(nf4, absmax)
    torch.cuda.synchronize()
    triton_time = (time.time() - t0) / 10.0

    torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(10):
        out_ref = ((nf4.to(torch.int32) - 8).to(torch.float16) * absmax.unsqueeze(1)) / 7.0
    torch.cuda.synchronize()
    ref_time = (time.time() - t0) / 10.0

    print(f"Triton Speedup: {ref_time / triton_time:.2f}x")

# Run test
test_dequantize_function()

Results match: True
Triton Speedup: 1.51x
