# Unsloth Puzzles: All Challenges

This notebook contains implementations and benchmarks for Challenges A, B, and C.

In [None]:
!pip install torch>=2.4 transformers peft trl accelerate bitsandbytes unsloth -U --quiet

## Challenge A: NF4 Triton Kernel

In [None]:
import torch
import triton
import triton.language as tl
import time
from unsloth.kernels import fast_dequantize
from bitsandbytes.nn import LinearNF4

@triton.jit
def _your_dequantize_nf4_kernel(
    weight_ptr,
    absmax_ptr,
    code_ptr,
    out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    block_start = pid * BLOCK_SIZE
    
    byte_offsets = (block_start // 2) + tl.arange(0, BLOCK_SIZE // 2)
    mask = byte_offsets < (n_elements // 2)
    
    packed_weights = tl.load(weight_ptr + byte_offsets, mask=mask)
    
    low_nibble = (packed_weights & 0xF).to(tl.int32)
    high_nibble = (packed_weights >> 4).to(tl.int32)
    
    # Note: Use tl.load from ptr for indexing
    val_low = tl.load(code_ptr + low_nibble)
    val_high = tl.load(code_ptr + high_nibble)
    
    abs_low = tl.load(absmax_ptr + (block_start + tl.arange(0, BLOCK_SIZE // 2) * 2) // 64, mask=mask)
    abs_high = tl.load(absmax_ptr + (block_start + tl.arange(0, BLOCK_SIZE // 2) * 2 + 1) // 64, mask=mask)
    
    val_low = val_low * abs_low
    val_high = val_high * abs_high
    
    out_offsets_low = block_start + tl.arange(0, BLOCK_SIZE // 2) * 2
    out_offsets_high = out_offsets_low + 1
    
    tl.store(out_ptr + out_offsets_low, val_low, mask=mask)
    tl.store(out_ptr + out_offsets_high, val_high, mask=mask)

def your_dequantize_nf4(linear):
    weight = linear.weight.data
    quant_state = linear.weight.quant_state
    n_elements = weight.numel() * 2
    out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=weight.device)
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    _your_dequantize_nf4_kernel[grid](
        weight, quant_state.absmax, quant_state.code, out, n_elements, BLOCK_SIZE=BLOCK_SIZE,
    )
    return out

# Correctness & Speed Check
device = "cuda"
shape = (4096, 4096)
linear = LinearNF4(shape[1], shape[0], bias=False).to(device)
out_ref = fast_dequantize(linear.weight.data, linear.weight.quant_state)
out_custom = your_dequantize_nf4(linear)
print(f"A Correctness: {torch.allclose(out_ref, out_custom, atol=1e-5)}")

## Challenge B: FSDP2 + QLoRA

In [None]:
import os
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
model_name = "unsloth/llama-3-8b-bnb-4bit"
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map={ "": 0 }, torch_dtype=torch.bfloat16, attn_implementation="sdpa")
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(r=64, lora_alpha=32, target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM")
model = get_peft_model(model, lora_config)

dataset = load_dataset("philschmid/dolly-15k-llama-3-format", split="train[:100]")
training_args = SFTConfig(
    output_dir="./outputs", per_device_train_batch_size=1, gradient_accumulation_steps=1,
    max_steps=5, bf16=True, dataset_text_field="text", max_seq_length=128,
    fsdp="full_shard auto_wrap",
    fsdp_config={
        "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
        "activation_checkpointing": True, "offload_params": True,
    },
)
trainer = SFTTrainer(model=model, train_dataset=dataset, args=training_args)
trainer.train()
print("B Completed!")

## Challenge C: torch.compile

In [None]:
import torch._dynamo
@torch.compile(fullgraph=False, dynamic=True)
def compiled_forward(model, input_ids, labels):
    return model(input_ids=input_ids, labels=labels).loss

input_ids = torch.randint(0, 32000, (1, 128)).cuda()
print("Checking C graph breaks...")
explanation = torch._dynamo.explain(compiled_forward, model, input_ids, input_ids)
print(f"Graph breaks: {explanation.graph_break_count}")
print("C Completed!")