# 04 — Summarizer fine-tune with LoRA (DistilBART)

We’ll adapt a compact English summarizer (`sshleifer/distilbart-cnn-12-6`) using **LoRA** so training is fast and light.
Inputs are `sum_train.csv`, `sum_val.csv`, `sum_test.csv` containing columns: `doc`, `summary`.
We’ll evaluate with **ROUGE** and save a LoRA adapter for your app.

---

## A) Setup & load data

In [1]:
# Cell A1 — Imports & paths
from pathlib import Path
import pandas as pd
import numpy as np
import torch, transformers, evaluate

from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, PeftModel

print("transformers:", transformers.__version__)
print("torch:", torch.__version__)

# Paths (adjust if your notebook lives elsewhere)
DATA_DIR = Path("../data")
OUT_DIR  = Path("..") / "models" / "summarizer_lora_bart"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# CSVs expected: columns -> doc, summary
train_csv = DATA_DIR / "sum_train.csv"
val_csv   = DATA_DIR / "sum_val.csv"
test_csv  = DATA_DIR / "sum_test.csv"

df_tr = pd.read_csv(train_csv)
df_va = pd.read_csv(val_csv)
df_te = pd.read_csv(test_csv)

print("rows:", len(df_tr), len(df_va), len(df_te))
display(df_tr.head(2))

transformers: 4.55.2
torch: 2.8.0+cpu
rows: 1 1 1


Unnamed: 0,video_id,doc,summary
0,iV46TJKL8cU,I'll give Disney some credit. They are brave e...,I'll give Disney some credit. They are brave e...


---

## B) Tokenizer & datasets

In [2]:
# Tokenizer (DistilBART, English)
BASE = "sshleifer/distilbart-cnn-12-6"   # compact, English summarizer
tok  = AutoTokenizer.from_pretrained(BASE)

max_src = 512
max_tgt = 128

def preprocess(batch):
    # encode inputs
    model_inputs = tok(batch["doc"], max_length=max_src, truncation=True)
    # encode targets
    with tok.as_target_tokenizer():
        labels = tok(batch["summary"], max_length=max_tgt, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [3]:
# Build HF datasets
ds = DatasetDict({
    "train": Dataset.from_pandas(df_tr[["doc","summary"]]),
    "val":   Dataset.from_pandas(df_va[["doc","summary"]]),
    "test":  Dataset.from_pandas(df_te[["doc","summary"]]),
}).map(preprocess, batched=True, remove_columns=["doc","summary"])

# Keep raw references for manual evaluation later
val_docs = df_va["doc"].tolist()
val_refs = df_va["summary"].tolist()
test_docs = df_te["doc"].tolist()
test_refs = df_te["summary"].tolist()

Map:   0%|          | 0/1 [00:00<?, ? examples/s]



Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

---

## C) Base model + LoRA adapter

In [4]:
# Cell C1 — Load base model
base = AutoModelForSeq2SeqLM.from_pretrained(BASE)

In [5]:
# Auto-detect attention projection module names & attach LoRA
def find_attn_proj_names(model):
    names = set()
    for n, m in model.named_modules():
        # Typical BART proj leaves: q_proj, k_proj, v_proj, out_proj
        leaf = n.split(".")[-1]
        if leaf in {"q_proj","k_proj","v_proj","out_proj"}:
            names.add(leaf)
    return sorted(names)

present = find_attn_proj_names(base)
print("Detected projection modules:", present)
assert present, "No attention projection modules found. Model layout unexpected."

lora = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=present,
    task_type="SEQ_2_SEQ_LM",
)
model = get_peft_model(base, lora)
model.config.use_cache = False

# Show trainable params to confirm LoRA is active
model.print_trainable_parameters()
n_all = sum(p.numel() for p in model.parameters())
n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable params: {n_train:,} / {n_all:,} ({100*n_train/max(1,n_all):.3f}%)")
assert n_train > 0, "LoRA did not attach: 0 trainable params."

Detected projection modules: ['k_proj', 'out_proj', 'q_proj', 'v_proj']
trainable params: 3,145,728 || all params: 308,656,128 || trainable%: 1.0192
Trainable params: 3,145,728 / 308,656,128 (1.019%)


---

## D) Training setup (compat-friendly)

In [6]:
# TrainingArguments (CPU/GPU friendly)
args = TrainingArguments(
    output_dir=str(OUT_DIR / "ckpt"),
    learning_rate=1e-4,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=16,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    dataloader_pin_memory=False,
    seed=42,
)

In [7]:
# Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["val"],
    data_collator=DataCollatorForSeq2Seq(tok, model=model),
    tokenizer=tok
)

# Sanity check: loss must require grad
batch = next(iter(trainer.get_train_dataloader()))
model.train()
out = model(
    input_ids=batch["input_ids"],
    attention_mask=batch["attention_mask"],
    labels=batch["labels"],
)
print("Loss requires_grad:", out.loss.requires_grad)  # must be True
assert out.loss.requires_grad, "Loss has no grad; LoRA might not be attached or model is frozen."

  trainer = Trainer(


Loss requires_grad: True


---

## E) Train

In [8]:
# Train
trainer.train()

Step,Training Loss


TrainOutput(global_step=5, training_loss=0.407194709777832, metrics={'train_runtime': 14.0808, 'train_samples_per_second': 0.355, 'train_steps_per_second': 0.355, 'total_flos': 3918098595840.0, 'train_loss': 0.407194709777832, 'epoch': 5.0})

In [9]:
# Baseline (no LoRA): fair comparison using identical decoding.
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
base_id = "sshleifer/distilbart-cnn-12-6"
tok0 = AutoTokenizer.from_pretrained(base_id)
m0   = AutoModelForSeq2SeqLM.from_pretrained(base_id).eval()

def gen_base(texts, bs=4, max_in=max_src, max_out=max_tgt):
    outs=[]
    for i in range(0, len(texts), bs):
        enc = tok0(texts[i:i+bs], return_tensors="pt", truncation=True, padding=True, max_length=max_in)
        with torch.no_grad():
            g = m0.generate(
                **enc,
                max_length=max_out,
                min_length=int(max_out*0.7),
                num_beams=4,
                no_repeat_ngram_size=3,
                length_penalty=2.0,
                early_stopping=True
            )
        outs += tok0.batch_decode(g, skip_special_tokens=True)
    return outs

base_preds = gen_base(val_docs)
import evaluate; rouge = evaluate.load("rouge")
print("BASE vs VAL:", {k: round(v,4) for k,v in rouge.compute(predictions=base_preds, references=val_refs).items()})

BASE vs VAL: {'rouge1': np.float64(0.5287), 'rouge2': np.float64(0.3372), 'rougeL': np.float64(0.3103), 'rougeLsum': np.float64(0.3103)}


---

## F) Manual evaluation (ROUGE) with generation

In [10]:
# ROUGE scorer
rouge = evaluate.load("rouge")

In [11]:
# Batch generation helper
def generate_batch(texts, batch_size=4, max_in=512, max_out=128, num_beams=4):
    preds = []
    for i in range(0, len(texts), batch_size):
        chunk = texts[i:i+batch_size]
        enc = tok(chunk, return_tensors="pt", truncation=True, padding=True, max_length=max_in)
        with torch.no_grad():
            out = model.generate(
                **enc,
                max_length=max_out,
                min_length=int(max_out * 0.7),
                num_beams=num_beams,
                no_repeat_ngram_size=3,
                length_penalty=2.0,
                early_stopping=True
            )
        preds.extend(tok.batch_decode(out, skip_special_tokens=True))
    return preds

In [12]:
# Compute ROUGE on val/test
val_preds  = generate_batch(val_docs, batch_size=4, max_in=max_src, max_out=max_tgt, num_beams=4)
test_preds = generate_batch(test_docs, batch_size=4, max_in=max_src, max_out=max_tgt, num_beams=4)

val_scores  = rouge.compute(predictions=val_preds,  references=val_refs)
test_scores = rouge.compute(predictions=test_preds, references=test_refs)

print("VAL ROUGE:", {k: round(v,4) for k,v in val_scores.items()})
print("TEST ROUGE:", {k: round(v,4) for k,v in test_scores.items()})



VAL ROUGE: {'rouge1': np.float64(0.3952), 'rouge2': np.float64(0.1818), 'rougeL': np.float64(0.2156), 'rougeLsum': np.float64(0.2156)}
TEST ROUGE: {'rouge1': np.float64(0.4828), 'rouge2': np.float64(0.2907), 'rougeL': np.float64(0.3103), 'rougeLsum': np.float64(0.3103)}


---

## G) Save adapter + tokenizer

In [13]:
# Save LoRA adapter + tokenizer
trainer.save_model(str(OUT_DIR))
tok.save_pretrained(str(OUT_DIR))
print("Saved adapter to:", OUT_DIR)


Saved adapter to: ..\models\summarizer_lora_bart


---

## H) Inference helper (drop into your repo)

> Copy this into `ml/src/infer_summary.py` or similar so your app can load the adapter easily.

In [14]:
# Inference helper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
import torch

class Summarizer:
    def __init__(self,
                 base="sshleifer/distilbart-cnn-12-6",
                 adapter_path="../../models/summarizer_lora_bart",
                 max_in=512, max_out=128, num_beams=4):
        self.tok = AutoTokenizer.from_pretrained(base)
        base_model = AutoModelForSeq2SeqLM.from_pretrained(base)
        self.model = PeftModel.from_pretrained(base_model, adapter_path).eval()
        self.max_in, self.max_out, self.num_beams = max_in, max_out, num_beams

    def __call__(self, text: str):
        enc = self.tok([text], truncation=True, padding=True, max_length=self.max_in, return_tensors="pt")
        with torch.no_grad():
            out = self.model.generate(
                **enc, max_length=self.max_out, num_beams=self.num_beams
            )
        return self.tok.batch_decode(out, skip_special_tokens=True)[0]

---

## ✅ Wrap-up: What we just did (04_summarizer_finetune_peft)

**TL;DR:** We fine-tuned a compact BART summarizer using **LoRA** (fast, low-VRAM), evaluated with **ROUGE**, and saved the adapter so your Streamlit app can load it on top of the base model.

**Outputs**

* Adapter + tokenizer at: `models/summarizer_lora_bart/`

**Why LoRA?**

* Much smaller to train/save, quicker iterations, and easy to swap on top of a stable base model.