In [1]:
# Install required packages
!pip install -q transformers peft datasets bitsandbytes
!pip install -q accelerate triton trl
!pip install -q torch torchvision

In [2]:
import torch
import os
torch_compile_options = torch_compile_options = {
    "epilogue_fusion"   : True,
    "max_autotune"      : True,
    "shape_padding"     : True,
    "trace.enabled"     : True,
    "triton.cudagraphs" : False,
}

@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
def compiled_llama_mlp(self, x):
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    return down_proj

import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaMLP.forward = compiled_llama_mlp

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \
    "expandable_segments:True,"\
    "roundup_power2_divisions:[32:256,64:128,256:64,>:32]"

max_seq_length = 1024
torch.set_default_dtype(torch.float16)
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16
bnb_config = BitsAndBytesConfig(
    load_in_4bit              = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type       = "nf4",
    bnb_4bit_compute_dtype    = dtype,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    attn_implementation = "sdpa",
    quantization_config = bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

lora_config = LoraConfig(
    r = 32,
    lora_alpha = 64,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_dropout = 0,
    bias = "none",
    task_type = TaskType.CAUSAL_LM,
)

# Get LoRA and setup model
model = get_peft_model(model, lora_config)
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name: param.requires_grad_(True)
        else: param.requires_grad_(False)

# Currently GC will cause torch.compile to be disabled, so disable it
# model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Get dataset
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train[:10%]")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
# Must show all graph breaks are not seen with torch.compile
import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

import logging
torch._inductor.config.debug = True
torch._logging.set_logs(
    dynamo = logging.WARN,
    inductor = logging.WARN,
    graph_breaks = True,
    recompiles = True,
    recompiles_verbose = True,
    compiled_autograd_verbose = True,
    # aot_joint_graph = True, # Enable for more logs
    # aot_graphs = True,
)
torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False

This is the uncompiled one on A100 GPU

In [5]:
# -----------------------------------------------------------------------------
# Define helper function to log VRAM usage
# -----------------------------------------------------------------------------
def log_vram(stage: str):
    """Prints out the current VRAM usage (allocated and reserved) in MB."""
    allocated = torch.cuda.memory_allocated() / (1024 ** 2)
    reserved = torch.cuda.memory_reserved() / (1024 ** 2)
    print(f"{stage} – VRAM allocated: {allocated:.2f} MB, reserved: {reserved:.2f} MB")

# Assume other parts of your code (imports, model setup, etc.) are already defined.

# -----------------------------------------------------------------------------
# Uncompiled Branch: Trainer Setup & Training
# -----------------------------------------------------------------------------

# (Optional) Set up your dataset and tokenizer as you normally do.
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

# Example dataset URL (use your actual dataset URL if different)
url_data = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files={"train": url_data}, split="train[:10%]")

# Setup SFTTrainer with uncompiled settings
trainer = SFTTrainer(
    model=model,  # Your uncompiled model (or you can disable torch.compile for this run)
    train_dataset=dataset,
    processing_class=tokenizer,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        warmup_steps=1,
        max_steps=10,
        logging_steps=1,
        output_dir="outputs",
        seed=3407,
        max_seq_length=max_seq_length,
        fp16=(model.get_input_embeddings().weight.dtype == torch.float16),
        bf16=(model.get_input_embeddings().weight.dtype == torch.bfloat16),
        report_to="none",  # e.g., disable reporting for W&B
        dataset_num_proc=4,
    ),
)

# Log VRAM usage before training (for the uncompiled branch)
log_vram("Before Training (Uncompiled)")

# Run the training loop
trainer.train()

# Log VRAM usage after training (for the uncompiled branch)
log_vram("After Training (Uncompiled)")


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Before Training (Uncompiled) – VRAM allocated: 1066.08 MB, reserved: 1548.00 MB


V0611 16:20:25.328000 39833 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks] Graph break in user code at /usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py:496
V0611 16:20:25.328000 39833 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks] Reason: Unsupported: call_method UserDefinedObjectVariable(Params4bit) t [] {}
V0611 16:20:25.328000 39833 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks] User code traceback:
V0611 16:20:25.328000 39833 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks]   File "<ipython-input-2-3830617507>", line 13, in compiled_llama_mlp
V0611 16:20:25.328000 39833 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks]     down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
V0611 16:20:25.328000 39833 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks]   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/lora/bnb.py", line 494, in forward
V0611 16:20:25.

Step,Training Loss
1,1.5384
2,2.3972
3,2.5093
4,3.5343
5,2.1524
6,2.9821
7,2.2626
8,1.6414
9,2.2215
10,2.7038


After Training (Uncompiled) – VRAM allocated: 1254.33 MB, reserved: 4004.00 MB


In [11]:
# ===========================================================================
# 1. Imports
# ===========================================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
import triton
import triton.language as tl
import math
import os
import logging
import time
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
# *** CORRECTED IMPORTS FOR PATCHING ***
from transformers.models.llama.modeling_llama import (
    LlamaForCausalLM,
    LlamaMLP,
    LlamaRMSNorm,
)
from peft import get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
import bitsandbytes.nn as bnb_nn
from bitsandbytes.nn.modules import Linear4bit

# ===========================================================================
# 2. Environment & Logging Configuration
# Set environment variables for detailed graph break and re-compilation logs.
# ===========================================================================
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
# Set to 0 to disable recompilation logs if they become too noisy
os.environ["TORCHDYNAMO_REPRO_AFTER"] = "0"
torch._logging.set_logs(
    dynamo=logging.INFO,
    aot=logging.INFO,
    inductor=logging.INFO,
    graph_breaks=True,
    recompiles=True,
    recompiles_verbose=True,
)
torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False

# ===========================================================================
# 3. Task A: Triton Kernel for NF4 Dequantization
# This kernel and its launcher are from your puzzle_1.ipynb solution.
# It is torch.compile-compatible and will replace the bnb forward pass.
# ===========================================================================
@triton.jit
def _your_dequantize_nf4_kernel(
    w_ptr, abs_idx_ptr, offset_ptr, abs2_scales_ptr, code2_ptr, nf4_code_ptr, output_ptr,
    TOTAL_ELEMENTS_IN_OUTPUT_TENSOR: tl.constexpr,
    BLOCK_SIZE_BYTES_PER_CHUNK: tl.constexpr,
    BLOCK_SIZE_ELEMENTS_PER_CHUNK: tl.constexpr,
    NUM_GROUPS_PER_CHUNK: tl.constexpr,
    ELEMENTS_PER_GROUP_CONST: tl.constexpr,
    LOG2_L2_BLOCK_SIZE_CONST_KERNEL: tl.constexpr,
    gsize_num_chunks: tl.constexpr,
):
    pid = tl.program_id(0)
    if pid >= gsize_num_chunks: return

    log2_l2_block_size = LOG2_L2_BLOCK_SIZE_CONST_KERNEL
    chunk_element_start_offset = pid * BLOCK_SIZE_ELEMENTS_PER_CHUNK
    group_arange_local = tl.arange(0, NUM_GROUPS_PER_CHUNK)
    absmax_group_indices_potential = (chunk_element_start_offset // ELEMENTS_PER_GROUP_CONST) + group_arange_local
    group_mask = (absmax_group_indices_potential * ELEMENTS_PER_GROUP_CONST) < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR

    quantized_absmax_indices = tl.load(abs_idx_ptr + absmax_group_indices_potential, mask=group_mask, other=0, eviction_policy="evict_first")
    dequantized_l1_scales = tl.load(code2_ptr + quantized_absmax_indices.to(tl.int32), mask=group_mask, other=0.0, eviction_policy="evict_last")

    absmax_l2_group_indices_potential = absmax_group_indices_potential >> log2_l2_block_size
    l2_scales = tl.load(abs2_scales_ptr + absmax_l2_group_indices_potential, mask=group_mask, other=0.0, eviction_policy="evict_last")

    offset_val = tl.load(offset_ptr + 0)
    final_group_scales_masked = l2_scales * dequantized_l1_scales + offset_val

    element_arange_pid_local = tl.arange(0, BLOCK_SIZE_ELEMENTS_PER_CHUNK)
    global_element_indices = chunk_element_start_offset + element_arange_pid_local
    element_op_mask = global_element_indices < TOTAL_ELEMENTS_IN_OUTPUT_TENSOR

    byte_index_global_for_element = global_element_indices // 2
    is_low_nibble_flag = (global_element_indices % 2) != 0

    packed_byte_for_element = tl.load(w_ptr + byte_index_global_for_element, mask=element_op_mask, other=0)
    nibble_shift = tl.where(is_low_nibble_flag, 0, 4)
    quantized_idx_for_element = (packed_byte_for_element >> nibble_shift) & 0x0F

    dequant_val_for_element = tl.load(nf4_code_ptr + quantized_idx_for_element.to(tl.int32), mask=element_op_mask, other=0.0)

    scales_reshaped_for_broadcast = tl.reshape(final_group_scales_masked, (NUM_GROUPS_PER_CHUNK, 1))
    scales_broadcasted_to_elements = tl.broadcast_to(scales_reshaped_for_broadcast, (NUM_GROUPS_PER_CHUNK, ELEMENTS_PER_GROUP_CONST))
    element_scales_vector = tl.reshape(scales_broadcasted_to_elements, (BLOCK_SIZE_ELEMENTS_PER_CHUNK,))
    final_scale_for_element = tl.where(element_op_mask, element_scales_vector, 0.0)

    scaled_element_output = dequant_val_for_element * final_scale_for_element
    tl.store(output_ptr + global_element_indices, scaled_element_output, mask=element_op_mask)

def _your_dequantize_nf4(weight_data, quant_state):
    device = weight_data.device
    output_shape = quant_state.shape
    total_elements_in_output = output_shape.numel()

    # Kernel launch parameters
    OPTIMIZED_BR = 4096
    OPTIMIZED_WARPS = 4
    OPTIMIZED_STAGES = 2

    output_tensor = torch.empty(output_shape, dtype=quant_state.dtype, device=device)

    grid = lambda META: (triton.cdiv(total_elements_in_output, META['BLOCK_SIZE_ELEMENTS_PER_CHUNK']),)

    _your_dequantize_nf4_kernel[grid](
        weight_data, quant_state.absmax, quant_state.offset,
        quant_state.state2.absmax, quant_state.state2.code, quant_state.code,
        output_tensor,
        TOTAL_ELEMENTS_IN_OUTPUT_TENSOR=total_elements_in_output,
        BLOCK_SIZE_ELEMENTS_PER_CHUNK=OPTIMIZED_BR,
        ELEMENTS_PER_GROUP_CONST=quant_state.blocksize,
        LOG2_L2_BLOCK_SIZE_CONST_KERNEL=int(math.log2(quant_state.state2.blocksize)),
        num_warps=OPTIMIZED_WARPS,
        num_stages=OPTIMIZED_STAGES
    )
    return output_tensor

def your_dequantize_nf4(weight_param):
    return _your_dequantize_nf4(weight_param.weight.data, weight_param.weight.quant_state)


# ===========================================================================
# 4. Compiled Forward Passes & Monkey-Patching
# We define our custom compiled functions and then patch them into the
# original library classes. This happens *before* the model is loaded.
# ===========================================================================

# ---- Patch 1: BitsAndBytes Linear4bit (The main fix for graph breaks) ----
@torch.compile(fullgraph=True)
def compiled_bnb_forward(self, x):
    # Dequantize using the Triton kernel from Task A
    dequantized_weight = your_dequantize_nf4(self)
    # Perform standard linear operation
    output = F.linear(x, dequantized_weight, self.bias)
    return output

# ---- Patch 2: LlamaMLP (as in the original notebook) ----
@torch.compile(fullgraph=True)
def compiled_llama_mlp_forward(self, x):
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

# ---- Patch 3: LlamaRMSNorm (Fixing reviewer feedback) ----
# Using rsqrt for better numerics and ensuring it's patched.
@torch.compile(fullgraph=True)
def compiled_rmsnorm_forward(self, hidden_states):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    # Using rsqrt as suggested
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    return (self.weight * hidden_states).to(input_dtype)

# ---- Patch 4: LlamaForCausalLM (Fixing reviewer feedback) ----
# This patches the entire model's forward pass to correctly upcast logits.
@torch.compile(fullgraph=False, dynamic=True)
def compiled_causal_lm_forward(
    self,
    input_ids=None,
    attention_mask=None,
    position_ids=None,
    past_key_values=None,
    inputs_embeds=None,
    labels=None,
    use_cache=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):
    # This is a simplified wrapper around the original forward pass
    # It ensures full compilation of the core logic while allowing flexibility.

    # We call the original forward method of the *base* LlamaForCausalLM class
    # to avoid infinite recursion after patching.
    outputs = LlamaForCausalLM.forward(
        self,
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        labels=None, # Pass labels as None to prevent internal loss calculation
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    logits = outputs.logits
    loss = None
    if labels is not None:
        # **Reviewer Fix:** Upcast logits to float32 before loss calculation
        logits_for_loss = logits.float()
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits_for_loss.view(-1, self.config.vocab_size), labels.view(-1))

    # Manually reconstruct the output object
    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output

    from transformers.modeling_outputs import CausalLMOutputWithPast
    return CausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

print("--- Applying patches to library classes ---")
bnb_nn.Linear4bit.forward = compiled_bnb_forward
LlamaMLP.forward = compiled_llama_mlp_forward
LlamaRMSNorm.forward = compiled_rmsnorm_forward
# Apply the main model forward patch last
LlamaForCausalLM.forward = compiled_causal_lm_forward
print("--- Patches applied successfully ---")


# ===========================================================================
# 5. Model & Tokenizer Loading
# ===========================================================================
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
max_seq_length = 1024

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Use Unsloth's FastLanguageModel for optimized loading and setup
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = True,
)

# ===========================================================================
# 6. PEFT & QLoRA Configuration
# ===========================================================================
model = FastLanguageModel.get_peft_model(
    model,
    r = 32,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 64,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = False, # GC must be disabled for fullgraph compile
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

# ===========================================================================
# 7. Dataset and Trainer Setup
# ===========================================================================
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files={"train": url}, split="train[:2%]")

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    # *** CORRECTED ARGUMENT NAME ***
    processing_class = tokenizer,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = SFTConfig(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 15,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
    ),
)

# ===========================================================================
# 8. Training and Analysis
# ===========================================================================
print("\n--- Starting Training with Compiled & Patched Model ---")
start_time = time.time()
trainer_stats = trainer.train()
end_time = time.time()
training_time = end_time - start_time

print(f"\n--- Training Finished ---")
print(f"Total training time: {training_time:.2f} seconds")

# *** CORRECTED FINAL LOSS PRINTING ***
if trainer.state.log_history:
    final_log = trainer.state.log_history[-1]
    final_loss = final_log.get('loss', final_log.get('train_loss', 'N/A'))
    if isinstance(final_loss, float):
        print(f"Final training loss: {final_loss:.4f}")

# Print dynamo stats
torch._dynamo.utils.compile_times()

--- Applying patches to library classes ---
--- Patches applied successfully ---
==((====))==  Unsloth 2025.6.2: Fast Llama patching. Transformers: 4.52.4.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!

--- Starting Training with Compiled & Patched Model ---


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 4,206 | Num Epochs = 1 | Total steps = 15
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 22,544,384/1,000,000,000 (2.25% trained)


Step,Training Loss
1,1.3256
2,1.181
3,1.1612
4,0.6844
5,0.2314
6,0.1871
7,0.1038
8,0.1487
9,0.1303
10,0.1712



--- Training Finished ---
Total training time: 15.70 seconds
Final training loss: 0.3908


'TorchDynamo compilation metrics:\nFunction                              Runtimes (s)\n------------------------------------  ------------------------------------------------------------------------------------------------------------------------------\n_compile.compile_inner                0.2370, 0.0813, 2.4950, 0.0208, 0.6170, 0.5355, 0.3118, 0.1966, 0.3928, 0.4289, 0.0081, 0.1536\nOutputGraph.call_user_compiler        2.3667, 0.3787, 0.4043, 0.1752, 0.1085, 0.2440, 0.3266, 0.0939\n_recursive_pre_grad_passes            0.0046, 0.0011, 0.0028, 0.0007, 0.0018, 0.0008, 0.0008, 0.0009\ncreate_aot_dispatcher_function        2.3577, 0.3707, 0.3955, 0.1715, 0.1023, 0.2367, 0.3203, 0.0897\n_recursive_joint_graph_passes         0.1581, 0.0121, 0.0254, 0.0085, 0.0016, 0.0120, 0.0254, 0.0017\ncompile_fx.<locals>.fw_compiler_base  1.8923, 0.1321, 0.0915, 0.0276, 0.0251, 0.0608, 0.0454, 0.0241\ncompile_fx_inner                      1.8913, 0.0697, 0.1311, 0.0761, 0.0906, 0.0820, 0.0267, 0.0265, 0