# End-to-End Demo: Fine-Tuning BERT for Extractive Summarization

This notebook provides a complete walkthrough of fine-tuning a pre-trained BERT model for extractive summarization. It assumes you have already run `preprocessing.ipynb` to generate a dataset of scripts and their corresponding sentence labels.

**Workflow:**
1.  **Setup**: Install dependencies and define necessary classes and functions from the training script.
2.  **Data Simulation**: We'll create a small, representative dataset on the fly to make this notebook self-contained. This simulates the output of the `preprocessing.ipynb`.
3.  **Configuration**: Set up training parameters like model name, learning rate, and epochs.
4.  **Training**: Run the complete training and validation loop.
5.  **Inference**: Load the best-performing model and use it to summarize a new, unseen movie script snippet.

In [None]:
!pip install transformers torch tqdm -q

In [None]:
import os
import json
import random
import re
from typing import List

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AdamW
from tqdm.notebook import tqdm

# --- Code from train_bert_summarizer.py ---

def sentence_split(text: str) -> List[str]:
    """Naive sentence splitter used for scripts and dialogues."""
    text = text.replace('\r', '\n')
    sents = re.split(r'(?<=[.!?\n])\s+', text)
    sents = [s.strip() for s in sents if s.strip()]
    return sents

class BertSummarizer(nn.Module):
    """
    A BERT-based model for extractive summarization.
    """
    def __init__(self, model_name: str):
        super(BertSummarizer, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.scorer = nn.Linear(self.bert.config.hidden_size, 1)
    
    def forward(self, input_ids, attention_mask, cls_indices):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        batch_size, _, _ = last_hidden_state.shape
        batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, cls_indices.size(1))
        batch_indices = batch_indices.to(last_hidden_state.device)
        sent_reps = last_hidden_state[batch_indices, cls_indices]
        scores = self.scorer(sent_reps).squeeze(-1)
        return scores

class SummarizationDataset(Dataset):
    """
    PyTorch Dataset for loading scripts and their summarization labels.
    """
    def __init__(self, data_dir: str, tokenizer, max_len: int = 512):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.txt')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filepath = self.files[idx]
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            text = f.read()

        sentences = sentence_split(text)
        label_path = filepath.replace('.txt', '.labels.json')
        try:
            with open(label_path, 'r') as f:
                label_data = json.load(f)
                important_indices = set(label_data.get('important_sentence_indices', []))
        except FileNotFoundError:
            important_indices = set()
            
        input_tokens = [self.tokenizer.cls_token]
        cls_indices = []
        labels = []

        for i, sentence in enumerate(sentences):
            cls_indices.append(len(input_tokens))
            labels.append(1.0 if i in important_indices else 0.0)
            sent_tokens = self.tokenizer.tokenize(sentence)
            input_tokens.extend(sent_tokens)
            
            if len(input_tokens) >= self.max_len - 1:
                input_tokens = input_tokens[:self.max_len - 1]
                while cls_indices and cls_indices[-1] >= len(input_tokens):
                    cls_indices.pop()
                    labels.pop()
                break
        
        input_tokens.append(self.tokenizer.sep_token)
        input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens)
        attention_mask = [1] * len(input_ids)

        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'cls_indices': torch.tensor(cls_indices, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.float)
        }

def collate_fn(batch):
    input_ids = torch.nn.utils.rnn.pad_sequence([item['input_ids'] for item in batch], batch_first=True, padding_value=0)
    attention_mask = torch.nn.utils.rnn.pad_sequence([item['attention_mask'] for item in batch], batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-1.0)
    cls_indices = torch.nn.utils.rnn.pad_sequence([item['cls_indices'] for item in batch], batch_first=True, padding_value=0)
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'cls_indices': cls_indices, 'labels': labels}

def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        cls_indices = batch['cls_indices'].to(device)
        labels = batch['labels'].to(device)
        scores = model(input_ids, attention_mask, cls_indices)
        label_mask = (labels != -1.0).float()
        max_sents = label_mask.sum(dim=1).max().int()
        scores = scores[:, :max_sents]
        labels = labels[:, :max_sents]
        label_mask = label_mask[:, :max_sents]
        loss = nn.BCEWithLogitsLoss(weight=label_mask)(scores, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            cls_indices = batch['cls_indices'].to(device)
            labels = batch['labels'].to(device)
            scores = model(input_ids, attention_mask, cls_indices)
            label_mask = (labels != -1.0).float()
            max_sents = label_mask.sum(dim=1).max().int()
            scores = scores[:, :max_sents]
            labels = labels[:, :max_sents]
            label_mask = label_mask[:, :max_sents]
            loss = nn.BCEWithLogitsLoss(weight=label_mask)(scores, labels)
            total_loss += loss.item()
    return total_loss / len(dataloader)

### 2. Data Simulation

Your `preprocessing.ipynb` notebook generates a directory of labeled data. For this demo, we'll create a temporary directory (`./temp_data`) and populate it with a few examples that mimic that structure. This allows the notebook to run from start to finish without external dependencies.

In [None]:
DATA_DIR = './temp_data'
os.makedirs(DATA_DIR, exist_ok=True)

# Sample 1: A short dialogue
script1_text = "Amanda: Did you get the files? John: Yes, I have them right here. Amanda: Great. Let's get to work then. We don't have much time."
script1_labels = {"important_sentence_indices": [1, 3]}

with open(os.path.join(DATA_DIR, 'script1.txt'), 'w') as f:
    f.write(script1_text)
with open(os.path.join(DATA_DIR, 'script1.labels.json'), 'w') as f:
    json.dump(script1_labels, f)

# Sample 2: A longer scene description
script2_text = "The sun sets over the city. A lone figure stands on a rooftop, looking down at the bustling streets. Sirens wail in the distance. The figure pulls up their hood, their face obscured by shadow. A decision has to be made tonight."
script2_labels = {"important_sentence_indices": [1, 4]}

with open(os.path.join(DATA_DIR, 'script2.txt'), 'w') as f:
    f.write(script2_text)
with open(os.path.join(DATA_DIR, 'script2.labels.json'), 'w') as f:
    json.dump(script2_labels, f)
    
# Create 18 more dummy files for a slightly larger dataset
for i in range(3, 21):
    text = f"This is sentence one. This is the crucial second sentence. And this is the final sentence of script {i}."
    labels = {"important_sentence_indices": [1]}
    with open(os.path.join(DATA_DIR, f'script{i}.txt'), 'w') as f:
        f.write(text)
    with open(os.path.join(DATA_DIR, f'script{i}.labels.json'), 'w') as f:
        json.dump(labels, f)

print(f"Created {len(os.listdir(DATA_DIR))//2} dummy script samples in '{DATA_DIR}'")

### 3. Configuration & Data Loading

Here, we'll define our training configuration and prepare the datasets and dataloaders for training.

In [None]:
# --- CONFIGURATION ---
MODEL_NAME = 'distilbert-base-uncased' # Using a smaller model for a quick demo
EPOCHS = 3
BATCH_SIZE = 4
LR = 2e-5
VAL_SPLIT = 0.2
SAVE_PATH = './bert_summarizer_model'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- DATA PREPARATION ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

full_dataset = SummarizationDataset(DATA_DIR, tokenizer)
val_size = int(len(full_dataset) * VAL_SPLIT)
train_size = len(full_dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples.")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

### 4. Model Training

Now we'll initialize the model and optimizer and run the main training loop. The model that performs best on the validation set will be saved to the `SAVE_PATH`.

In [None]:
model = BertSummarizer(MODEL_NAME).to(device)
optimizer = AdamW(model.parameters(), lr=LR)

best_val_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = validate_epoch(model, val_loader, device)
    
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        print("New best validation loss. Saving model...")
        os.makedirs(SAVE_PATH, exist_ok=True)
        model.bert.save_pretrained(SAVE_PATH)
        torch.save(model.scorer.state_dict(), os.path.join(SAVE_PATH, 'scorer.pt'))
        tokenizer.save_pretrained(SAVE_PATH)

print("\nTraining complete!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved to {SAVE_PATH}")

### 5. Inference

With the model trained, let's test its summarization capabilities. We'll define an `infer` function, load our best model from disk, and feed it a new script snippet it has never seen before.

In [None]:
def infer_summary(script_text: str, model: BertSummarizer, tokenizer, device, top_k: int = 2):
    model.eval()
    sentences = sentence_split(script_text)
    
    # Prepare input
    input_tokens = [tokenizer.cls_token]
    cls_indices = []
    for sentence in sentences:
        cls_indices.append(len(input_tokens))
        sent_tokens = tokenizer.tokenize(sentence)
        input_tokens.extend(sent_tokens)
    input_tokens.append(tokenizer.sep_token)

    input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(input_tokens), dtype=torch.long).unsqueeze(0).to(device)
    attention_mask = torch.ones_like(input_ids).to(device)
    cls_indices = torch.tensor(cls_indices, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        scores = model(input_ids, attention_mask, cls_indices).squeeze(0)
        sigmoid_scores = torch.sigmoid(scores)

    # Select top-k sentences
    top_indices = torch.argsort(sigmoid_scores, descending=True)[:top_k]
    sorted_indices = sorted(top_indices.tolist())
    
    summary = " ".join([sentences[i] for i in sorted_indices])
    return summary

# --- Load the fine-tuned model ---
print(f"Loading model from {SAVE_PATH}")
inference_tokenizer = AutoTokenizer.from_pretrained(SAVE_PATH)
inference_model = BertSummarizer(SAVE_PATH).to(device) # AutoModel loads from the directory
inference_model.scorer.load_state_dict(torch.load(os.path.join(SAVE_PATH, 'scorer.pt')))

# --- Run Inference ---
test_script = """
INT. WAREHOUSE - NIGHT

DETECTIVE MILLER shines his flashlight across dusty crates. This place hasn't been touched in years. His partner, DETECTIVE SANTIAGO, kicks at a loose floorboard. It's a dead end. 

SANTIAGO
Nothing. We've been had.

MILLER
Maybe not. Look at this.

Miller points his light to a small, almost invisible symbol carved into a crate. The key symbol. Santiago's eyes widen. This changes everything.
"""

generated_summary = infer_summary(test_script, inference_model, inference_tokenizer, device, top_k=3)

print("--- ORIGINAL SCRIPT SNIPPET ---")
print(test_script)
print("\n--- GENERATED SUMMARY ---")
print(generated_summary)