In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Architecture

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()  # Python 2.0 & above applicable
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim =  embed_size // heads

        assert(self.head_dim * heads == embed_size)

        # Where are we using these ones in the `forward function`?
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)

        self.fc_out = nn.Linear(self.heads * self.head_dim, embed_size)

    def forward(self, query, keys, values, mask: bool):
        N = query.shape[0]  # Training batch size or number of batches?
        query_len, key_len, value_len = query.shape[1], keys.shape[1], values.shape[1]

        queries_proj = self.queries(query)  # (N, Seq, embed_size)
        keys_proj = self.keys(keys)
        values_proj = self.values(values)

        # Split embedding into self.heads pieces
        queries = queries_proj.reshape(N, query_len, self.heads, self.head_dim)  # What is the difference between key_len and N in the first place?
        keys = keys_proj.reshape(N, key_len, self.heads, self.head_dim)
        values = values_proj.reshape(N, value_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd, nkhd->nhqk", [queries, keys])  # Why an array? & Why this arrangement?
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(energy/ self.embed_size ** (1/2), dim=3)

        out = torch.einsum("nhqk,nkhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )  # The shape of the whole must be looked at & how does it reshape this (are the vectors going to the correct place)??

        out = self.fc_out(out)
        return out


In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        """
        The forward expansion term basically takes the vector from a lower dimensional space to a 
        higher dimensional space for better representation capacity. The value used in the paper is
        equal to 4.
        """
        super().__init__()  # Python 3.0 & above applicable
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)  # Why embedding size?
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask):
        attention = self.attention(query, key, value, mask) # How is this calling the forward method?

        x = self.dropout(self.norm1(attention + query))  # Adding Skip Connections
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [4]:
class Encoder(nn.Module):
    def __init__(  # Why isn't any input in the initialization? How is the input fed to the Encoder?
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length  # Maximum possible length of the sentence
    ):  # The above parameters are hyperparameters for the model that we are training
        super().__init__()
        self.embed_size = embed_size  # Do we really need another instance variable embed_size??
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)  # Is this good enough for big models?
        self.positional_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion) for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)  # Why do we need this again? We don't seem to be using this in `forward`

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)  # What does .expand() do here?
        out = self.word_embedding(x) + self.positional_embedding(positions)  # Adding or concatenation?

        for layer in self.layers:
            out = layer(out, out, out, mask)  # What is the use of having separate matrices defined when we pass the same thing here?
            # Also why this? Can't we directly use the TransformerBlock instead of self.layers?
        
        return out


In [5]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super().__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, target_mask):  # Read about src and target masks
        attention = self.attention(x, x, x, target_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(query, key, value, src_mask)
        return out

In [6]:
class Decoder(nn.Module):
    def __init__(
        self,
        target_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length
    ):
        super().__init__()
        self.device = device
        self.word_embedding = nn.Embedding(target_vocab_size, embed_size)  # If we get the value embeddings from the previous layer, do we need this? (for seq2seq) - is my question sensible in the first place?
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device) for _ in range(num_layers)
            ]
        )

        self.fc_out = nn.Linear(embed_size, target_vocab_size)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x, enc_out, src_mask, target_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))  # Why dropout here?

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, target_mask)  # Why are we giving the same values here when we have separate matrices for keys and values?

        out = self.fc_out(x)
        return out

In [7]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        target_vocab_size,
        src_pad_idx,  # Necessary to compute mask we are going to use
        target_pad_idx,
        embed_size = 256,
        num_layers = 6,
        forward_expansion = 4,
        heads = 8,
        dropout = 0,
        device = "cuda",
        max_length = 100
    ):
        super().__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            target_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length  # What is the purpose of this variable??
        )

        self.src_pad_idx = src_pad_idx
        self.target_pad_idx = target_pad_idx
        self.device =  device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)  # Look into the .unsqueeze() function
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_target_mask(self, target):
        N, target_len = target.shape
        target_mask = torch.tril(torch.ones((target_len, target_len))).expand(  # Triangular lower for masking
            N, 1, target_len, target_len  # To expand it to the training examples I guess?
        )
        return target_mask.to(self.device)  # Don't forget to return AND move to device!

    def forward(self, src, target):
        src_mask = self.make_src_mask(src)
        target_mask = self.make_target_mask(target)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(target, enc_src, src_mask, target_mask)
        return out

### Example

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    device
)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(
    device
)
out = model(x, trg[:, :-1])
print(out)
print(out.shape)

cuda
tensor([[[-0.5408,  0.3239, -0.7158,  0.2044, -1.6625,  0.4420,  0.7030,
          -0.2096, -0.7952, -0.2962],
         [-0.1051,  0.0705, -0.3269,  0.2955,  0.2590,  0.8724, -0.1583,
          -0.6905, -0.4783,  0.2919],
         [-0.2632,  1.1045,  0.4545, -0.3623, -0.8390,  1.5524,  0.2520,
          -0.7388,  0.2325,  0.8423],
         [-1.2933,  0.0230, -0.0591,  0.7313, -1.0477,  0.7772, -0.5674,
          -1.0164, -1.0040,  0.7085],
         [-0.0625,  0.0576, -0.4781,  0.3183, -0.6959, -0.4410, -0.7745,
          -0.6779,  0.2478, -0.8871],
         [ 0.4092,  0.2617, -0.3576, -0.0663, -0.4656,  0.3951, -0.7548,
          -0.4140,  0.2158,  0.1852],
         [-0.7933,  0.3552,  0.1669, -0.6749, -0.8392,  0.4564,  0.2412,
           0.3939,  0.2433,  0.0156]],

        [[-0.6584,  0.4709, -0.8047,  0.3188, -1.6987,  0.3926,  0.8262,
          -0.3585, -0.8416, -0.2737],
         [-0.0743,  0.2823, -0.4805,  0.2492, -0.8091,  0.4418, -0.4845,
          -0.7410, -0.1551,  0.0

# Training

## Loading the Dataset

We're using **SAMSum** — a dialogue summarization dataset with chat conversations and their summaries.

**What's in SAMSum?**
- `dialogue`: The chat conversation (what we feed to the **encoder**)
- `summary`: Short summary of the conversation (what we train the **decoder** to produce)
- `id`: Unique identifier

We'll take the first 2000 samples to keep training fast while learning.

In [9]:
from datasets import load_dataset

# Load SAMSum dataset (~2.5 MB download)
# We only load the 'train' split and take first 2000 samples
dataset = load_dataset("knkarthick/samsum", split="train[:2000]")

print(f"Dataset size: {len(dataset)} samples")
print(f"Features: {dataset.features}")

Dataset size: 2000 samples
Features: {'id': Value('string'), 'dialogue': Value('string'), 'summary': Value('string')}


In [10]:
# Let's look at ONE example to understand what we're working with
sample = dataset[0]

print("=" * 60)
print("DIALOGUE (Input to Encoder):")
print("=" * 60)
print(sample['dialogue'])
print(f"\n[Length: {len(sample['dialogue'])} characters]")

print("\n" + "=" * 60)
print("SUMMARY (Target for Decoder):")
print("=" * 60)
print(sample['summary'])
print(f"\n[Length: {len(sample['summary'])} characters]")

DIALOGUE (Input to Encoder):
Amanda: I baked  cookies. Do you want some?
Jerry: Sure!
Amanda: I'll bring you tomorrow :-)

[Length: 92 characters]

SUMMARY (Target for Decoder):
Amanda baked cookies and will bring Jerry some tomorrow.

[Length: 56 characters]


In [11]:
# Let's understand the length distribution - this helps us set max_length later
dialogue_lengths = [len(d.split()) for d in dataset['dialogue']]  # Word count
summary_lengths = [len(s.split()) for s in dataset['summary']]

print("DIALOGUE lengths (in words):")
print(f"  Min: {min(dialogue_lengths)}, Max: {max(dialogue_lengths)}, Avg: {sum(dialogue_lengths)/len(dialogue_lengths):.0f}")

print("\nSUMMARY lengths (in words):")
print(f"  Min: {min(summary_lengths)}, Max: {max(summary_lengths)}, Avg: {sum(summary_lengths)/len(summary_lengths):.0f}")

DIALOGUE lengths (in words):
  Min: 7, Max: 471, Avg: 95

SUMMARY lengths (in words):
  Min: 1, Max: 60, Avg: 21


### Byte Pair Encoding

The BPE algorithm operates by *iteratively* replacing the most frequent pair of bytes (or characters) in a dataset with a new token. This process continues until a predefined vocabulary size is reached or no more pairs can be merged. This tokenization method allows you to more densely represent the incoming text data, allowing more information to be captured in fewer tokens (and thus fit into a fixed finite-sized context window), at the expense of larger vocabulary sizes.

**Properties:**
- **Reversible and lossless**, so you can convert tokens back into the original text
- **Works on arbitrary text**, even text that is not in the tokeniser's training data
- **Compresses the text:** the token sequence is shorter than the bytes corresponding to the original text. On average, in practice, each token corresponds to about 4 bytes.
- **Attempts to let the model see common subwords.** For instance, "ing" is a common subword in English, so BPE encodings will often split "encoding" into tokens like "encod" and "ing" (instead of e.g. "enc" and "oding"). Because the model will then see the "ing" token again and again in different contexts, it helps models generalise and better understand grammar.

**Benefits:**
- Fixed vocabulary size (no infinite words)
- Handles rare/new words by breaking them down
- Balances between character-level and word-level

In [12]:
# Install tiktoken if needed: !pip install tiktoken
import tiktoken

# Load GPT-2's BPE tokenizer using tiktoken (OpenAI's official tokenizer)
# This bypasses the buggy transformers library code
# Has only one special token: <|endoftext|>
enc = tiktoken.get_encoding("gpt2")

# Create a simple wrapper class to make tiktoken compatible with our code
class TiktokenWrapper:
    def __init__(self, encoding):
        self.encoding = encoding
        self.pad_token_id = encoding.eot_token  # End of text token as pad
        self.eos_token_id = encoding.eot_token
        self.bos_token_id = encoding.eot_token  # GPT-2 uses same token
        self.pad_token = "<|endoftext|>"
        self.eos_token = "<|endoftext|>"
        self.bos_token = "<|endoftext|>"
    
    def __len__(self):
        return self.encoding.n_vocab
    
    def encode(self, text):
        return self.encoding.encode(text)
    
    def decode(self, token_ids):
        return self.encoding.decode(token_ids)
    
    def convert_ids_to_tokens(self, token_ids):
        return [self.encoding.decode([tid]) for tid in token_ids]
    
    def __call__(self, text, max_length=None, padding=None, truncation=None, return_tensors=None):
        """Tokenize with padding/truncation support"""
        import torch
        
        token_ids = self.encoding.encode(text)
        
        # Truncate if needed
        if truncation and max_length and len(token_ids) > max_length:
            token_ids = token_ids[:max_length]
        
        # Pad if needed
        if padding == 'max_length' and max_length:
            pad_length = max_length - len(token_ids)
            token_ids = token_ids + [self.pad_token_id] * pad_length
        
        result = {'input_ids': token_ids}
        
        if return_tensors == 'pt':
            result['input_ids'] = torch.tensor([result['input_ids']])
        
        return result

tokenizer = TiktokenWrapper(enc)

print(f"Vocabulary size: {len(tokenizer)}")
print(f"Special tokens: pad={tokenizer.pad_token!r}, eos={tokenizer.eos_token!r}, bos={tokenizer.bos_token!r}")
print(f"Pad token ID: {tokenizer.pad_token_id}")

Vocabulary size: 50257
Special tokens: pad='<|endoftext|>', eos='<|endoftext|>', bos='<|endoftext|>'
Pad token ID: 50256


In [13]:
# Let's see BPE in action!
# Notice how it breaks words into subword units

test_text = "Amanda baked cookies and will bring Jerry some tomorrow."

# Tokenize: text → token IDs
token_ids = tokenizer.encode(test_text)
print(f"Original text: {test_text}")
print(f"Token IDs: {token_ids}")
print(f"Number of tokens: {len(token_ids)}")

# Decode: token IDs → text (to verify)
decoded = tokenizer.decode(token_ids)
print(f"Decoded back: {decoded}")

# See the actual subword tokens
tokens = tokenizer.convert_ids_to_tokens(token_ids)
print(f"\nSubword tokens: {tokens}")
print("(Ġ means 'starts with a space' in GPT-2's encoding)")

Original text: Amanda baked cookies and will bring Jerry some tomorrow.
Token IDs: [5840, 5282, 22979, 14746, 290, 481, 2222, 13075, 617, 9439, 13]
Number of tokens: 11
Decoded back: Amanda baked cookies and will bring Jerry some tomorrow.

Subword tokens: ['Am', 'anda', ' baked', ' cookies', ' and', ' will', ' bring', ' Jerry', ' some', ' tomorrow', '.']
(Ġ means 'starts with a space' in GPT-2's encoding)


In [14]:
# Now let's see how our dataset looks when tokenized
# This helps us decide max_length for the model

dialogue_token_lengths = [len(tokenizer.encode(d)) for d in dataset['dialogue']]
summary_token_lengths = [len(tokenizer.encode(s)) for s in dataset['summary']]

print("DIALOGUE token lengths (after BPE):")
print(f"  Min: {min(dialogue_token_lengths)}, Max: {max(dialogue_token_lengths)}, Avg: {sum(dialogue_token_lengths)/len(dialogue_token_lengths):.0f}")

print("\nSUMMARY token lengths (after BPE):")
print(f"  Min: {min(summary_token_lengths)}, Max: {max(summary_token_lengths)}, Avg: {sum(summary_token_lengths)/len(summary_token_lengths):.0f}")

# Percentiles help us choose max_length
import numpy as np
print(f"\nDIALOGUE 95th percentile: {np.percentile(dialogue_token_lengths, 95):.0f} tokens")
print(f"SUMMARY 95th percentile: {np.percentile(summary_token_lengths, 95):.0f} tokens")
print('Length of datset: ', len(dataset))

DIALOGUE token lengths (after BPE):
  Min: 15, Max: 679, Avg: 150

SUMMARY token lengths (after BPE):
  Min: 1, Max: 73, Avg: 26

DIALOGUE 95th percentile: 367 tokens
SUMMARY 95th percentile: 54 tokens
Length of datset:  2000


### Setting max_length

Based on the 95th percentile, we'll set:
- **Source (dialogue) max_length**: 512 tokens (covers 95%+ of samples)
- **Target (summary) max_length**: 64 tokens

Sequences longer than this will be **truncated**, shorter ones will be **padded**.

In [15]:
# Hyperparameters for tokenization
SRC_MAX_LENGTH = 512  # Max tokens for dialogue (encoder input)
TRG_MAX_LENGTH = 64   # Max tokens for summary (decoder input/output)

def tokenize_sample(dialogue, summary):
    """
    Tokenize a single dialogue-summary pair.
    
    Returns:
        src_ids: Token IDs for dialogue (encoder input)
        trg_ids: Token IDs for summary (decoder input/output)
    """
    # Tokenize with padding and truncation
    src_encoded = tokenizer(
        dialogue,
        max_length=SRC_MAX_LENGTH,
        padding='max_length',      # Pad to max_length
        truncation=True,           # Truncate if longer
        return_tensors='pt'        # Return PyTorch tensors
    )
    
    trg_encoded = tokenizer(
        summary,
        max_length=TRG_MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    return src_encoded['input_ids'].squeeze(), trg_encoded['input_ids'].squeeze()

# Test on one sample
src_ids, trg_ids = tokenize_sample(dataset[0]['dialogue'], dataset[0]['summary'])
print(f"Source shape: {src_ids.shape}")  # Should be [512]
print(f"Target shape: {trg_ids.shape}")  # Should be [64]
print(f"\nFirst 20 source tokens: {src_ids[:20].tolist()}")
print(f"First 20 target tokens: {trg_ids[:20].tolist()}")

Source shape: torch.Size([512])
Target shape: torch.Size([64])

First 20 source tokens: [5840, 5282, 25, 314, 22979, 220, 14746, 13, 2141, 345, 765, 617, 30, 198, 43462, 25, 10889, 0, 198, 5840]
First 20 target tokens: [5840, 5282, 22979, 14746, 290, 481, 2222, 13075, 617, 9439, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]


## PyTorch Dataset & DataLoader

Now we create the data pipeline that feeds batches to our model:

1. **Dataset**: Wraps our data, tokenizes on-the-fly, returns (source, target) pairs
2. **DataLoader**: Batches samples, shuffles for training, handles parallel loading

**Key concept**: During training, the decoder gets the target sequence **shifted by one**:
- **Decoder input**: `<sos> token1 token2 ... tokenN`
- **Decoder target**: `token1 token2 ... tokenN <eos>`

This teaches the model to predict the **next token** given previous tokens.

In [16]:
from torch.utils.data import Dataset, DataLoader

class SummarizationDataset(Dataset):
    """
    Custom Dataset for dialogue summarization.
    
    For each sample, returns:
    - src: tokenized dialogue (encoder input)
    - trg: tokenized summary (decoder input/output)
    """
    def __init__(self, hf_dataset, tokenizer, src_max_len, trg_max_len):
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.src_max_len = src_max_len
        self.trg_max_len = trg_max_len
        
        # Special token IDs
        self.pad_id = tokenizer.pad_token_id
        self.eos_id = tokenizer.eos_token_id
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        dialogue = sample['dialogue']
        summary = sample['summary']
        
        # Tokenize source (dialogue)
        src_tokens = self.tokenizer.encode(dialogue)
        
        # Tokenize target (summary) - add EOS at the end
        trg_tokens = self.tokenizer.encode(summary) + [self.eos_id]
        
        # Truncate if needed
        src_tokens = src_tokens[:self.src_max_len]
        trg_tokens = trg_tokens[:self.trg_max_len]
        
        # Pad to fixed length
        src_padded = src_tokens + [self.pad_id] * (self.src_max_len - len(src_tokens))
        trg_padded = trg_tokens + [self.pad_id] * (self.trg_max_len - len(trg_tokens))
        
        return {
            'src': torch.tensor(src_padded, dtype=torch.long),
            'trg': torch.tensor(trg_padded, dtype=torch.long)
        }

# Test the dataset
test_dataset = SummarizationDataset(dataset, tokenizer, SRC_MAX_LENGTH, TRG_MAX_LENGTH)
sample = test_dataset[0]

print(f"Source shape: {sample['src'].shape}")
print(f"Target shape: {sample['trg'].shape}")
print(f"\nFirst 20 source tokens: {sample['src'][:20].tolist()}")
print(f"First 20 target tokens: {sample['trg'][:20].tolist()}")

Source shape: torch.Size([512])
Target shape: torch.Size([64])

First 20 source tokens: [5840, 5282, 25, 314, 22979, 220, 14746, 13, 2141, 345, 765, 617, 30, 198, 43462, 25, 10889, 0, 198, 5840]
First 20 target tokens: [5840, 5282, 22979, 14746, 290, 481, 2222, 13075, 617, 9439, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]


In [17]:
# Split into train and validation sets
from torch.utils.data import random_split

# Use 90% for training, 10% for validation
full_dataset = SummarizationDataset(dataset, tokenizer, SRC_MAX_LENGTH, TRG_MAX_LENGTH)
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 1800
Validation samples: 200


In [18]:
# Create DataLoaders
BATCH_SIZE = 64  # Batch size for learning (increase on GPU with more memory)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,      # Shuffle for training
    num_workers=0,     # Set to 0 for Windows compatibility
    drop_last=True     # Drop incomplete batches
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,     # No shuffle for validation
    num_workers=0,
    drop_last=False
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Let's look at one batch
batch = next(iter(train_loader))
print(f"\nBatch source shape: {batch['src'].shape}")   # [batch_size, src_max_len]
print(f"Batch target shape: {batch['trg'].shape}")     # [batch_size, trg_max_len]

Training batches: 28
Validation batches: 4

Batch source shape: torch.Size([64, 512])
Batch target shape: torch.Size([64, 64])


## Training Loop

Now we train the model! Key components:

1. **Loss Function**: CrossEntropyLoss — compares predicted token probabilities with actual tokens
2. **Optimizer**: Adam — adaptive learning rate optimizer
3. **Teacher Forcing**: During training, we feed the **ground truth** previous token to predict the next one

**The training flow:**
```
Source (dialogue) → Encoder → Context
Target[:-1] (shifted) → Decoder + Context → Predictions
Predictions vs Target[1:] → Loss → Backprop → Update weights
```

In [19]:
# Initialize model with our actual vocabulary size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model hyperparameters
VOCAB_SIZE = len(tokenizer)  # 50257 for GPT-2
PAD_IDX = tokenizer.pad_token_id
EMBED_SIZE = 256
NUM_LAYERS = 3      # Reduce for faster training (original paper uses 6)
HEADS = 8
FORWARD_EXPANSION = 4
DROPOUT = 0.1
MAX_LENGTH = max(SRC_MAX_LENGTH, TRG_MAX_LENGTH)

# Create model
model = Transformer(
    src_vocab_size=VOCAB_SIZE,
    target_vocab_size=VOCAB_SIZE,  # Same vocab for src and target
    src_pad_idx=PAD_IDX,
    target_pad_idx=PAD_IDX,
    embed_size=EMBED_SIZE,
    num_layers=NUM_LAYERS,
    forward_expansion=FORWARD_EXPANSION,
    heads=HEADS,
    dropout=DROPOUT,
    device=device,
    max_length=MAX_LENGTH
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Using device: cuda
Total parameters: 44,432,465
Trainable parameters: 44,432,465


In [20]:
# Loss function and optimizer
import torch.optim as optim

# CrossEntropyLoss ignores padding tokens when computing loss
# Also why this specific loss function?
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)  # What is the need for PAD_IDX?

# Adam optimizer with learning rate
LEARNING_RATE = 5e-4  # Why this value?
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Loss function: CrossEntropyLoss (ignoring pad_idx={PAD_IDX})")
print(f"Optimizer: Adam (lr={LEARNING_RATE})")

Loss function: CrossEntropyLoss (ignoring pad_idx=50256)
Optimizer: Adam (lr=0.0005)


In [21]:
def train_epoch(model, loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    for batch in loader:
        src = batch['src'].to(device)    # [batch, src_len]
        trg = batch['trg'].to(device)    # [batch, trg_len]
        
        # Decoder input: all tokens except last (teacher forcing)
        # Decoder target: all tokens except first (what we predict)
        # ----- Are we teaching the model to predict the next token or the next entire sequence? -----
        trg_input = trg[:, :-1]   # [batch, trg_len-1]
        trg_target = trg[:, 1:]   # [batch, trg_len-1]
        
        # Forward pass
        optimizer.zero_grad()  # We reset the gradients to zero as in PyTorch autograd accumulates gradients
        output = model(src, trg_input)  # [batch, trg_len-1, vocab_size]
        
        # Reshape for loss: [batch * seq_len, vocab_size] vs [batch * seq_len]
        output = output.reshape(-1, output.shape[-1])
        trg_target = trg_target.reshape(-1)
        
        # Compute loss
        loss = criterion(output, trg_target)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping (prevents exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        total_loss += loss.item()  # Why .item() is needed here?
    
    return total_loss / len(loader)

In [22]:
def evaluate(model, loader, criterion, device):
    """Evaluate on validation set"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in loader:
            src = batch['src'].to(device)
            trg = batch['trg'].to(device)
            
            trg_input = trg[:, :-1]
            trg_target = trg[:, 1:]
            
            output = model(src, trg_input)
            output = output.reshape(-1, output.shape[-1])
            trg_target = trg_target.reshape(-1)
            
            loss = criterion(output, trg_target)
            total_loss += loss.item()
    
    return total_loss / len(loader)

print("Training functions defined! ✓")

Training functions defined! ✓


In [23]:
# Training loop
NUM_EPOCHS = 20
best_val_loss = float('inf')

print(f"Starting training for {NUM_EPOCHS} epochs...")
print("=" * 60)

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    
    # Evaluate
    val_loss = evaluate(model, val_loader, criterion, device)
    
    # Track best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')
        marker = " ← Best!"
    else:
        marker = ""
    
    print(f"Epoch {epoch+1:2d}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}{marker}")

print("=" * 60)
print(f"Training complete! Best validation loss: {best_val_loss:.4f}")

Starting training for 20 epochs...
Epoch  1/20 | Train Loss: 8.5541 | Val Loss: 7.1284 ← Best!
Epoch  2/20 | Train Loss: 6.7001 | Val Loss: 7.0484 ← Best!
Epoch  3/20 | Train Loss: 6.6062 | Val Loss: 7.0663
Epoch  4/20 | Train Loss: 6.5866 | Val Loss: 7.1036
Epoch  5/20 | Train Loss: 6.5779 | Val Loss: 7.1214
Epoch  6/20 | Train Loss: 6.5736 | Val Loss: 7.1373
Epoch  7/20 | Train Loss: 6.5497 | Val Loss: 7.0845
Epoch  8/20 | Train Loss: 6.4252 | Val Loss: 6.7760 ← Best!
Epoch  9/20 | Train Loss: 6.1566 | Val Loss: 6.5521 ← Best!
Epoch 10/20 | Train Loss: 5.8850 | Val Loss: 6.3673 ← Best!
Epoch 11/20 | Train Loss: 5.6580 | Val Loss: 6.2703 ← Best!
Epoch 12/20 | Train Loss: 5.4775 | Val Loss: 6.2025 ← Best!
Epoch 13/20 | Train Loss: 5.3316 | Val Loss: 6.1694 ← Best!
Epoch 14/20 | Train Loss: 5.1918 | Val Loss: 6.1307 ← Best!
Epoch 15/20 | Train Loss: 5.0639 | Val Loss: 6.1175 ← Best!
Epoch 16/20 | Train Loss: 4.9524 | Val Loss: 6.1172 ← Best!
Epoch 17/20 | Train Loss: 4.8397 | Val Loss: 

## Inference (Autoregressive Generation)

Now the fun part — generating summaries! 

**How autoregressive generation works:**
1. Encode the source (dialogue) once
2. Start with just `<eos>` token (or empty)
3. Predict next token, append to sequence
4. Repeat until `<eos>` or max length reached

```
Step 1: [] → predict "Amanda" 
Step 2: [Amanda] → predict "baked"
Step 3: [Amanda, baked] → predict "cookies"
...until <eos>
```

In [24]:
def generate_summary(model, dialogue, tokenizer, device, max_len=64):
    """
    Generate a summary for a given dialogue using autoregressive decoding.
    
    Args:
        model: Trained Transformer model
        dialogue: Input dialogue string
        tokenizer: Tokenizer
        device: cuda or cpu
        max_len: Maximum tokens to generate
    
    Returns:
        Generated summary string
    """
    model.eval()
    
    # Tokenize source
    src_tokens = tokenizer.encode(dialogue)
    src_tokens = src_tokens[:SRC_MAX_LENGTH]  # Truncate if needed
    src_padded = src_tokens + [tokenizer.pad_token_id] * (SRC_MAX_LENGTH - len(src_tokens))
    src_tensor = torch.tensor([src_padded], dtype=torch.long).to(device)
    
    # Start with empty target (we'll build it token by token)
    # Using EOS as the starting token
    generated = [tokenizer.eos_token_id]
    
    with torch.no_grad():
        for _ in range(max_len):
            # Prepare target tensor
            trg_tensor = torch.tensor([generated], dtype=torch.long).to(device)
            
            # Forward pass
            output = model(src_tensor, trg_tensor)  # [1, seq_len, vocab_size]
            
            # Get the last token prediction
            next_token_logits = output[0, -1, :]  # [vocab_size]
            
            # Greedy decoding: pick the token with highest probability
            next_token = next_token_logits.argmax().item()
            
            # Stop if EOS token
            if next_token == tokenizer.eos_token_id:
                break
            
            # Append to generated sequence
            generated.append(next_token)
    
    # Decode tokens to text (skip the initial EOS token)
    summary = tokenizer.decode(generated[1:])
    return summary

print("Generation function defined! ✓")

Generation function defined! ✓


In [25]:
# Test on a few examples from the dataset
print("=" * 70)
print("TESTING ON TRAINING EXAMPLES")
print("=" * 70)

for i in [0, 5, 10]:
    sample = dataset[i]
    dialogue = sample['dialogue']
    actual_summary = sample['summary']
    
    # Generate summary
    generated_summary = generate_summary(model, dialogue, tokenizer, device)
    
    print(f"\n--- Example {i} ---")
    print(f"DIALOGUE:\n{dialogue[:200]}..." if len(dialogue) > 200 else f"DIALOGUE:\n{dialogue}")
    print(f"\nACTUAL SUMMARY:\n{actual_summary}")
    print(f"\nGENERATED SUMMARY:\n{generated_summary}")
    print("-" * 70)

TESTING ON TRAINING EXAMPLES

--- Example 0 ---
DIALOGUE:
Amanda: I baked  cookies. Do you want some?
Jerry: Sure!
Amanda: I'll bring you tomorrow :-)

ACTUAL SUMMARY:
Amanda baked cookies and will bring Jerry some tomorrow.

GENERATED SUMMARY:
 is going to buy a new job. She will come to go to the weekend.                                               a
----------------------------------------------------------------------

--- Example 5 ---
DIALOGUE:
Neville: Hi there, does anyone remember what date I got married on?
Don: Are you serious?
Neville: Dead serious. We're on vacation, and Tina's mad at me about something. I have a strange suspicion tha...

ACTUAL SUMMARY:
Wyatt reminds Neville his wedding anniversary is on the 17th of September. Neville's wife is upset and it might be because Neville forgot about their anniversary.

GENERATED SUMMARY:
 is going to buy a new job. She will come to go to the weekend.                                               a
---------------------------

In [28]:
# Try with a completely NEW dialogue!
print("=" * 70)
print("TESTING ON NEW DIALOGUE")
print("=" * 70)

new_dialogue = """
John: Hey, are you coming to the party tonight?
Sarah: What party?
John: Mike's birthday party at 8pm!
Sarah: Oh I totally forgot! Where is it?
John: At his place. I can pick you up if you want.
Sarah: That would be great! Thanks!
John: No problem. See you at 7:30.
"""

generated = generate_summary(model, new_dialogue, tokenizer, device)

print(f"DIALOGUE:\n{new_dialogue}")
print(f"\nGENERATED SUMMARY:\n{generated}")

TESTING ON NEW DIALOGUE
DIALOGUE:

John: Hey, are you coming to the party tonight?
Sarah: What party?
John: Mike's birthday party at 8pm!
Sarah: Oh I totally forgot! Where is it?
John: At his place. I can pick you up if you want.
Sarah: That would be great! Thanks!
John: No problem. See you at 7:30.


GENERATED SUMMARY:
 is going to buy a new job. She will come to go to the weekend.                                               a
