# Week 2 — Notebook 2: QLoRA Fine-Tuning (SFT Warm-up → DPO)

**Strategy:** Two-phase training to counteract the alignment tax observed in production:

1. **Phase 1 — SFT warm-up** on the 700 highest-quality correct samples  
   → anchors the model in the correct factual space before preference learning
2. **Phase 2 — DPO** on chosen/rejected pairs from Notebook 01  
   → with raised KL penalty (β) to prevent over-optimization

---
> **GPU requirement:** 1× A100/H100 40GB+ recommended. Works on 2× A40 with `load_in_4bit=True`.  
> **Default model:** `Qwen/Qwen2.5-7B-Instruct` (swap in Llama-3 or Mistral as needed).

## 0. Install & Imports

In [None]:
# !pip install transformers>=4.40 peft trl>=0.8 bitsandbytes accelerate datasets wandb

In [None]:
import os
import torch
from pathlib import Path
from dotenv import load_dotenv

load_dotenv("../../week1/.env")

import transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import SFTTrainer, DPOTrainer, DPOConfig
from datasets import Dataset, load_from_disk

DATA_DIR   = Path("../data")
MODELS_DIR = Path("../models")
MODELS_DIR.mkdir(exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Config

In [None]:
# ── Model ────────────────────────────────────────────────────────────────────
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"   # swap to: meta-llama/Meta-Llama-3-8B-Instruct
RUN_NAME   = "qwen2.5-7b-customer-support-dpo"

# ── QLoRA ────────────────────────────────────────────────────────────────────
LORA_R          = 16
LORA_ALPHA      = 32
LORA_DROPOUT    = 0.05
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"]

# ── SFT warm-up ──────────────────────────────────────────────────────────────
SFT_EPOCHS      = 1
SFT_BATCH_SIZE  = 4
SFT_LR          = 2e-4
SFT_MAX_SEQ_LEN = 1024

# ── DPO ──────────────────────────────────────────────────────────────────────
DPO_BETA        = 0.2   # ↑ from default 0.1 to suppress over-optimization / alignment tax
DPO_EPOCHS      = 1
DPO_BATCH_SIZE  = 2
DPO_LR          = 5e-5
DPO_MAX_LENGTH  = 1024

print(f"Base model : {BASE_MODEL}")
print(f"DPO beta   : {DPO_BETA}  (raised to mitigate alignment tax)")

## 2. Load Model in 4-bit (QLoRA)

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B params")

In [None]:
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## 3. Phase 1 — SFT Warm-up on High-Score Chosen Responses

Filter to rows where `score_chosen >= 0.80` (top quality bar from NB01) for the SFT pass.  
This anchors the model in correct factual behavior before DPO reshapes style.

In [None]:
import pandas as pd

# Load pairs produced by NB01 (floor strategy: score_chosen >= 0.70, delta >= 0.10)
df_dpo = pd.read_json(DATA_DIR / "dpo_floor.jsonl", lines=True)

# SFT warm-up: use only the highest-scoring chosen responses
# score_chosen was saved alongside prompt/chosen/rejected in NB01
SFT_SCORE_FLOOR = 0.80
if "score_chosen" in df_dpo.columns:
    df_sft = df_dpo[df_dpo["score_chosen"] >= SFT_SCORE_FLOOR].copy()
else:
    # Fallback: top quartile by position if score not present in file
    df_sft = df_dpo.nlargest(min(700, len(df_dpo)), "score_chosen") if "score_chosen" in df_dpo.columns \
             else df_dpo.sample(min(700, len(df_dpo)), random_state=42)

print(f"SFT warm-up candidates: {len(df_sft):,}")


def to_sft_text(row):
    """Format as chat-style text for causal LM training."""
    return f"### User:\n{row['prompt']}\n\n### Assistant:\n{row['chosen']}"


df_sft["text"] = df_sft.apply(to_sft_text, axis=1)
ds_sft = Dataset.from_pandas(df_sft[["text"]])

print(f"SFT warm-up samples: {len(ds_sft):,}")
print(ds_sft[0]["text"][:300])

In [None]:
sft_args = TrainingArguments(
    output_dir=str(MODELS_DIR / f"{RUN_NAME}-sft"),
    num_train_epochs=SFT_EPOCHS,
    per_device_train_batch_size=SFT_BATCH_SIZE,
    gradient_accumulation_steps=4,
    learning_rate=SFT_LR,
    fp16=False,
    bf16=True,
    logging_steps=10,
    save_strategy="epoch",
    optim="paged_adamw_8bit",
    report_to="wandb" if os.getenv("WANDB_API_KEY") else "none",
    run_name=f"{RUN_NAME}-sft",
)

sft_trainer = SFTTrainer(
    model=model,
    args=sft_args,
    train_dataset=ds_sft,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=SFT_MAX_SEQ_LEN,
)

print("Starting SFT warm-up...")
sft_trainer.train()
sft_trainer.save_model(str(MODELS_DIR / f"{RUN_NAME}-sft-final"))
print("SFT complete.")

## 4. Phase 2 — DPO with Raised β

`β = 0.2` (vs. default 0.1) → stronger KL penalty → model stays closer to SFT distribution, reducing correctness decay.

In [None]:
# Load DPO pairs
ds_dpo = Dataset.from_pandas(df_dpo[["prompt", "chosen", "rejected"]])
split   = ds_dpo.train_test_split(test_size=0.05, seed=42)

print(f"DPO train: {len(split['train']):,}  |  eval: {len(split['test']):,}")

In [None]:
dpo_config = DPOConfig(
    beta=DPO_BETA,
    output_dir=str(MODELS_DIR / f"{RUN_NAME}-dpo"),
    num_train_epochs=DPO_EPOCHS,
    per_device_train_batch_size=DPO_BATCH_SIZE,
    per_device_eval_batch_size=DPO_BATCH_SIZE,
    gradient_accumulation_steps=8,
    learning_rate=DPO_LR,
    bf16=True,
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=100,
    max_length=DPO_MAX_LENGTH,
    max_prompt_length=512,
    optim="paged_adamw_8bit",
    report_to="wandb" if os.getenv("WANDB_API_KEY") else "none",
    run_name=f"{RUN_NAME}-dpo",
)

dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,   # None = use frozen SFT checkpoint as reference
    args=dpo_config,
    train_dataset=split["train"],
    eval_dataset=split["test"],
    tokenizer=tokenizer,
)

print(f"Starting DPO (β={DPO_BETA})...")
dpo_trainer.train()
dpo_trainer.save_model(str(MODELS_DIR / f"{RUN_NAME}-dpo-final"))
print("DPO complete.")

## 5. Quick Inference Check

In [None]:
model.eval()

test_prompt = "My order hasn't arrived after 2 weeks. What should I do?"
inputs = tokenizer(
    f"### User:\n{test_prompt}\n\n### Assistant:\n",
    return_tensors="pt"
).to(device)

with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id,
    )

response = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print("Prompt:", test_prompt)
print("\nResponse:\n", response)

---
## Summary

| Phase | Method | Key setting | Purpose |
|-------|--------|------------|--------|
| 1 | SFT warm-up | 700 high-correctness samples | Anchor factual quality |
| 2 | DPO | β=0.2 (raised KL) | Style alignment without correctness decay |

**Next:** `03_model_evaluation.ipynb` — measure correctness, groundedness, style before and after.