<a href="https://colab.research.google.com/github/Aryan8912/Unsolth.ai-challenage/blob/main/Convert_nf4_to_Triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Code to install Unsloth, Triton, Torch etc
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [2]:
import torch
import triton
from triton import language as tl
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding,
)
import bitsandbytes as bnb

In [3]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.bfloat16
)

In [15]:
import triton
import triton.language as tl

@triton.jit
def compute_absmax_kernel(
    input_ptr, absmax_ptr, num_elements, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(0)
    start_idx = pid * BLOCK_SIZE
    offsets = start_idx + tl.arange(0, BLOCK_SIZE)
    mask = offsets < num_elements
    data = tl.load(input_ptr + offsets, mask=mask, other=0.0)
    abs_data = tl.abs(data)
    block_absmax = tl.max(abs_data, axis=0)
    tl.store(absmax_ptr + pid, block_absmax)


In [11]:
import torch

BLOCK_SIZE = 1024

input_tensor = torch.randn(10, 10).cuda()


num_elements = input_tensor.numel()
num_blocks = (num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE


absmax_tensor = torch.empty(num_blocks, device=input_tensor.device, dtype=torch.float32)


compute_absmax_kernel[(num_blocks,)](
    input_tensor, absmax_tensor, num_elements, BLOCK_SIZE=BLOCK_SIZE
)

<triton.compiler.compiler.CompiledKernel at 0x7c8127bd1ad0>

In [16]:
@triton.jit
def dequantize_kernel(
    quantized_ptr, absmax_ptr, output_ptr, M, N,
    stride_qm, stride_qn, stride_om, stride_on,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    start_m = pid_m * BLOCK_M
    start_n = pid_n * BLOCK_N
    rm = start_m + tl.arange(0, BLOCK_M)
    rn = start_n + tl.arange(0, BLOCK_N)
    quantized = tl.load(
        quantized_ptr + rm[:, None] * stride_qm + rn[None, :] * stride_qn,
        mask=(rm[:, None] < M) & (rn[None, :] < N), other=0
    )
    absmax = tl.load(absmax_ptr + rm, mask=rm < M, other=1.0)
    scale = absmax[:, None] / 7.0
    dequantized = (quantized - 8) * scale
    tl.store(
        output_ptr + rm[:, None] * stride_om + rn[None, :] * stride_on,
        dequantized,
        mask=(rm[:, None] < M) & (rn[None, :] < N)
    )


In [17]:
# Define block sizes
BLOCK_SIZE = 1024
BLOCK_M, BLOCK_N = 128, 128

# Using torch.compile to optimize performance
@torch.compile
def main():
    # Input tensor
    input_tensor = torch.randn(10, 10).cuda()

    # Compute absmax
    num_elements = input_tensor.numel()
    num_blocks = (num_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
    absmax_tensor = torch.empty(num_blocks, device=input_tensor.device, dtype=torch.float32)
    compute_absmax_kernel[(num_blocks,)](input_tensor, absmax_tensor, num_elements, BLOCK_SIZE=BLOCK_SIZE)

    # Quantized tensor
    quantized_tensor = torch.randint(0, 16, (10, 10), dtype=torch.int32, device='cuda')
    M, N = quantized_tensor.shape
    dequantized_tensor = torch.empty((M, N), device=quantized_tensor.device, dtype=torch.float32)

    # Dequantize
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
    dequantize_kernel[grid](
        quantized_tensor, absmax_tensor, dequantized_tensor, M, N,
        quantized_tensor.stride(0), quantized_tensor.stride(1),
        dequantized_tensor.stride(0), dequantized_tensor.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
    )

    print("Dequantized tensor:", dequantized_tensor)

# Execute the main function
if __name__ == "__main__":
    main()


Dequantized tensor: tensor([[ 7.7998e-01,  2.3399e+00, -7.7998e-01,  1.5600e+00, -2.3399e+00,
          1.9499e+00,  1.5600e+00,  7.7998e-01,  3.8999e-01, -3.8999e-01],
        [ 5.0262e-01, -5.0262e-01,  1.7592e+00, -1.0052e+00, -5.0262e-01,
          1.5078e+00,  2.5131e-01,  0.0000e+00,  0.0000e+00, -1.5078e+00],
        [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-1.9591e+00, -4.8977e-01, -2.4489e-01,  2.4489e-01,  1.4693e+00,
         -1.9591e+00,  0.0000e+00, -2.4489e-01, -1.9591e+00,  1.2244e+00],
        [ 1.0541e+19,  2.1082e+19, -5.2705e+18, -2.1082e+19, -4.2164e+19,
          2.1082e+19, -5.2705e+18, -2.6352e+19,  3.1623e+19,  0.0000e+00],
        [-1.2118e+00, -2.4235e-01,  7.2706e-01, -9.6942e-01, -7.2706e-01,
         -1.4541e+00, -1.2118e+00, -1.9388e+00, -1.9388e+00,  1.6965e+00],
        [-1.0541e+19, -2.6352e+19,  3.1623e+19,  2.6352e+19, -3.6893e+19,
         -1.

In [18]:
import torch

def reference_dequantize(quantized_tensor, absmax_tensor):
    # Ensure the tensors are on the same device
    device = quantized_tensor.device
    absmax_tensor = absmax_tensor.to(device)

    # Calculate the scale factor
    scale = absmax_tensor[:, None] / 7.0  # Assuming 4-bit quantization ranges from -7 to 7

    # Dequantize the tensor
    dequantized = (quantized_tensor - 8) * scale  # Map [0, 15] to [-7, 7] and scale
    return dequantized


In [19]:
import torch
import triton
import triton.language as tl

def test_dequantize_function():
    # Define tensor dimensions
    M, N = 10, 10  # You can adjust these dimensions as needed

    # Generate random quantized tensor and absmax tensor
    quantized_tensor = torch.randint(0, 16, (M, N), dtype=torch.int32, device='cuda')
    absmax_tensor = torch.rand(M, device='cuda') * 10  # Random absmax values between 0 and 10

    # Allocate output tensor for Triton kernel
    dequantized_tensor = torch.empty((M, N), device='cuda', dtype=torch.float32)

    # Define block sizes
    BLOCK_M, BLOCK_N = 128, 128

    # Launch Triton kernel
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
    dequantize_kernel[grid](
        quantized_tensor, absmax_tensor, dequantized_tensor,
        M, N,
        quantized_tensor.stride(0), quantized_tensor.stride(1),
        dequantized_tensor.stride(0), dequantized_tensor.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
    )

    # Run reference dequantization
    reference_output = reference_dequantize(quantized_tensor, absmax_tensor)

    # Compare the outputs
    if torch.allclose(dequantized_tensor, reference_output, atol=1e-6):
        print("Test passed: Triton kernel output matches reference implementation.")
    else:
        print("Test failed: Outputs do not match.")
        print("Triton output:", dequantized_tensor)
        print("Reference output:", reference_output)

# Ensure the Triton kernel is defined before running this test
test_dequantize_function()


Test passed: Triton kernel output matches reference implementation.
