<a href="https://colab.research.google.com/github/kissflow/prompt2finetune/blob/main/tinyllama_ft_for_t4gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install uv

In [None]:
!uv pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
!uv pip install trl peft accelerate bitsandbytes

In [None]:
from unsloth import FastLanguageModel
import torch
from datasets import load_dataset, Dataset
from trl import SFTTrainer, SFTConfig
from transformers import TrainingArguments, TextStreamer
import json
import pandas as pd
from typing import Dict, List
import random

In [None]:
class Config:
    """Training configuration optimized for T4 GPU based on Unsloth documentation."""

    # Paths
    RAW_JSON = "/content/ipl_2023.json"
    OUTPUT_DIR = "tinyllama-ipl-unsloth-t45"

    # Model
    BASE_MODEL = "unsloth/tinyllama"
    MAX_SEQ_LENGTH = 512  # ⚠️ CRITICAL: Start with 512, not 1024!

    # LoRA (Keep high for quality)
    LORA_R = 32
    LORA_ALPHA = 64
    LORA_DROPOUT = 0.05
    TARGET_MODULES = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ]

    # 🔥 CRITICAL T4 SETTINGS (FROM RESEARCH)
    BATCH_SIZE = 1  # ⚠️ Research recommends 1 for T4!
    GRADIENT_ACCUMULATION_STEPS = 32  # ⚠️ High accumulation!
    NUM_EPOCHS = 5  # Research: At least 1 epoch minimum

    # 🔥 LEARNING RATE (Research-backed)
    LEARNING_RATE = 2e-4  # Research: Good starting point for QLoRA
    WARMUP_RATIO = 0.03  # ⚠️ Research: 3% warmup (not 10%!)
    WEIGHT_DECAY = 0.01
    MAX_GRAD_NORM = 0.3  # ⚠️ Research: 0.3 is safe default!

    # Evaluation
    TRAIN_TEST_SPLIT = 0.1
    EVAL_STEPS = 25
    SAVE_STEPS = 50
    LOGGING_STEPS = 5

    # Optimization
    USE_GRADIENT_CHECKPOINTING = True
    USE_FLASH_ATTENTION = True

    # Reproducibility
    RANDOM_SEED = 42
config = Config()



#============================================================================
# 4. DATA LOADING & PREPROCESSING
#============================================================================

print("="*80)
print("Loading IPL 2023 Dataset")
print("="*80)

with open(config.RAW_JSON, encoding="utf-8") as f:
    rows = json.load(f)

print(f"✅ Loaded {len(rows)} examples")

In [None]:
def to_chatml(r):
    """Convert instruction-output pair to ChatML format."""
    return (f"<|user|>\n{r['instruction']}</s>\n"
            f"<|assistant|>\n{r['output']}</s>\n")

# Create dataset
full_dataset = Dataset.from_dict({
    "text": [to_chatml(r) for r in rows]
})

# Split into train/validation
dataset_split = full_dataset.train_test_split(
    test_size=config.TRAIN_TEST_SPLIT,
    seed=config.RANDOM_SEED
)

train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]

print(f"✅ Training samples: {len(train_dataset)}")
print(f"✅ Validation samples: {len(eval_dataset)}")
print(f"\nSample formatted text:")
print("-" * 80)
print(train_dataset[0]["text"][:200] + "...")
print("-" * 80 + "\n")

In [None]:
# Fix 1: Suppress torch compilation errors
import torch._dynamo
torch._dynamo.config.suppress_errors = True

# Fix 2: Enable TF32 for better numerical stability
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Fix 3: Set manual seed for reproducibility
import random
import numpy as np
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)


model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=config.BASE_MODEL,
    max_seq_length=config.MAX_SEQ_LENGTH,
    dtype=None,  # Auto-detect (will use BF16 on A100)
    load_in_4bit=True,  # 4-bit quantization for efficiency
)

print(f"✅ Base model loaded: {config.BASE_MODEL}")
print(f"✅ Precision: BF16 (A100 native support)")
print(f"✅ Quantization: 4-bit NF4")

# Configure tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"✅ Tokenizer configured (vocab size: {len(tokenizer)})")

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=config.LORA_R,
    target_modules=config.TARGET_MODULES,
    lora_alpha=config.LORA_ALPHA,
    lora_dropout=config.LORA_DROPOUT,
    bias="none",
    use_gradient_checkpointing="unsloth",  # Unsloth's optimized checkpointing
    random_state=config.RANDOM_SEED,
    use_rslora=False,  # Standard LoRA
    loftq_config=None,
)


print("\n📊 Trainable Parameters:")
model.print_trainable_parameters()

In [None]:
training_args = SFTConfig(
      output_dir=config.OUTPUT_DIR,

      # Dataset
      dataset_text_field="text",
      max_length=config.MAX_SEQ_LENGTH,  # 512
      packing=False,

      # Training
      num_train_epochs=config.NUM_EPOCHS,
      per_device_train_batch_size=config.BATCH_SIZE,  # 1
      per_device_eval_batch_size=config.BATCH_SIZE,
      gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,  # 32

      # 🔥 OPTIMIZER (CRITICAL FOR T4)
      learning_rate=config.LEARNING_RATE,  # 2e-4
      optim="paged_adamw_8bit",  # ⚠️ PAGED optimizer for low VRAM!
      weight_decay=config.WEIGHT_DECAY,
      max_grad_norm=config.MAX_GRAD_NORM,  # 0.3

      # 🔥 SCHEDULER (Research-backed)
      warmup_ratio=0.03,  # ⚠️ 3% warmup (crucial!)
      lr_scheduler_type="cosine",

      # 🔥 PRECISION (T4 FP16 FIX)
      fp16=True,
      bf16=False,
      fp16_full_eval=True,

      # 🔥 GRADIENT CHECKPOINTING (Essential for T4)
      gradient_checkpointing=True,
      gradient_checkpointing_kwargs={
          'use_reentrant': False  # Better for FP16
      },

      # Logging
      logging_steps=config.LOGGING_STEPS,
      logging_dir=f"{config.OUTPUT_DIR}/logs",
      logging_first_step=True,  # Log first step to catch issues early

      # Saving
      save_strategy="steps",
      save_steps=config.SAVE_STEPS,
      save_total_limit=3,

      # Evaluation
      eval_strategy="steps",
      eval_steps=config.EVAL_STEPS,
      load_best_model_at_end=True,
      metric_for_best_model="eval_loss",
      greater_is_better=False,

      # Misc
      seed=config.RANDOM_SEED,
      report_to="none",
      group_by_length=True,

      # 🔥 DATALOADER (Important for T4)
      dataloader_num_workers=0,  # Avoid CPU overhead
      dataloader_pin_memory=True,  # Faster data transfer
  )


print("Initializing SFT Trainer...")

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # Validation split
    args=training_args,
)



In [None]:
trainer_stats = trainer.train()


In [None]:
from pathlib import Path

output_path = Path(config.OUTPUT_DIR) / "lora_adapters"
output_path.mkdir(parents=True, exist_ok=True)

model.save_pretrained(str(output_path))
tokenizer.save_pretrained(str(output_path))

In [None]:
FastLanguageModel.for_inference(model)

def build_prompt(user_msg: str) -> str:
    """Build ChatML prompt."""
    return f"<|user|>\n{user_msg}</s>\n<|assistant|>\n"

def extract_answer(full_text: str) -> str:
    """Extract assistant's response."""
    if "<|assistant|>" not in full_text:
        return full_text
    ans = full_text.split("<|assistant|>")[-1]
    return ans.split(tokenizer.eos_token)[0].strip()

def generate_response(question: str, max_new_tokens: int = 120):
    """Generate response to a question."""
    prompt = build_prompt(question)
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.1,  # Low temp for factual answers
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    full_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    return extract_answer(full_text)

# Test questions
test_questions = [
    "In IPL 2023, who won when Gujarat Titans played against Chennai Super Kings?",
    "Where was the IPL 2023 match between Rajasthan Royals and Sunrisers Hyderabad played?",
    "Which stadium hosted the Punjab Kings vs Rajasthan Royals match in IPL 2023?",
    "Who was the Man of the Match when Mumbai Indians played Chennai Super Kings in IPL 2023?",
]

In [None]:
print("\n🧪 Running Test Predictions:\n")

for i, question in enumerate(test_questions, 1):
    print(f"[Test {i}/{len(test_questions)}]")
    print(f"Q: {question}")
    print("-" * 80)
    answer = generate_response(question)
    print(f"A: {answer}")
    print("=" * 80 + "\n")