# Acme Conversation Converter

### Imports

In [27]:
import math
import random
from dataclasses import dataclass
from typing import Dict, Any, List
from collections import Counter

import pandas as pd
import numpy as np
from datasets import load_dataset

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, BartForConditionalGeneration, get_linear_schedule_with_warmup

from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True      # safe, faster matmul
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x2c9c26eb530>

### Load and Explore Samsum Dataset

In [28]:
# Load Samsum dataset
df = None
for repo in ["samsum", "knkarthick/samsum"]:
    try:
        df = load_dataset(repo)
        break
    except:
        continue

df_train = df["train"].to_pandas()[["id","dialogue","summary"]]
df_val   = df["validation"].to_pandas()[["id","dialogue","summary"]]
df_test  = df["test"].to_pandas()[["id","dialogue","summary"]] if "test" in df else None

print(len(df_train), len(df_val), len(df_test) if df_test is not None else 0)

print("Train shape:", df_train.shape, "| Val shape:", df_val.shape)
print("Sample rows:\n", df_train.head(5))

14731 818 819
Train shape: (14731, 3) | Val shape: (818, 3)
Sample rows:
          id                                           dialogue  \
0  13818513  Amanda: I baked  cookies. Do you want some?\nJ...   
1  13728867  Olivia: Who are you voting for in this electio...   
2  13681000  Tim: Hi, what's up?\nKim: Bad mood tbh, I was ...   
3  13730747  Edward: Rachel, I think I'm in ove with Bella....   
4  13728094  Sam: hey  overheard rick say something\nSam: i...   

                                             summary  
0  Amanda baked cookies and will bring Jerry some...  
1  Olivia and Olivier are voting for liberals in ...  
2  Kim may try the pomodoro technique recommended...  
3  Edward thinks he is in love with Bella. Rachel...  
4  Sam is confused, because he overheard Rick com...  


In [29]:
def add_len_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df["dialogue_word_count"] = df["dialogue"].str.split().apply(len)
    df["summary_word_count"]  = df["summary"].str.split().apply(len)
    return df

eda_train = add_len_cols(df_train)
eda_val   = add_len_cols(df_val)

print("\n=== Train word count summary ===")
print(eda_train[["dialogue_word_count", "summary_word_count"]].describe())

print("\n=== Validation word count summary ===")
print(eda_val[["dialogue_word_count", "summary_word_count"]].describe())


=== Train word count summary ===
       dialogue_word_count  summary_word_count
count         14731.000000        14731.000000
mean             93.792750           20.318444
std              74.031937           11.153570
min               7.000000            1.000000
25%              39.000000           12.000000
50%              73.000000           18.000000
75%             128.000000           27.000000
max             803.000000           64.000000

=== Validation word count summary ===
       dialogue_word_count  summary_word_count
count           818.000000          818.000000
mean             91.641809           20.283619
std              74.479672           11.211454
min              10.000000            3.000000
25%              38.000000           12.000000
50%              70.000000           18.000000
75%             127.000000           26.000000
max             540.000000           59.000000


### BART Model

In [None]:
# TOKENIZERS (BERT + GPT)
# BART tokenizer & model
from transformers import DataCollatorForSeq2Seq

bart_name = "sshleifer/distilbart-cnn-12-6"
bart_tok = AutoTokenizer.from_pretrained(bart_name)
model = BartForConditionalGeneration.from_pretrained(bart_name).to(device)

model.gradient_checkpointing_disable()   # saves memory
model.config.use_cache = False

model.config.output_attentions = False
model.config.output_hidden_states = False



class BartSumDataset(Dataset):
    def __init__(self, df, tokenizer, max_source_len=512, max_target_len=128):
        self.df = df.reset_index(drop=True)
        self.tok = tokenizer
        self.src_len = max_source_len
        self.tgt_len = max_target_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        row = self.df.iloc[idx]
        dialogue = str(row["dialogue"])
        summary  = str(row["summary"])

        # Encode source (dialogue)
        enc = self.tok(
            dialogue,
            max_length=self.src_len,
            truncation=True,
            return_tensors="pt",          # returns torch.Tensor
        )

        # Encode target (summary)
        dec = self.tok(
            text_target=summary,
            max_length=self.tgt_len,
            truncation=True,
            return_tensors="pt",          # returns torch.Tensor
        )

        labels = dec["input_ids"].clone()
        labels[labels == self.tok.pad_token_id] = -100

        # Squeeze away the batch dim since return_tensors="pt" adds it
        return {
            "input_ids": enc["input_ids"].squeeze(0),         # torch.LongTensor
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": labels.squeeze(0),
        }

# lengths & batch (full training settings)
MAX_SOURCE_LEN  = 320
MAX_TARGET_LEN = 96
BATCH_SIZE   = 8
NUM_WORKERS  = 0
PIN_MEMORY   = torch.cuda.is_available()
PREFETCH     = None

# For TRAINING speed:
model.config.use_cache = False        # cache hurts training speed/mem
model.gradient_checkpointing_enable()# disable for speed (enable only if you need memory)
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")


train_ds = BartSumDataset(df_train, bart_tok, MAX_SOURCE_LEN, MAX_TARGET_LEN)
val_ds   = BartSumDataset(df_val,   bart_tok, MAX_SOURCE_LEN, MAX_TARGET_LEN)
test_ds  = BartSumDataset(df_test,  bart_tok, MAX_SOURCE_LEN, MAX_TARGET_LEN) if df_test is not None else None

pad_to_mult = 8 if torch.cuda.is_available() else None

collator = DataCollatorForSeq2Seq(
    tokenizer=bart_tok,
    model=model,                     
    label_pad_token_id=-100,
    pad_to_multiple_of=pad_to_mult
)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collator, num_workers=NUM_WORKERS,
                          pin_memory=PIN_MEMORY, prefetch_factor=PREFETCH, persistent_workers=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collator, num_workers=NUM_WORKERS,
                          pin_memory=PIN_MEMORY, prefetch_factor=PREFETCH, persistent_workers=False)
test_loader  = (DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                           collate_fn=collator, num_workers=NUM_WORKERS, 
                           pin_memory=PIN_MEMORY, prefetch_factor=PREFETCH, persistent_workers=False)
                if test_ds is not None else None)
print("Batches/epoch:", len(train_loader), "val:", len(val_loader))

# Prompt template for causal LM
PROMPT_PREFIX = "Summarize the following dialogue:\n"
PROMPT_SUFFIX = "\nSummary:"

Batches/epoch: 1842 val: 103


In [32]:

model = BartForConditionalGeneration.from_pretrained(bart_name).to(device)
model.config.pad_token_id = bart_tok.pad_token_id
model.config.bos_token_id = bart_tok.bos_token_id
model.config.eos_token_id = bart_tok.eos_token_id

# Epochs / lr / regularization
EPOCHS            = 3
LR                = 3e-5
WEIGHT_DECAY      = 0.01
GRAD_ACCUM_STEPS  = 4      # effective batch ~= BATCH_SIZE * GRAD_ACCUM_STEPS
MAX_GRAD_NORM     = 1.0
WARMUP_RATIO      = 0.06

# Optimizer (AdamW) with weight decay on non-bias/LayerNorm
decay, no_decay = [], []
for n,p in model.named_parameters():
    (decay if not any(nd in n for nd in ["bias","LayerNorm.weight"]) else no_decay).append(p)
optimizer = AdamW(
    [{"params": decay, "weight_decay": WEIGHT_DECAY},
     {"params": no_decay, "weight_decay": 0.0}],
    lr=LR,
    fused=torch.cuda.is_available()
)


updates_per_epoch = (len(train_loader) + GRAD_ACCUM_STEPS - 1) // GRAD_ACCUM_STEPS
total_updates     = EPOCHS * updates_per_epoch
warmup_steps      = max(1, int(WARMUP_RATIO * total_updates))

scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_updates)

scaler = torch.amp.GradScaler("cuda", enabled=(torch.cuda.is_available()))

In [7]:
import evaluate
rouge = evaluate.load("rouge")



### Training BART Model

In [33]:
def train():
    device = next(model.parameters()).device
    model.train()
    global_step = 0
    log_every_updates = 10     # prints every N optimizer updates (not batches)

    for epoch in range(1, EPOCHS + 1):
        running_loss = 0.0
        running_updates = 0
        optimizer.zero_grad(set_to_none=True)

        # OPTIONAL: warm up one batch to trigger kernel JIT etc.
        # (Helps make your first timings realistic)
        # _warm = next(iter(train_loader))
        # _warm = {k: v.to(device, non_blocking=True) for k,v in _warm.items()}
        # with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
        #     _ = model(**_warm).loss.backward()
        # optimizer.zero_grad(set_to_none=True)
        # if device.type == "cuda": torch.cuda.synchronize()

        for step, batch in enumerate(train_loader, 1):
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

            with torch.amp.autocast("cuda", enabled=(device.type=="cuda")):
                out  = model(**batch)                         # Seq2SeqLM returns .loss
                loss = out.loss / GRAD_ACCUM_STEPS

            scaler.scale(loss).backward()
            running_loss += loss.item()

            if step % GRAD_ACCUM_STEPS == 0:
                # Unscale before clipping for stable norms
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()

                global_step += 1
                running_updates += 1

                if running_updates % log_every_updates == 0:
                    # report avg loss over the last `log_every_updates` updates
                    avg = running_loss / (log_every_updates)
                    print(f"[E{epoch}] upd {global_step}/{total_updates} | loss {avg:.4f}")
                    running_loss = 0.0

        # end epoch
        if running_updates % log_every_updates != 0 and running_updates > 0:
            avg = running_loss / max(1, (running_updates % log_every_updates))
            print(f"[E{epoch}] upd {global_step}/{total_updates} | loss {avg:.4f}")


train()

[E1] upd 10/1383 | loss 2.9553
[E1] upd 20/1383 | loss 2.2832
[E1] upd 30/1383 | loss 1.9457
[E1] upd 40/1383 | loss 1.8372
[E1] upd 50/1383 | loss 1.8606
[E1] upd 60/1383 | loss 1.7278
[E1] upd 70/1383 | loss 1.7562
[E1] upd 80/1383 | loss 1.7223
[E1] upd 90/1383 | loss 1.6419
[E1] upd 100/1383 | loss 1.7306
[E1] upd 110/1383 | loss 1.6942
[E1] upd 120/1383 | loss 1.6607
[E1] upd 130/1383 | loss 1.6507
[E1] upd 140/1383 | loss 1.7588
[E1] upd 150/1383 | loss 1.6814
[E1] upd 160/1383 | loss 1.6132
[E1] upd 170/1383 | loss 1.6255
[E1] upd 180/1383 | loss 1.6999
[E1] upd 190/1383 | loss 1.6131
[E1] upd 200/1383 | loss 1.6252
[E1] upd 210/1383 | loss 1.6923
[E1] upd 220/1383 | loss 1.5971
[E1] upd 230/1383 | loss 1.5886
[E1] upd 240/1383 | loss 1.6558
[E1] upd 250/1383 | loss 1.6051
[E1] upd 260/1383 | loss 1.6142
[E1] upd 270/1383 | loss 1.6184
[E1] upd 280/1383 | loss 1.6106
[E1] upd 290/1383 | loss 1.5640
[E1] upd 300/1383 | loss 1.5570
[E1] upd 310/1383 | loss 1.5437
[E1] upd 320/1383

### Rouge-lite function for evaluation

In [34]:
import re

_word_re = re.compile(r"\w+('\w+)?", re.UNICODE)

def _tokenize(s: str):
    return _word_re.findall(s.lower())

def _f1(overlap, ref_count, pred_count):
    if overlap == 0: return 0.0
    precision = overlap / pred_count if pred_count else 0.0
    recall    = overlap / ref_count  if ref_count  else 0.0
    if precision + recall == 0: return 0.0
    return 2 * precision * recall / (precision + recall)

def _ngram_counts(tokens, n):
    return Counter(tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1))

def _rouge_n(ref_tokens, pred_tokens, n):
    ref = _ngram_counts(ref_tokens, n)
    pred = _ngram_counts(pred_tokens, n)
    overlap = sum((ref & pred).values())
    return _f1(overlap, sum(ref.values()), sum(pred.values()))

def _lcs_len(a, b):
    # classic DP LCS length
    m, n = len(a), len(b)
    dp = [0]*(n+1)
    for i in range(1, m+1):
        prev = 0
        for j in range(1, n+1):
            tmp = dp[j]
            if a[i-1] == b[j-1]:
                dp[j] = prev + 1
            else:
                dp[j] = max(dp[j], dp[j-1])
            prev = tmp
    return dp[-1]

def _rouge_l(ref_tokens, pred_tokens):
    lcs = _lcs_len(ref_tokens, pred_tokens)
    return _f1(lcs, len(ref_tokens), len(pred_tokens))

def compute_rouge_light(preds, refs):
    """NLTK-free ROUGE: returns {'rouge1','rouge2','rougeL'} (F1)."""
    assert len(preds) == len(refs)
    r1=r2=rl=0.0
    for p, r in zip(preds, refs):
        pt = _tokenize(p); rt = _tokenize(r)
        r1 += _rouge_n(rt, pt, 1)
        r2 += _rouge_n(rt, pt, 2)
        rl += _rouge_l(rt, pt)
    n = max(1, len(preds))
    return {"rouge1": r1/n, "rouge2": r2/n, "rougeL": rl/n}

bertscore_metric = evaluate.load("bertscore")

def get_bertscore(preds, refs, lang="en"):
    """
    Compute BERTScore F1 between predictions and references.
    Args:
        preds (list[str]): Model-generated summaries
        refs (list[str]): Reference summaries
        lang (str): Language code (default: English)
    Returns:
        float: Average F1 score
    """
    results = bertscore_metric.compute(predictions=preds, references=refs, lang=lang)
    return float(sum(results["f1"]) / len(results["f1"]))
    
def basic_diag(preds, refs):
    import numpy as np, re
    def wc(s): return len(re.findall(r"\w+", s))
    len_pred = np.array([wc(p) for p in preds]); len_ref = np.array([wc(r) for r in refs])
    ratios   = len_pred / np.maximum(1, len_ref)
    return {
        "avg_len_pred": float(len_pred.mean()),
        "avg_len_ref":  float(len_ref.mean()),
        "len_ratio_mean": float(ratios.mean()),
        "len_ratio_median": float(np.median(ratios)),
    }

def decode_refs_from_labels(tokenizer, label_tensor_batch):
    refs = []
    for t in label_tensor_batch:
        t = t[t != -100]
        refs.append(tokenizer.decode(t, skip_special_tokens=True).strip())
    return refs

### BART Evaluation

In [35]:
from typing import List
import torch

@torch.no_grad()
def evaluate_bart(model, loader, tokenizer, max_len=128, num_beams=4, limit=None):
    model.eval()
    preds, refs = [], []
    for i, batch in enumerate(loader):
        if limit is not None and i >= limit: break
        inp = batch["input_ids"].to(model.device)
        att = batch["attention_mask"].to(model.device)
        gen = model.generate(
            input_ids=inp, attention_mask=att,
            max_length=max_len, num_beams=num_beams, early_stopping=True,
            pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
        )
        preds.extend(tokenizer.batch_decode(gen, skip_special_tokens=True))
        refs.extend(decode_refs_from_labels(tokenizer, batch["labels"]))
    r = compute_rouge_light(preds, refs)
    bs = get_bertscore(preds, refs) 
    diag = basic_diag(preds, refs)
    df = pd.DataFrame({"model":"BART", "pred": preds, "ref": refs})
    return r, bs, diag, df




In [None]:
# Use existing loaders;
MAX_TARGET_LEN = 96  # keep consistent with training

preds, refs, scores = evaluate_bart(
    model, val_loader, bart_tok,
    max_len=MAX_TARGET_LEN,
    num_beams=4,        # try 2 for speed, or 1 for greedy
    limit= None          # set to e.g. 100 to do a quick subset
)

print({k: round(v, 4) for k, v in scores.items()})

# Save side-by-side for inspection
import pandas as pd
pd.DataFrame({"pred": preds, "ref": refs}).to_csv("bart_val_predictions.csv", index=False)
print("Saved: bart_val_predictions.csv")


In [None]:
bertscore = evaluate.load("bertscore")
scores = bertscore.compute(predictions=preds, references=refs, lang="en")
print("BERTScore F1:", sum(scores["f1"])/len(scores["f1"]))

#### GPT Model

In [5]:
GPT_NAME = "distilgpt2"

gpt_tok = AutoTokenizer.from_pretrained(GPT_NAME)
if gpt_tok.pad_token is None:  # ensure proper padding & clean decoding
    gpt_tok.pad_token = gpt_tok.eos_token

gpt = AutoModelForCausalLM.from_pretrained(GPT_NAME)
gpt.resize_token_embeddings(len(gpt_tok))  # if pad token was added
gpt.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [38]:
# Building GPT Dataset
class GPTDataset(Dataset):
    """
    Dataset for auto-regressive LM summarization.
    Builds: prompt = PREFIX + dialogue + SUFFIX + " " + summary
    Masks prompt tokens with -100 so only summary contributes to loss.
    """
    def __init__(self, df, tokenizer, max_length=1024,
                 prompt_prefix="Summarize the following dialogue:\n",
                 prompt_suffix="\nSummary:"):
        self.df = df.reset_index(drop=True)
        self.tok = tokenizer
        self.max_length = max_length
        self.prefix = prompt_prefix
        self.suffix = prompt_suffix

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        dialogue = str(row["dialogue"])
        summary  = str(row["summary"])

        # prompt = instruction + dialogue + suffix
        prompt = f"{self.prefix}{dialogue}{self.suffix} "
        tok_prompt = self.tok(prompt, add_special_tokens=False,
                              truncation=True, max_length=self.max_length)
        tok_full   = self.tok(prompt + summary, add_special_tokens=False,
                              truncation=True, max_length=self.max_length)

        ids = tok_full["input_ids"][: self.max_length]
        attn = [1] * len(ids)

        # mask out prompt portion
        prompt_len = min(len(tok_prompt["input_ids"]), len(ids))
        labels = [-100] * prompt_len + ids[prompt_len:]

        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(attn, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }

In [39]:
# GPT Collator
@dataclass
class GPTCollator:
    tokenizer: any
    label_pad_id: int = -100

    def __call__(self, features):
        input_ids = [f["input_ids"] for f in features]
        attn      = [f["attention_mask"] for f in features]
        labels    = [f["labels"] for f in features]

        input_ids_padded = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        attn_padded = torch.nn.utils.rnn.pad_sequence(
            attn, batch_first=True, padding_value=0
        )
        labels_padded = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=self.label_pad_id
        )

        return {"input_ids": input_ids_padded,
                "attention_mask": attn_padded,
                "labels": labels_padded}

In [67]:
MAX_GPT_LEN  = 1024
BATCH_SIZE   = 8
NUM_WORKERS  = 0
PIN_MEMORY   = torch.cuda.is_available()
PREFETCH     = None

torch.set_float32_matmul_precision("high")  # TF32 fast matmuls on Ampere+
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)

gpt.gradient_checkpointing_disable()
gpt.config.use_cache = False

gpt_train_ds = GPTDataset(df_train, gpt_tok, max_length=MAX_GPT_LEN,
                                prompt_prefix=PROMPT_PREFIX, prompt_suffix=PROMPT_SUFFIX)
gpt_val_ds   = GPTDataset(df_val,   gpt_tok, max_length=MAX_GPT_LEN,
                                prompt_prefix=PROMPT_PREFIX, prompt_suffix=PROMPT_SUFFIX)

gpt_collator = GPTCollator(tokenizer=gpt_tok)

gpt_train_loader = DataLoader(
    gpt_train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=gpt_collator, num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY, prefetch_factor=PREFETCH, persistent_workers=False
)
gpt_val_loader = DataLoader(
    gpt_val_ds, batch_size=16, shuffle=False,
    collate_fn=gpt_collator, num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY, prefetch_factor=PREFETCH, persistent_workers=False
)
print("GPT loaders ready:", len(gpt_train_loader), len(gpt_val_loader))

GPT loaders ready: 1842 52


### GPT Training

In [42]:
EPOCHS             = 3
LR                 = 5e-5
WEIGHT_DECAY       = 0.01
GRAD_ACCUM_STEPS   = 8
MAX_GRAD_NORM      = 1.0
WARMUP_RATIO       = 0.06

# Optimizer with weight decay on non-bias/LayerNorm
decay, nodecay = [], []
for n,p in gpt.named_parameters():
    (decay if not any(nd in n for nd in ["bias","LayerNorm.weight"]) else nodecay).append(p)
gpt_optim = AdamW([{"params": decay, "weight_decay": WEIGHT_DECAY},
                   {"params": nodecay, "weight_decay": 0.0}], lr=LR, fused=True)

updates_per_epoch = math.ceil(len(gpt_train_loader) / GRAD_ACCUM_STEPS)
total_updates     = EPOCHS * updates_per_epoch
warmup_steps      = max(1, int(WARMUP_RATIO * total_updates))

gpt_sched = get_linear_schedule_with_warmup(gpt_optim, warmup_steps, total_updates)
scaler    = torch.amp.GradScaler(enabled=(device.type=="cuda"))

def train_gpt():
    global_step = 0
    gpt.train()
    for epoch in range(1, EPOCHS+1):
        running = 0.0
        gpt_optim.zero_grad(set_to_none=True)
        for step, batch in enumerate(gpt_train_loader, 1):
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            # PyTorch 2.x autocast
            with torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
                out  = gpt(**batch)                 # labels mask the prompt to -100
                loss = out.loss / GRAD_ACCUM_STEPS

            scaler.scale(loss).backward()
            running += loss.item()

            if step % GRAD_ACCUM_STEPS == 0:
                scaler.unscale_(gpt_optim)
                clip_grad_norm_(gpt.parameters(), MAX_GRAD_NORM)
                scaler.step(gpt_optim); scaler.update()
                gpt_optim.zero_grad(set_to_none=True)
                gpt_sched.step()
                global_step += 1

            LOG_EVERY = max(10, GRAD_ACCUM_STEPS * 5)

            if step % LOG_EVERY == 0:
                print(f"[GPT][E{epoch}] step {step}/{len(gpt_train_loader)} | running_loss {running/LOG_EVERY:.4f}")
                running = 0.0


train_gpt()

[GPT][E1] step 40/1842 | running_loss 0.2598
[GPT][E1] step 80/1842 | running_loss 0.2719
[GPT][E1] step 120/1842 | running_loss 0.2621
[GPT][E1] step 160/1842 | running_loss 0.2669
[GPT][E1] step 200/1842 | running_loss 0.2719
[GPT][E1] step 240/1842 | running_loss 0.2554
[GPT][E1] step 280/1842 | running_loss 0.2733
[GPT][E1] step 320/1842 | running_loss 0.2604
[GPT][E1] step 360/1842 | running_loss 0.2614
[GPT][E1] step 400/1842 | running_loss 0.2654
[GPT][E1] step 440/1842 | running_loss 0.2721
[GPT][E1] step 480/1842 | running_loss 0.2661
[GPT][E1] step 520/1842 | running_loss 0.2672
[GPT][E1] step 560/1842 | running_loss 0.2688
[GPT][E1] step 600/1842 | running_loss 0.2658
[GPT][E1] step 640/1842 | running_loss 0.2600
[GPT][E1] step 680/1842 | running_loss 0.2654
[GPT][E1] step 720/1842 | running_loss 0.2619
[GPT][E1] step 760/1842 | running_loss 0.2615
[GPT][E1] step 800/1842 | running_loss 0.2616
[GPT][E1] step 840/1842 | running_loss 0.2708
[GPT][E1] step 880/1842 | running_lo

### GPT Evaluation

In [76]:
# EVAL tokenizer (left padding)
gpt_tok_eval = AutoTokenizer.from_pretrained(GPT_NAME, padding_side="left")
if gpt_tok_eval.pad_token is None:
    gpt_tok_eval.pad_token = gpt_tok_eval.eos_token

# Left-pad collator for decoder-only models
from dataclasses import dataclass
from typing import Dict, List
import torch

@dataclass
class GPTLeftPadCollator:
    tokenizer: any
    label_pad_id: int = -100
    def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        pad_id = self.tokenizer.pad_token_id
        bs = len(features)
        max_len = max(f["input_ids"].size(0) for f in features)

        input_ids  = torch.full((bs, max_len), pad_id, dtype=torch.long)
        attn_mask  = torch.zeros((bs, max_len), dtype=torch.long)
        labels_out = torch.full((bs, max_len), self.label_pad_id, dtype=torch.long)

        for i, f in enumerate(features):
            L = f["input_ids"].size(0)
            input_ids[i,  -L:] = f["input_ids"]
            attn_mask[i,  -L:] = f["attention_mask"]
            labels_out[i,  -L:] = f["labels"]
        return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels_out}

gpt_collator_eval = GPTLeftPadCollator(tokenizer=gpt_tok_eval)

gpt_val_loader = DataLoader(
    gpt_val_ds, batch_size=BATCH_SIZE, shuffle=False,
    collate_fn=gpt_collator_eval, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
)


In [78]:

import torch
import torch.nn.functional as F

@torch.no_grad()
def evaluate_gpt(
    model,
    loader,
    tokenizer,
    max_new_tokens=96,
    num_beams=1,
    limit=None,
    pad_to_multiple_of=None,     # None to disable; 8 helps tensor cores
    min_new_tokens=8,
    no_repeat_ngram_size=3,   # small repetition guard; set 0/None to disable
    do_bertscore=False,        # turn off if you want max speed
):
    device = next(model.parameters()).device

    # Speed up AR decoding
    was_ckpt = getattr(model, "is_gradient_checkpointing", False)
    if was_ckpt:
        model.gradient_checkpointing_disable()
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = True

    model.eval()
    all_preds, all_refs = [], []
    seen = 0

    with torch.inference_mode(), torch.amp.autocast(device_type=device.type, enabled=(device.type=="cuda")):
        for batch in loader:
            if limit is not None and seen >= limit:
                break

            # ---- keep labels on CPU (we only need them to decode refs)
            labs_cpu = batch["labels"]  # (don’t move to GPU)

            # Move inputs to GPU once
            ids  = batch["input_ids"].to(device, non_blocking=True).long()
            attn = batch["attention_mask"].to(device, non_blocking=True).long()

            # Sort by length to reduce padding/FLOPs (done on GPU)
            lengths = attn.sum(dim=1)
            order   = torch.argsort(lengths)          # GPU
            inv     = torch.empty_like(order); inv[order] = torch.arange(order.numel(), device=device)

            ids  = ids[order]
            attn = attn[order]
            lens_sorted = lengths[order]

            # Optional: LEFT-pad to multiple of 8 (do it on GPU)
            if pad_to_multiple_of:
                pad_to = ((ids.size(1) + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
                if pad_to > ids.size(1):
                    left_pad = pad_to - ids.size(1)
                    ids  = F.pad(ids,  (left_pad, 0), value=tokenizer.pad_token_id)
                    attn = F.pad(attn, (left_pad, 0), value=0)

            # Compute safe new-token budget for the whole sub-batch
            max_ctx = getattr(tokenizer, "model_max_length", 1024) or 1024
            longest_prompt = int(attn.sum(dim=1).max().item())
            max_new = max(1, min(max_new_tokens, max_ctx - longest_prompt))

            # ---- One batched generate() on GPU
            gen = model.generate(
                input_ids=ids,
                attention_mask=attn,
                max_new_tokens=max_new,
                min_new_tokens=min_new_tokens,
                num_beams=num_beams,          # 1 = greedy
                do_sample=False,
                no_repeat_ngram_size=no_repeat_ngram_size,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )  # [B, prompt+gen] on GPU

            # Slice completions by *padded input width* (padding-agnostic)
            start = ids.size(1)                  # same for the whole sorted sub-batch
            comps = gen[:, start:]               # still on GPU

            # Decode preds (minimal CPU touch)
            # Move only the completions we need to CPU for decoding
            pred_texts_sorted = tokenizer.batch_decode(comps.cpu(), skip_special_tokens=True)

            # Decode refs on CPU (labels never moved to GPU)
            ref_texts_sorted = [
                tokenizer.decode(l[l != -100], skip_special_tokens=True).strip()
                for l in labs_cpu[order.cpu()]
            ]

            # Restore original order
            inv_cpu = inv.cpu().tolist()
            pred_texts = [pred_texts_sorted[k] for k in inv_cpu]
            ref_texts  = [ref_texts_sorted[k]  for k in inv_cpu]

            all_preds.extend(pred_texts)
            all_refs.extend(ref_texts)
            seen += len(pred_texts)

            if limit is not None and seen >= limit:
                break

    # Restore training-time flags
    if hasattr(model.config, "use_cache"):
        model.config.use_cache = False
    if was_ckpt:
        model.gradient_checkpointing_enable()

    # Filter empty pairs for metrics stability
    pairs = [(p, r) for p, r in zip(all_preds, all_refs) if p and r]
    if not pairs:
        return all_preds, all_refs, {"rouge1":0, "rouge2":0, "rougeL":0}, 0.0
    preds_f, refs_f = map(list, zip(*pairs))

    # ROUGE (lightweight) — requires strings (CPU)
    rouge = compute_rouge_light(preds_f, refs_f)

    # (Optional) BERTScore — can be slow; toggle with do_bertscore
    bert_f1 = 0.0
    if do_bertscore:
        try:
            import evaluate
            bertscore = evaluate.load("bertscore")
            bs = bertscore.compute(predictions=preds_f, references=refs_f, lang="en")
            bert_f1 = float(sum(bs["f1"]) / len(bs["f1"]))
        except Exception:
            bert_f1 = 0.0

    return all_preds, all_refs, rouge, bert_f1


In [79]:
# Run full eval (set limit to e.g. 200 to sample quickly)
preds, refs, rouge_metrics, bert_f1 = evaluate_gpt(
    gpt, gpt_val_loader, gpt_tok_eval, max_new_tokens=96, num_beams=1, limit=None
)

print({k: round(v, 4) for k, v in rouge_metrics.items()}, "| BERTScore F1:", round(bert_f1, 4))

pd.DataFrame({"pred": preds, "ref": refs}).to_csv("gpt_val_predictions.csv", index=False)
print("Saved: gpt_val_predictions.csv")

{'rouge1': 0.4268, 'rouge2': 0.4054, 'rougeL': 0.4255} | BERTScore F1: 0.0
Saved: gpt_val_predictions.csv


In [80]:
gpt.save_pretrained("./checkpoints/distilgpt2-samsum")
gpt_tok.save_pretrained("./checkpoints/distilgpt2-samsum")

# Later:
# gpt_tok = AutoTokenizer.from_pretrained("./checkpoints/distilgpt2-samsum")
# gpt     = AutoModelForCausalLM.from_pretrained("./checkpoints/distilgpt2-samsum").to(device)


('./checkpoints/distilgpt2-samsum\\tokenizer_config.json',
 './checkpoints/distilgpt2-samsum\\special_tokens_map.json',
 './checkpoints/distilgpt2-samsum\\vocab.json',
 './checkpoints/distilgpt2-samsum\\merges.txt',
 './checkpoints/distilgpt2-samsum\\added_tokens.json',
 './checkpoints/distilgpt2-samsum\\tokenizer.json')

### Samples from the Models

In [None]:
def show_samples(df, k=5, seed=13):
    import numpy as np
    np.random.seed(seed)
    idx = np.random.choice(len(df), size=min(k, len(df)), replace=False)
    for i, j in enumerate(idx, 1):
        row = df.iloc[j]
        print(f"\n[{row['model']}] Example {i}")
        print("PRED:", row["pred"])
        print("REF :", row["ref"])
gpt_df = pd.DataFrame({"pred": preds, "ref": refs})
gpt_df["model"] = "GPT"

print("\n--- GPT samples ---")
show_samples(gpt_df,  k=5)


--- GPT samples ---

[GPT] Example 1
PRED:  Carol will book them.    Charles will pay a little more for the better seats.  Carol will let Carol know.  Charles is going to the Opera next month and will book the tickets.  He will book Carmen for the opera.  She will book her tickets. Carol is going on a trip to the theatre next month, so she will book a few more seats. Carol has to book the seats. She will be going to a theatre next year. 
REF : will book tickets for Carmen for himself and Carol.

[GPT] Example 2
PRED:  Jenny will let her in.    Jenny will get her flu.  Jenny has a key.  She will let herself in so she will take her key out.  Joins her to the shop in the morning.  The key is in the door.  It will be taken by Jenny.  He will let Jenny in. Joins Jenny in the afternoon.  They will have a cold and they will have flu. Join's key is taken by Jo
REF : is coming down with a cold. Sue is doing grocery shopping for Jenny.

[GPT] Example 3
PRED:  Oliver is fed up.    Oliver is goin