<a href="https://colab.research.google.com/github/KaustubhUp025/KU_Unsloth_Challenge_Solutions/blob/main/ChallengeCSolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [None]:
!pip install --upgrade torch torchvision bitsandbytes



In [None]:
import torch
import torch.nn as nn
import transformers.models.llama.modeling_llama as llama_mod
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
import torch._dynamo
import torch._inductor
import logging
import torch.autograd as autograd
import bitsandbytes.functional as F  # BitsAndBytes functional API
from types import SimpleNamespace  # For quant_state attributes

# -----------------------------------------------------------------------------
# Patch the transpose method of BitsAndBytes Params4bit so that it runs in eager mode.
# This prevents TorchDynamo from tracing user-defined methods on BitsAndBytes objects.
# -----------------------------------------------------------------------------
import bitsandbytes as bnb
if hasattr(bnb, "Params4bit"):
    _orig_t = bnb.Params4bit.t
    @torch._dynamo.disable
    def safe_t(self):
        return _orig_t(self)
    bnb.Params4bit.t = safe_t

# -----------------------------------------------------------------------------
# Eager dequantization+transpose wrapper for BitsAndBytes that is not traced.
# -----------------------------------------------------------------------------
@torch._dynamo.disable
def dequantize_4bit_eager(weight, quant_state, dtype):
    result = F.dequantize_4bit(weight, quant_state).to(dtype).t()
    return result

# -----------------------------------------------------------------------------
# Custom 4-bit MatMul Operator using the original dequantization behavior.
# -----------------------------------------------------------------------------
class CustomMatMul4BitFunction(autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, quant_state):
        deq_weight = dequantize_4bit_eager(weight, quant_state, x.dtype)
        ctx.save_for_backward(x, deq_weight, bias)
        output = torch.nn.functional.linear(x, deq_weight, bias)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, deq_weight, bias = ctx.saved_tensors
        x_flat = x.reshape(-1, x.shape[-1]).to(grad_output.dtype)
        grad_output_flat = grad_output.reshape(-1, grad_output.shape[-1])
        deq_weight = deq_weight.to(grad_output.dtype)
        grad_x_flat = grad_output_flat.matmul(deq_weight)
        grad_weight = grad_output_flat.transpose(0, 1).matmul(x_flat)
        grad_x = grad_x_flat.reshape(x.shape)
        grad_bias = grad_output_flat.sum(dim=0) if bias is not None else None
        return grad_x, grad_weight, grad_bias, None

# -----------------------------------------------------------------------------
# Wrapper for 4-bit matmul.
# This function ensures quant_state is a Tensor carrying the necessary attributes.
# We disable its tracing so that dynamic operations (like setattr) are not traced.
# -----------------------------------------------------------------------------
@torch._dynamo.disable
def matmul_4bit_wrapper(x, weight, bias, quant_state):
    if not isinstance(quant_state, torch.Tensor):
        dummy_shape = getattr(quant_state, "shape", weight.shape)
        dummy = torch.empty(dummy_shape, device=x.device, dtype=x.dtype)
        for attr in ["absmax", "nested", "state2", "blocksize", "quant_type", "offset"]:
            if hasattr(quant_state, attr):
                setattr(dummy, attr, getattr(quant_state, attr))
        quant_state = dummy
    return CustomMatMul4BitFunction.apply(x, weight, bias, quant_state)

_original_matmul_4bit = bnb.matmul_4bit  # Save original if needed.
bnb.matmul_4bit = matmul_4bit_wrapper

######################################
# Torch compile options and Model Patching (unchanged from original)
######################################
torch_compile_options = {
    "epilogue_fusion": True,
    "max_autotune": True,
    "shape_padding": True,
    "trace.enabled": True,
    "triton.cudagraphs": False,
}

try:
    torch.compiler.allow_in_graph(bnb.nn.modules.Params4bit.t)
except Exception as e:
    print("Warning: Could not mark Params4bit.t as allowed:", e)

@torch.compile(fullgraph=False, dynamic=True, options=torch_compile_options)
def compiled_llama_mlp(self, x):
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
llama_mod.LlamaMLP.forward = compiled_llama_mlp

@torch.compile(fullgraph=False, dynamic=True, options=torch_compile_options)
def compiled_flex_attention(self, hidden_states, position_embeddings, attention_mask, past_key_value=None, cache_position=None, **kwargs):
    # --- New patch: disable caching during training to avoid dynamic past_key_value issues ---
    if self.training:
        past_key_value = None

    batch_size, seq_len, _ = hidden_states.shape
    num_query_heads = self.q_proj.out_features // self.head_dim
    num_kv_heads    = self.k_proj.out_features // self.head_dim

    query_states = self.q_proj(hidden_states).reshape(batch_size, seq_len, num_query_heads, self.head_dim).transpose(1, 2)
    key_states   = self.k_proj(hidden_states).reshape(batch_size, seq_len, num_kv_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).reshape(batch_size, seq_len, num_kv_heads, self.head_dim).transpose(1, 2)

    cos, sin = position_embeddings
    cos = cos[:seq_len, :] if cos.dim() == 2 else cos[..., :seq_len, :]
    sin = sin[:seq_len, :] if sin.dim() == 2 else sin[..., :seq_len, :]

    query_states, key_states = llama_mod.apply_rotary_pos_emb(query_states, key_states, cos, sin)

    if past_key_value is not None:
        cache_kwargs = {"sin": cos, "cos": sin, "cache_position": cache_position}
        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

    from transformers.models.llama.modeling_llama import repeat_kv
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query_states, key_states, value_states,
        attn_mask=attention_mask,
        dropout_p=self.attention_dropout if self.training else 0.0,
        is_causal=self.is_causal
    )

    attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_query_heads * self.head_dim)
    attn_output = self.o_proj(attn_output)
    return attn_output, None

llama_mod.LlamaAttention.forward = compiled_flex_attention

original_layernorm_forward = nn.LayerNorm.forward
@torch.compile(fullgraph=True, dynamic=True, options=torch_compile_options)
def compiled_layernorm_forward(self, input):
    return original_layernorm_forward(self, input)
nn.LayerNorm.forward = compiled_layernorm_forward

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
max_seq_length = 1024
torch.set_default_dtype(torch.float16)
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    attn_implementation="sdpa",
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

# ----- Patch quant_state for quantized weights only if needed ---------
with torch.no_grad():
    for name, param in model.named_parameters():
        if hasattr(param, "quant_state") and hasattr(param.quant_state, "shape"):
            if param.quant_state.shape[0] == 1:
                if hasattr(param.quant_state, "logical_shape"):
                    new_shape = param.quant_state.logical_shape
                else:
                    new_shape = param.quant_state.shape  # fallback; leave as is
                if new_shape != param.quant_state.shape:
                    print(f"Patching quant_state shape for {name}: {param.quant_state.shape} -> {new_shape}")
                    param.quant_state.shape = new_shape

lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)
model.enable_input_require_grads()

url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files={"train": url}, split="train[:10%]")

os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.capture_dynamic_output_shape_ops = True

torch._inductor.config.debug = True
torch._logging.set_logs(
    dynamo=logging.WARN,
    inductor=logging.WARN,
    graph_breaks=True,
    recompiles=True,
    recompiles_verbose=True,
    compiled_autograd_verbose=True,
)
torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False

def log_gpu_stats():
    allocated = torch.cuda.memory_allocated()
    cached = torch.cuda.memory_reserved()
    print(f"Allocated VRAM: {allocated / 1e6:.1f} MB, Reserved: {cached / 1e6:.1f} MB")
log_gpu_stats()

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        warmup_steps=1,
        max_steps=10,
        logging_steps=1,
        output_dir="outputs",
        seed=3407,
        max_seq_length=max_seq_length,
        fp16=(model.get_input_embeddings().weight.dtype == torch.float16),
        bf16=(model.get_input_embeddings().weight.dtype == torch.bfloat16),
        report_to="none",
        dataset_num_proc=4,
    ),
)
trainer.train()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Allocated VRAM: 1118.9 MB, Reserved: 1153.4 MB


Tokenizing train dataset (num_proc=4):   0%|          | 0/21029 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=4):   0%|          | 0/21029 [00:00<?, ? examples/s]

V0221 12:19:38.816000 9693 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks] Graph break in user code at /usr/local/lib/python3.11/dist-packages/bitsandbytes/nn/modules.py:484
V0221 12:19:38.816000 9693 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks] Reason: Unsupported: call_method UserDefinedObjectVariable(Params4bit) t [] {}
V0221 12:19:38.816000 9693 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks] User code traceback:
V0221 12:19:38.816000 9693 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks]   File "<ipython-input-3-9e33b440181a>", line 108, in compiled_flex_attention
V0221 12:19:38.816000 9693 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks]     query_states = self.q_proj(hidden_states).reshape(batch_size, seq_len, num_query_heads, self.head_dim).transpose(1, 2)
V0221 12:19:38.816000 9693 torch/_dynamo/symbolic_convert.py:435] [0/0] [__graph_breaks]   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/lora/b

Step,Training Loss
1,1.5189
2,2.3934
3,2.5027
4,3.5342
5,2.1386
6,2.9794
7,2.2491
8,1.6307
9,2.2227
10,2.6863


TrainOutput(global_step=10, training_loss=2.385595905780792, metrics={'train_runtime': 27.3406, 'train_samples_per_second': 0.732, 'train_steps_per_second': 0.366, 'total_flos': 10592155496448.0, 'train_loss': 2.385595905780792})