In [None]:
!pip install -q transformers==4.44.2 datasets evaluate rouge-score

from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from torch.optim import AdamW
import torch

In [None]:
import os, torch, pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from transformers import (
    T5TokenizerFast, T5ForConditionalGeneration, AdamW
)
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import evaluate

In [None]:
SAMPLE_FRAC = 0.1
MAX_INPUT_LEN, MAX_TARGET_LEN = 512, 150
BATCH_SIZE, EPOCHS = 8, 2
MODEL_NAME = "t5-small"
KAGGLE_PATH = Path("/kaggle/input/newspaper-text-summarization-cnn-dailymail/cnn_dailymail")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
if KAGGLE_PATH.exists():
    train_df = pd.read_csv(KAGGLE_PATH / "train.csv")[["article","highlights"]].dropna().reset_index(drop=True)
    val_df = pd.read_csv(KAGGLE_PATH / "validation.csv")[["article","highlights"]].dropna().reset_index(drop=True)
    test_df = pd.read_csv(KAGGLE_PATH / "test.csv")[["article","highlights"]].dropna().reset_index(drop=True)
else:
    from datasets import load_dataset
    ds = load_dataset("cnn_dailymail", "3.0.0")
    train_df = pd.DataFrame(ds["train"])
    val_df = pd.DataFrame(ds["validation"])
    test_df = pd.DataFrame(ds["test"])

In [None]:
if 0 < SAMPLE_FRAC < 1:
    train_df = train_df.sample(frac=SAMPLE_FRAC, random_state=42).reset_index(drop=True)
    val_df = val_df.sample(frac=SAMPLE_FRAC, random_state=42).reset_index(drop=True)
    test_df = test_df.sample(frac=SAMPLE_FRAC, random_state=42).reset_index(drop=True)

for df in [train_df, val_df, test_df]:
    df["input"] = "summarize: " + df["article"].astype(str).str.strip()
    df["target"] = df["highlights"].astype(str).str.strip()

In [None]:
tok = T5TokenizerFast.from_pretrained(MODEL_NAME)

class SummDataset(Dataset):
    def __init__(self, df):
        self.x, self.y = df["input"].tolist(), df["target"].tolist()
    def __len__(self): return len(self.x)
    def __getitem__(self, i):
        enc = tok(self.x[i], truncation=True, max_length=MAX_INPUT_LEN, padding="max_length", return_tensors="pt")
        dec = tok(self.y[i], truncation=True, max_length=MAX_TARGET_LEN, padding="max_length", return_tensors="pt")
        ids, mask, labels = enc["input_ids"].squeeze(), enc["attention_mask"].squeeze(), dec["input_ids"].squeeze()
        labels[labels == tok.pad_token_id] = -100
        return ids, mask, labels

def collate(b):
    ids, m, l = zip(*b)
    return torch.stack(ids), torch.stack(m), torch.stack(l)

train_loader = DataLoader(SummDataset(train_df), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
val_loader = DataLoader(SummDataset(val_df), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

In [None]:
from transformers import T5ForConditionalGeneration, T5Config
from torch.optim import AdamW

config = T5Config.from_pretrained(
    MODEL_NAME,
    dropout_rate=0.2,
    attention_dropout_rate=0.2
)

model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, config=config).to(device)

opt = AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

In [None]:
for ep in range(EPOCHS):
    model.train(); total = 0
    for ids, mask, labels in tqdm(train_loader, desc=f"Epoch {ep+1}"):
        ids, mask, labels = ids.to(device), mask.to(device), labels.to(device)
        opt.zero_grad()
        loss = model(input_ids=ids, attention_mask=mask, labels=labels).loss
        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total += loss.item()
    print(f"Epoch {ep+1} Loss: {total/len(train_loader):.4f}")

In [None]:
model.save_pretrained("models/t5-summarizer")
tok.save_pretrained("models/t5-summarizer")

In [None]:
tok = T5TokenizerFast.from_pretrained("models/t5-summarizer")
model = T5ForConditionalGeneration.from_pretrained("models/t5-summarizer").to(device)
model.eval()

In [None]:
def summarize(text):
    enc = tok("summarize: " + text, truncation=True, max_length=MAX_INPUT_LEN, return_tensors="pt").to(device)
    ids = model.generate(**enc, max_new_tokens=MAX_TARGET_LEN, num_beams=4, early_stopping=True, no_repeat_ngram_size=3)
    return tok.decode(ids[0], skip_special_tokens=True)

print(summarize(test_df["article"].iloc[0]))

In [None]:
rouge = evaluate.load("rouge")
preds, refs = [], []
for i in range(min(200, len(test_df))):
    preds.append(summarize(test_df["article"].iloc[i]).strip())
    refs.append(test_df["target"].iloc[i].strip())
res = {k: round(v,4) for k,v in rouge.compute(predictions=preds, references=refs, use_stemmer=True).items()}
res

In [None]:
torch.save(model.state_dict(), "t5_weights.pt")
print("Weights Saved!")