# SAMSum Dialogue Summarization (MVP)
DistilBERT encoder â†’ DistilGPT-2 decoder (EncoderDecoderModel)

Note- Please install pip cell as is. May take a few min to execute. You will see red error messages after installing. Restart runtime after insalling, and proceed to next cells as normal.

In [None]:
!pip -q uninstall -y transformers tokenizers datasets huggingface_hub accelerate rouge-score >/dev/null 2>&1

# 2) Install a mutually compatible bundle
#    Note: transformers 4.45.x pairs with tokenizers 0.20.x
!pip -q install --no-cache-dir --upgrade --upgrade-strategy eager \
  "transformers==4.45.2" \
  "tokenizers==0.20.1" \
  "accelerate==0.34.2" \
  "datasets==2.20.0" \
  "huggingface_hub==0.25.2" \
  "rouge-score==0.1.2"

B) Imports, seed, device

In [1]:
import os, random, gc, numpy as np, pandas as pd, torch

from datasets import load_dataset, DatasetDict
from transformers import (AutoModel, AutoConfig, GPT2LMHeadModel, EncoderDecoderModel,
                          BertTokenizerFast, GPT2TokenizerFast,
                          DataCollatorForSeq2Seq, Trainer, TrainingArguments)
from rouge_score import rouge_scorer, scoring

os.environ["WANDB_DISABLED"] = "true"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device, torch.cuda.get_device_name(0) if device=="cuda" else "")

Device: cuda NVIDIA A100-SXM4-40GB


C) Config knobs

In [2]:
# You can bump these later if you have headroom.
MAX_INPUT_LEN  = 512
MAX_TARGET_LEN = 64

TRAIN_SAMPLES = 1000
VAL_SAMPLES   = 1000

ENC_NAME = "distilbert-base-uncased"
DEC_NAME = "distilgpt2"  # smaller than gpt2

print("Config ->", dict(MAX_INPUT_LEN=MAX_INPUT_LEN, MAX_TARGET_LEN=MAX_TARGET_LEN,
                        TRAIN_SAMPLES=TRAIN_SAMPLES, VAL_SAMPLES=VAL_SAMPLES,
                        ENC_NAME=ENC_NAME, DEC_NAME=DEC_NAME))

Config -> {'MAX_INPUT_LEN': 512, 'MAX_TARGET_LEN': 64, 'TRAIN_SAMPLES': 1000, 'VAL_SAMPLES': 1000, 'ENC_NAME': 'distilbert-base-uncased', 'DEC_NAME': 'distilgpt2'}


D) Load SAMSum (with reliable fallback)

In [3]:
def load_samsum():
    try:
        return load_dataset("knkarthick/samsum")
    except Exception as e:
        print("Hub issue, falling back to raw JSON:", e)
        return load_dataset(
            "json",
            data_files={
                "train":"https://huggingface.co/datasets/samsum/resolve/main/train.json",
                "validation":"https://huggingface.co/datasets/samsum/resolve/main/validation.json",
                "test":"https://huggingface.co/datasets/samsum/resolve/main/test.json",
            }
        )

ds = load_samsum()

# Trim for MVP speed
train_small = ds["train"].select(range(min(TRAIN_SAMPLES, len(ds["train"]))))
val_small   = ds["validation"].select(range(min(VAL_SAMPLES, len(ds["validation"]))))
print(ds)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14731
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
})


E) Tokenizers & preprocessing

In [4]:
bert_tok = BertTokenizerFast.from_pretrained(ENC_NAME)
gpt2_tok  = GPT2TokenizerFast.from_pretrained(DEC_NAME)

# Ensure PAD token for GPT-2 family
if gpt2_tok.pad_token is None:
    gpt2_tok.pad_token = gpt2_tok.eos_token   # simple & safe

def preprocess(batch):
    # Encoder inputs (DistilBERT)
    model_inputs = bert_tok(
        batch["dialogue"],
        truncation=True,
        padding="max_length",
        max_length=MAX_INPUT_LEN
    )
    # Decoder targets (DistilGPT-2)
    with gpt2_tok.as_target_tokenizer():
        labels = gpt2_tok(
            batch["summary"],
            truncation=True,
            padding="max_length",
            max_length=MAX_TARGET_LEN
        )["input_ids"]

    # Mask PAD with -100 so it doesn't contribute to loss
    labels = [
        [(tid if tid != gpt2_tok.pad_token_id else -100) for tid in seq]
        for seq in labels
    ]
    model_inputs["labels"] = labels
    return model_inputs

proc_train = train_small.map(preprocess, batched=True, remove_columns=train_small.column_names)
proc_val   = val_small.map(preprocess,   batched=True, remove_columns=val_small.column_names)
proc = DatasetDict({"train": proc_train, "validation": proc_val})
proc

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizerFast'.


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]



DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 818
    })
})

F) Build model (DistilBERT â†’ DistilGPT-2) + memory savers

In [5]:
# Encoder
encoder = AutoModel.from_pretrained(ENC_NAME)

# Decoder config with cross-attention enabled
dec_config = AutoConfig.from_pretrained(DEC_NAME)
dec_config.is_decoder = True
dec_config.add_cross_attention = True

# Decoder
decoder = GPT2LMHeadModel.from_pretrained(DEC_NAME, config=dec_config)
decoder.resize_token_embeddings(len(gpt2_tok))

# Compose encoder-decoder
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

# Special tokens + generation defaults
model.config.decoder_start_token_id = gpt2_tok.eos_token_id  # GPT-2 has no BOS; EOS works fine
model.config.eos_token_id = gpt2_tok.eos_token_id
model.config.pad_token_id = gpt2_tok.pad_token_id
model.config.max_length = MAX_TARGET_LEN
model.config.no_repeat_ngram_size = 3
model.config.num_beams = 2

model.to(device)

# Memory savers
model.gradient_checkpointing_enable()  # big VRAM win
model.config.use_cache = False         # avoid cache during training

model

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_attn.bias', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.2.cr

EncoderDecoderModel(
  (encoder): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): 

G) Data collator (fp16-friendly)

In [6]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=bert_tok,
    model=model,
    padding="longest",
    pad_to_multiple_of=8,   # helps fp16
    return_tensors="pt"
)
data_collator

DataCollatorForSeq2Seq(tokenizer=BertTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}, model=EncoderDecoderModel(
  (encoder): DistilBertModel(
    (embeddings): Embeddings(

H) ROUGE metric (robust to HF outputs)

In [7]:
rouge_scorer_fn = rouge_scorer.RougeScorer(["rouge1","rouge2","rougeL"], use_stemmer=True)

def compute_metrics(eval_pred):
    import numpy as np
    preds, labels = eval_pred
    # Some HF versions return a tuple (preds, other)
    if isinstance(preds, tuple):
        preds = preds[0]
    preds  = np.asarray(preds)
    labels = np.asarray(labels)
    # Replace masked positions so we can decode
    labels = np.where(labels != -100, labels, gpt2_tok.pad_token_id)

    decoded_preds  = gpt2_tok.batch_decode(preds,  skip_special_tokens=True)
    decoded_labels = gpt2_tok.batch_decode(labels, skip_special_tokens=True)

    agg = scoring.BootstrapAggregator()
    for ref, hyp in zip(decoded_labels, decoded_preds):
        agg.add_scores(rouge_scorer_fn.score(ref, hyp))

    res = agg.aggregate()
    return {k: round(v.mid.fmeasure*100, 2) for k,v in res.items()}

I) Trainer setup (epoch style, small VRAM)

In [8]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

bs = 2
args = Seq2SeqTrainingArguments(
    output_dir="./distilbert_distilgpt2_samsum_mvp",
    per_device_train_batch_size=bs,
    per_device_eval_batch_size=bs,
    gradient_accumulation_steps=16,
    learning_rate=5e-5,
    weight_decay=0.01,
    num_train_epochs=3,
    warmup_ratio=0.1,
    lr_scheduler_type="linear",
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    save_total_limit=1,

    predict_with_generate=True,
    generation_max_length=MAX_TARGET_LEN,
    generation_num_beams=2,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=proc["train"],
    eval_dataset=proc["validation"],
    tokenizer=bert_tok,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

start_id = gpt2_tok.bos_token_id if gpt2_tok.bos_token_id is not None else gpt2_tok.eos_token_id

# keep model.config populated
model.config.decoder_start_token_id = start_id
model.config.bos_token_id          = start_id
model.config.eos_token_id          = gpt2_tok.eos_token_id
model.config.pad_token_id          = gpt2_tok.pad_token_id

# ðŸ”‘ keep generation_config in sync (Transformers >= 4.27)
gen = model.generation_config
gen.decoder_start_token_id = model.config.decoder_start_token_id
gen.bos_token_id           = model.config.bos_token_id
gen.eos_token_id           = model.config.eos_token_id
gen.pad_token_id           = model.config.pad_token_id
gen.max_length             = MAX_TARGET_LEN
gen.num_beams              = 2

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


J) Train and Evaluate

In [9]:
train_result = trainer.train()
trainer.save_model()
metrics = trainer.evaluate()
print("ROUGE on validation:", {k: round(v, 2) if isinstance(v, float) else v for k, v in metrics.items()})

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
50,4.1442




ROUGE on validation: {'eval_loss': 3.86, 'eval_rouge1': 14.96, 'eval_rouge2': 2.06, 'eval_rougeL': 12.26, 'eval_runtime': 326.77, 'eval_samples_per_second': 2.5, 'eval_steps_per_second': 1.25, 'epoch': 2.98}


K) Sample generations + lightweight eval (fallback)

In [10]:
# K) Sample generations + lightweight eval (fixed version)
model.eval()
model.config.use_cache = True  # faster inference

def generate_summary(dialogue):
    # Encode input text
    enc = bert_tok(
        dialogue,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_INPUT_LEN
    )
    # Keep only tensors the model actually uses (avoid token_type_ids error)
    enc = {k: v.to(device) for k, v in enc.items() if k in ("input_ids", "attention_mask")}

    # Generate output
    with torch.no_grad():
        gen = model.generate(
            **enc,
            max_new_tokens=MAX_TARGET_LEN,
            num_beams=2,
            do_sample=False,
        )
    return gpt2_tok.decode(gen[0], skip_special_tokens=True)

# --- Show a few validation examples ---
for i in range(3):
    d = ds["validation"][i]["dialogue"]
    ref = ds["validation"][i]["summary"]
    pred = generate_summary(d)
    print(f"\n=== Example {i+1} ===")
    print("REF :", ref)
    print("PRED:", pred)

# --- Lightweight manual ROUGE on small slice (no dependence on Trainer.evaluate) ---
small = ds["validation"].select(range(50))
agg = scoring.BootstrapAggregator()

for ex in small:
    pred = generate_summary(ex["dialogue"])
    agg.add_scores(rouge_scorer_fn.score(ex["summary"], pred))

lite_res = agg.aggregate()
print("\nLite ROUGE on 50 examples:", {k: round(v.mid.fmeasure * 100, 2) for k, v in lite_res.items()})



=== Example 1 ===
REF : A will go to the animal shelter tomorrow to get a puppy for her son. They already visited the shelter last Monday and the son chose the puppy. 
PRED: Sophie is going to a party at the end of the week. She's going to the party at 8 pm. Â  Â Â   Â  Â  Â  Â  Â  Â  Â Â  Â  Â  Â  Â  Â  Â  Â  Â  Â  Â  Â  Â  Â  Â  Â 
Â  Â Â Â  Â Â Â Â  Â 

=== Example 2 ===
REF : Emma and Rob love the advent calendar. Lauren fits inside calendar various items, for instance, small toys and Christmas decorations. Her children are excited whenever they get the calendar.
PRED: Karen is going to a party at the end of the week. Â   Â  She will be there for a few hours.  Â Â     Â   Â Â Â   Ã‚   Â Â    Â  Â    Â  Â  Â  Â    Â  Â  Â  Â  Â  Â  Â  Â   
Â  Â  Â  Â Â  Â 

=== Example 3 ===
REF : Madison is pregnant but she doesn't want to talk about it. Patricia Stevens got married and she thought she was pregnant. 
PRED: Sophie is going to a party at the end of the week. She's going to have a dr