# Llama 3.2 1B: SFT + DPO Training on Colab

Training pipeline for multiple-choice reasoning with Chain-of-Thought

**Runtime:** Make sure you're using **GPU (T4)** runtime!

Runtime ‚Üí Change runtime type ‚Üí T4 GPU

## 1. Setup: Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Navigate to Project

In [None]:
import os
os.chdir('/content/drive/MyDrive/llama32-mcq-cot')

# Verify
!pwd
!ls

## 3. Install Dependencies

In [None]:
!pip install -q -r requirements.txt

In [None]:
# Verify installations
import transformers
import torch

print(f"Transformers: {transformers.__version__}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

## 4. Login to HuggingFace & Wandb

In [None]:
# HuggingFace login
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# Wandb login
import wandb
wandb.login()

## 5. Test Data Loading

In [None]:
!python src/prepare_data.py

## 6. Build DPO Preference Pairs

In [None]:
!python src/build_dpo_data.py

## 7. Train SFT Model

**‚è±Ô∏è Expected time:** ~3-4 hours

**Note:** Don't close browser while training!

In [None]:
!python src/train_sft.py

### Check SFT Training Status

In [None]:
from pathlib import Path

if Path("outputs/sft-llama32-1b-mcq-merged").exists():
    print("‚úì SFT training completed!")
    !ls outputs/sft-llama32-1b-mcq-merged/
else:
    print("‚úó SFT training not complete yet")

## 8. Train DPO Model

**‚è±Ô∏è Expected time:** ~2-3 hours

In [None]:
!python src/train_dpo.py

### Check DPO Training Status

In [None]:
from pathlib import Path

if Path("outputs/dpo-llama32-1b-mcq-merged").exists():
    print("‚úì DPO training completed!")
    !ls outputs/dpo-llama32-1b-mcq-merged/
else:
    print("‚úó DPO training not complete yet")

## 9. Evaluate All Models

**‚è±Ô∏è Expected time:** ~15-30 minutes

Compares base, SFT, and DPO models

In [None]:
!python src/evaluate.py

## 10. Quick Test Inference

Test your DPO model on a sample question

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load DPO model
model_path = "outputs/dpo-llama32-1b-mcq-merged"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Sample question
prompt = """Answer the following question with step-by-step reasoning.

Question: Where would you find a jellyfish that has not been captured?
Options:
A. ocean
B. store
C. tank
D. internet
E. aquarium

Think through this step by step, then provide your answer as "Answer: X"."""

# Generate
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=150, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(response)

## 11. Download Models (Optional)

Download trained models to local machine

In [None]:
# Zip models
!zip -r trained_models.zip outputs/*-merged/

# Download
from google.colab import files
files.download('trained_models.zip')

## Utilities

### Clear GPU Memory

In [None]:
import torch
import gc

gc.collect()
torch.cuda.empty_cache()
print("GPU memory cleared!")

### Check GPU Memory Usage

In [None]:
!nvidia-smi

### Project Status

In [None]:
from pathlib import Path

print("Project Status Check")
print("="*50)

# Check data
if Path("data/dpo_pairs.jsonl").exists():
    print("‚úì DPO data prepared")
else:
    print("‚úó DPO data not ready")

# Check SFT
if Path("outputs/sft-llama32-1b-mcq-merged").exists():
    print("‚úì SFT training completed")
else:
    print("‚úó SFT training pending")

# Check DPO
if Path("outputs/dpo-llama32-1b-mcq-merged").exists():
    print("‚úì DPO training completed")
else:
    print("‚úó DPO training pending")

print("="*50)

---

## Done! üéâ

You've completed:
- ‚úÖ SFT training with QLoRA
- ‚úÖ DPO training for preference optimization
- ‚úÖ Model evaluation and comparison

Check your wandb dashboard for training metrics!