# ![Banner](https://github.com/LittleHouse75/flatiron-resources/raw/main/NevitsBanner.png)
---
# Experiment 1 — BERT Encoder → GPT-2 Decoder  
### *“Frankenstein” Encoder–Decoder Summarization Model*
---

This notebook runs Experiment 1 for the project:

**Goal:**  
Evaluate a custom encoder–decoder architecture where:

- **Encoder:** `bert-base-uncased`  
- **Decoder:** `gpt2` (augmented with cross-attention by HuggingFace)  

This is intentionally *not* a pretrained summarization model.  
The purpose is to test whether a glued-together architecture can learn dialogue summarization with curriculum training (warmup → finetune).

All reusable code is imported from `src/`, keeping this notebook clean.

## 1. Environment Setup

In [None]:
# Disable tokenizers parallelism warning
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
from pathlib import Path
import pandas as pd

# Ensure project root is importable
PROJECT_ROOT = Path("..").resolve()
import sys
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# Mute warnings
import warnings
warnings.filterwarnings("ignore", message="Mem Efficient attention")
warnings.filterwarnings(
    "ignore",
    message=".*copy construct from a tensor.*"
)
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message=".*better way to train encoder-decoder models.*"
)
warnings.filterwarnings("ignore", message=".*requires_grad=True.*")
warnings.filterwarnings("ignore", message=".*Flash Efficient attention.*")

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 2. Project Imports (Shared Utilities)
We import:
- SAMSum loader  
- Dataset wrapper  
- Model builder  
- Trainer  
- Qualitative preview  

In [None]:
from src.data.load_data import load_samsum
from src.data.preprocess import SummaryDataset
from src.models.build_bert_gpt2 import build_bert_gpt2_model
from src.train.trainer_seq2seq import train_model
from src.eval.qualitative import qualitative_samples

## 3. Constants & Hyperparameters  
These values were chosen based on EDA and practical training needs.

In [None]:
MAX_SOURCE_LEN = 512       # <= BERT's max_position_embeddings
MAX_TARGET_LEN = 128
EPOCHS = 10

BATCH_SIZE = 1
GRAD_ACCUM = 4             # Accumulate gradients over 4 batches before updating

LEARNING_RATE = 1e-5
BRIDGE_LR = 1e-3

RUN_TRAINING = True

# =============================================================================
# WARMUP CONFIGURATION
# =============================================================================
# The BERT encoder and GPT-2 decoder were pretrained separately.
# The cross-attention layers that connect them are RANDOMLY INITIALIZED.
# 
# Warmup trains ONLY the decoder (including cross-attention) while keeping
# the encoder frozen. This helps the cross-attention layers learn to "read"
# the encoder's output before we fine-tune everything together.
#
# WARMUP_TARGET_BATCHES: How many batches we WANT to process
#   - If the dataset has fewer batches, warmup will end early (that's OK)
#   - One epoch of SAMSum training ≈ 14,732 batches (with batch_size=1)
#   - So 3000 batches ≈ 20% of one epoch
#
# Actual weight updates = batches / GRAD_ACCUM
#   - With 3000 batches and GRAD_ACCUM=4: 750 weight updates
# =============================================================================

WARMUP_TARGET_BATCHES = 3000  # Target number of batches (may be less if dataset is smaller)

# Pre-calculate expected updates (will be confirmed when we know dataset size)
WARMUP_EXPECTED_UPDATES = WARMUP_TARGET_BATCHES // GRAD_ACCUM

print(f"Warmup configuration:")
print(f"  Target batches:           {WARMUP_TARGET_BATCHES}")
print(f"  Gradient accumulation:    {GRAD_ACCUM}")
print(f"  Expected weight updates:  {WARMUP_EXPECTED_UPDATES}")
print(f"  (Actual numbers will be confirmed after data loading)")

HIST_PATH = PROJECT_ROOT / "models" / "bert-gpt2" / "history.csv"
BEST_DIR = PROJECT_ROOT / "models" / "bert-gpt2" / "best"


## 4. Load SAMSum Data
Data is pulled from `src/data/load_data.py`.  
Local parquet cache is used automatically if available.

In [None]:
train_df, val_df, test_df = load_samsum()
len(train_df), len(val_df), len(test_df)

## 5. Tokenizers & Datasets

GPT-2 has **no pad token**, so we set pad = eos.  

We then build the shared `SummaryDataset`.

In [None]:
from transformers import BertTokenizer, GPT2Tokenizer

bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_tokenizer.model_max_length = 512  # native BERT limit

gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# =============================================================================
# FIX: GPT-2 doesn't have pad_token or bos_token by default
# We need to set these explicitly for the encoder-decoder model to work
# =============================================================================

# Set pad token to eos token (common practice for GPT-2)
if gpt_tokenizer.pad_token is None:
    gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
    # Note: We use assignment instead of add_special_tokens to avoid resizing
    # the embedding matrix unnecessarily when we're just aliasing an existing token

# Set bos token to eos token as well (GPT-2 uses eos as a general delimiter)
# This is crucial for decoder_start_token_id in encoder-decoder models
if gpt_tokenizer.bos_token is None:
    gpt_tokenizer.bos_token = gpt_tokenizer.eos_token

# Verify the tokens are set correctly
print(f"GPT-2 Tokenizer Configuration:")
print(f"  pad_token: '{gpt_tokenizer.pad_token}' (id: {gpt_tokenizer.pad_token_id})")
print(f"  bos_token: '{gpt_tokenizer.bos_token}' (id: {gpt_tokenizer.bos_token_id})")
print(f"  eos_token: '{gpt_tokenizer.eos_token}' (id: {gpt_tokenizer.eos_token_id})")

# Use RIGHT padding for training (labels need to be left-aligned)
gpt_tokenizer.padding_side = "right"

gpt_tokenizer.model_max_length = 1024



# PyTorch datasets
train_dataset = SummaryDataset(train_df, bert_tokenizer, gpt_tokenizer,
                               MAX_SOURCE_LEN, MAX_TARGET_LEN)

val_dataset = SummaryDataset(val_df, bert_tokenizer, gpt_tokenizer,
                             MAX_SOURCE_LEN, MAX_TARGET_LEN)

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=0)

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=0)

## 6. Build the BERT→GPT-2 Model
This calls the modular builder in `src/models/build_bert_gpt2.py`.

In [None]:
model = build_bert_gpt2_model(
    gpt_pad_token_id=gpt_tokenizer.pad_token_id,
    gpt_bos_token_id=gpt_tokenizer.bos_token_id,
    decoder_tokenizer=gpt_tokenizer,
    max_length=MAX_TARGET_LEN,
).to(device)

# Disable cache for gradient checkpointing
model.config.use_cache = False

# Turn on gradient checkpointing
model.encoder.gradient_checkpointing_enable()
model.decoder.gradient_checkpointing_enable()

# or, if your HF version supports it:
# model.gradient_checkpointing_enable()

model

## 7. Optimizer (Warm-up → Fine-tune)

The training loop is shared, but **Experiment-1’s warmup logic is unique**.  
We handle it here in the notebook and pass the correct optimizer into `train_model()`.

In [None]:
import torch.optim as optim

# Phase 1 — train decoder only (encoder frozen)
for name, p in model.named_parameters():
    if name.startswith("encoder."):
        p.requires_grad = False
    else:
        p.requires_grad = True

decoder_params = [p for p in model.parameters() if p.requires_grad]

print("Trainable params in warmup:", sum(p.requires_grad for p in model.parameters()))
print("Decoder-only params:", len(decoder_params))

optimizer = optim.AdamW(decoder_params, lr=BRIDGE_LR)

## 8. Warm-Up Phase (Train Only Cross-Attention)

We warm up for `WARMUP_TARGET_BATCHES` batches (or fewer if the dataset is smaller), then unfreeze the whole model.

In [None]:
if RUN_TRAINING:

    # =================================================================
    # WARMUP PHASE: Train decoder only (encoder frozen)
    # =================================================================
    
    total_available_batches = len(train_loader)
    
    # Determine actual number of batches to process
    # We can't process more batches than exist in the dataset
    actual_warmup_batches = min(WARMUP_TARGET_BATCHES, total_available_batches)
    actual_warmup_updates = actual_warmup_batches // GRAD_ACCUM
    
    print("=" * 60)
    print("WARMUP PHASE")
    print("=" * 60)
    print(f"Dataset size:        {total_available_batches} batches")
    print(f"Target batches:      {WARMUP_TARGET_BATCHES}")
    print(f"Actual batches:      {actual_warmup_batches}", end="")
    
    if actual_warmup_batches < WARMUP_TARGET_BATCHES:
        print(f"  ⚠️  (limited by dataset size)")
    else:
        print()
    
    print(f"Gradient accum:      {GRAD_ACCUM}")
    print(f"Weight updates:      ~{actual_warmup_updates}")
    print("-" * 60)

    batch_count = 0          # How many batches we've processed
    gradient_updates = 0     # How many times we've updated weights
    accumulated_loss = 0.0
    
    model.train()
    loss_trace = []
    
    optimizer.zero_grad()

    for batch in train_loader:
        batch_count += 1

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        decoder_attention_mask = batch["decoder_attention_mask"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        
        # Scale loss for gradient accumulation
        scaled_loss = loss / GRAD_ACCUM
        scaled_loss.backward()
        
        accumulated_loss += loss.item()

        # Update weights every GRAD_ACCUM batches
        if batch_count % GRAD_ACCUM == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            optimizer.zero_grad()
            
            gradient_updates += 1
            avg_loss = accumulated_loss / GRAD_ACCUM
            loss_trace.append(avg_loss)
            
            # Progress update every 100 gradient updates
            if gradient_updates % 100 == 0:
                pct_complete = batch_count / actual_warmup_batches * 100
                print(f"  Update {gradient_updates:4d} | "
                      f"Batch {batch_count:5d}/{actual_warmup_batches} ({pct_complete:5.1f}%) | "
                      f"Loss: {avg_loss:.4f}")
            
            accumulated_loss = 0.0

        # Stop when we've processed enough batches
        if batch_count >= actual_warmup_batches:
            break
    
    # Handle any remaining gradients from incomplete accumulation
    remaining = batch_count % GRAD_ACCUM
    if remaining != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        optimizer.zero_grad()
        
        gradient_updates += 1
        avg_loss = accumulated_loss / remaining
        loss_trace.append(avg_loss)
        print(f"  Final partial update ({remaining} batches), Loss: {avg_loss:.4f}")

    print("-" * 60)
    print(f"Warmup complete!")
    print(f"  Batches processed:    {batch_count}")
    print(f"  Gradient updates:     {gradient_updates}")
    print(f"  Final loss:           {loss_trace[-1]:.4f}")
    print("=" * 60)


## 9. Fine-Tune Phase (Unfreeze All Layers)

In [None]:
if RUN_TRAINING:

    # Unfreeze all parameters
    for p in model.parameters():
        p.requires_grad = True

    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

## 10. Full Training Loop  
This uses the shared `train_model()` from `src/train/trainer_seq2seq.py`  
which handles:
- training epochs  
- validation  
- ROUGE metrics  
- returns a summary DataFrame  

In [None]:
if RUN_TRAINING:
        
    history_df = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        tokenizer=gpt_tokenizer,
        device=device,
        epochs=EPOCHS,
        max_target_len=MAX_TARGET_LEN,
        checkpoint_dir=str(BEST_DIR),
        patience=2,
        grad_accum_steps=GRAD_ACCUM, 
    )
    print("Best checkpoint saved to:", BEST_DIR)

    # --- SAVE HISTORY CSV ---
    
    history_df.to_csv(HIST_PATH, index=False)
    print("Saved training history to:", HIST_PATH)

In [None]:
if not RUN_TRAINING:
    
    print("Skipping training and loading best saved model...")
    from transformers import EncoderDecoderModel, GenerationConfig
    import json

    # Load Model
    model = EncoderDecoderModel.from_pretrained(BEST_DIR).to(device)
    
    # Load Generation Config (if it was saved)
    try:
        saved_config = GenerationConfig.from_pretrained(BEST_DIR)
        print("Found saved generation config.")
    except Exception:
        saved_config = None
        print("No saved generation config found.")
    
    # ALWAYS set generation config to match current notebook settings
    gen_cfg = model.generation_config
    gen_cfg.pad_token_id = gpt_tokenizer.pad_token_id
    gen_cfg.bos_token_id = gpt_tokenizer.bos_token_id
    gen_cfg.max_length = MAX_TARGET_LEN
    gen_cfg.min_length = 5
    gen_cfg.no_repeat_ngram_size = 3
    gen_cfg.early_stopping = True
    gen_cfg.length_penalty = 2.0
    gen_cfg.num_beams = 4
    
    print(f"Generation config set: max_length={MAX_TARGET_LEN}")

    # Load training metadata (if available)
    metadata_path = BEST_DIR / "training_metadata.json"
    if metadata_path.exists():
        with open(metadata_path, 'r') as f:
            training_metadata = json.load(f)
        print(f"Loaded training metadata:")
        print(f"  Best epoch: {training_metadata.get('best_epoch')}")
        print(f"  Weights from epoch: {training_metadata.get('weights_epoch')}")
        print(f"  Note: {training_metadata.get('weights_note')}")
    else:
        print("No training metadata found (older checkpoint format).")

    # Load History
    history_df = pd.read_csv(HIST_PATH)
    print("Loaded saved training history from:", HIST_PATH)


## 11. Loss Curves  
(Optional small plot)

In [None]:
import matplotlib.pyplot as plt

plt.plot(history_df["epoch"], history_df["train_loss"], label="train")
plt.plot(history_df["epoch"], history_df["val_loss"], label="val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Experiment 1 — Loss Curve")
plt.show()

## 12. Qualitative Examples  

Shows 5 model summaries vs human summaries.

In [None]:
qualitative_samples(
    df=val_df,
    model=model,
    encoder_tokenizer=bert_tokenizer,
    decoder_tokenizer=gpt_tokenizer,
    device=device,
    max_source_len=MAX_SOURCE_LEN,
    max_target_len=MAX_TARGET_LEN,
    source_prefix="",  # No prefix for BERT-GPT2
    seed=42,
)

## 13. Save Model + Tokenizers

Matches your README exactly:

In [None]:
SAVE_DIR = PROJECT_ROOT / "models" / "bert-gpt2"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

model.save_pretrained(SAVE_DIR)
bert_tokenizer.save_pretrained(SAVE_DIR)
gpt_tokenizer.save_pretrained(SAVE_DIR)

print("Model saved to:", SAVE_DIR)

# Key Takeaways for Experiment-1

This section will be finished after training, but expected themes:

- Cross-attention warm-up stabilizes training  
- ROUGE improves slowly but plateaus early  
- Model tends to produce chatty, narrative summaries  
- Strong evidence this architecture is sub-optimal compared to BART/T5  

This notebook demonstrates the feasibility and limitations of a hand-assembled encoder–decoder system versus pretrained seq2seq models.