##  **T5 from Scratch**

In [74]:
import torch 
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("Memory Allocated:", round(torch.cuda.memory_allocated(0) / 1e9, 2), "GB")
    print("Memory Cached:", round(torch.cuda.memory_reserved(0) / 1e9, 2), "GB")
    print("CUDA Version:", torch.version.cuda)
    print("GPU Count:", torch.cuda.device_count())
else:
    print("No GPU detected")

GPU Name: Tesla P100-PCIE-16GB
Memory Allocated: 1.2 GB
Memory Cached: 4.39 GB
CUDA Version: 12.1
GPU Count: 1


In [75]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [76]:
import math
import time
import os
import random
import numpy as np
from tqdm.notebook import tqdm 


In [77]:

class Config:
    # Data
    data_file = "/kaggle/input/taylorswift/taylorswift.txt"     # Input text file path (MUST EXIST)

    # --- SentencePiece Tokenizer Settings ---
    sp_model_prefix = 'ts_spm_bpe' # Prefix for saving the trained SentencePiece model files (.model, .vocab)
    sp_vocab_size = 4000          # Desired vocabulary size (tune this)
    sp_model_type = 'bpe'         # Model type: 'bpe', 'unigram', 'char', or 'word'
    sp_pad_id = 3                 # Ensure this matches the ID used in SP training
    sp_pad_piece = "<pad>"        # Representation of the padding token

    vocab_size = None 
    d_model = 512
    nhead = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    dim_feedforward = 2048
    dropout = 0.1
    max_seq_len = 128
    relative_attention_num_buckets = 32

    batch_size = 32
    learning_rate = 1e-4
    epochs = 30        # << INCREASED EPOCHS
    clip_grad_norm = 1.0

    seed = 42
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [78]:
import sentencepiece as spm

In [79]:
config = Config()

random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
if config.device.type == 'cuda':
    torch.cuda.manual_seed_all(config.seed) # Seed all GPUs
    # Ensure deterministic algorithms for reproducibility (can impact performance slightly)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"Using device: {config.device}")
print(f"Using data file: {config.data_file}")
print(f"Tokenizer: SentencePiece ({config.sp_model_type.upper()}), Target Vocab Size: {config.sp_vocab_size}")
print(f"Current Config:")
print(f"  d_model       = {config.d_model}")
print(f"  nhead         = {config.nhead}")
print(f"  num_layers    = {config.num_encoder_layers} (enc/dec)")
print(f"  batch_size    = {config.batch_size}")
print(f"  max_seq_len   = {config.max_seq_len} (tokens)")
print(f"  learning_rate = {config.learning_rate}")
print(f"  epochs        = {config.epochs}")

Using device: cuda
Using data file: /kaggle/input/taylorswift/taylorswift.txt
Tokenizer: SentencePiece (BPE), Target Vocab Size: 4000
Current Config:
  d_model       = 512
  nhead         = 8
  num_layers    = 6 (enc/dec)
  batch_size    = 32
  max_seq_len   = 128 (tokens)
  learning_rate = 0.0001
  epochs        = 30


### 4.0 Train SentencePiece Tokenizer (if model doesn't exist)

In [80]:
data_filename = config.data_file
if not os.path.exists(data_filename):
    print(f"ERROR: Data file '{data_filename}' not found!")
    print("Please ensure the file exists in the same directory as this notebook.")
    # Stop execution if file is missing
    raise FileNotFoundError(f"Required data file not found: {data_filename}")
else:
    print(f"Data file '{data_filename}' found. Size: {os.path.getsize(data_filename):,} bytes.")

Data file '/kaggle/input/taylorswift/taylorswift.txt' found. Size: 186,754 bytes.


In [81]:

sp_model_file = f"{config.sp_model_prefix}.model"

if not os.path.exists(sp_model_file):
    print(f"SentencePiece model not found at '{sp_model_file}'. Training...")
    if not os.path.exists(config.data_file):
         raise FileNotFoundError(f"Data file '{config.data_file}' not found. Cannot train tokenizer.")

    # --unk_id=0, --bos_id=1, --eos_id=2 are defaults but explicit here
    # --pad_id=3 and --pad_piece make padding explicit
    # --character_coverage=1.0 ensures all characters are representable
    spm_command = (
        f"--input={config.data_file} --model_prefix={config.sp_model_prefix} "
        f"--vocab_size={config.sp_vocab_size} --model_type={config.sp_model_type} "
        f"--character_coverage=1.0 --unk_id=0 --bos_id=1 --eos_id=2 "
        f"--pad_id={config.sp_pad_id} --pad_piece={config.sp_pad_piece} "
        f"--unk_piece=<unk> --bos_piece=<s> --eos_piece=</s> " # Optional: Define string representations
        f"--hard_vocab_limit=false " # Allows vocab slightly larger if needed for special tokens
        f"--shuffle_input_sentence=true --input_sentence_size=10000000" # Shuffle input for better training
    )

    print(f"Running SentencePiece Trainer with command:\n{spm_command}\n")
    try:
        spm.SentencePieceTrainer.train(spm_command)
        print(f"SentencePiece model trained and saved with prefix '{config.sp_model_prefix}'")
    except Exception as e:
        print(f"Error training SentencePiece: {e}")
        raise SystemExit("SentencePiece training failed.")

else:
    print(f"SentencePiece model found at '{sp_model_file}'. Skipping training.")

sp_vocab_file = f"{config.sp_model_prefix}.vocab"
if not os.path.exists(sp_model_file) or not os.path.exists(sp_vocab_file):
     raise FileNotFoundError(f"SentencePiece model/vocab files missing after training attempt.")

SentencePiece model found at 'ts_spm_bpe.model'. Skipping training.


In [82]:

print(f"Loading text corpus from: {config.data_file}")
try:
    
    if not os.path.exists(config.data_file):
         raise FileNotFoundError(f"Data file '{config.data_file}' was expected but not found.")

    with open(config.data_file, 'r', encoding='utf-8') as f:
        text_corpus = f.read()
    print(f"Corpus loaded. Length: {len(text_corpus):,} characters")
    

except FileNotFoundError as e:
    print(f"ERROR: {e}")
    raise SystemExit("Data file loading failed. Stopping execution.")
except Exception as e:
    print(f"An error occurred while reading the file: {e}")
    raise SystemExit("File reading error. Stopping execution.")

Loading text corpus from: /kaggle/input/taylorswift/taylorswift.txt
Corpus loaded. Length: 185,560 characters


In [83]:

if 'text_corpus' in locals() and text_corpus:
    print("-" * 30)
    print("Start of loaded text corpus (first 500 characters):")
    print("-" * 30)
    print(text_corpus[:500]) # Print the first 500 characters
    print("\n" + "-" * 30)
    print("[...] (rest of the corpus follows)")
    print("-" * 30 + "\n")
else:
    print("Text corpus variable 'text_corpus' not found or empty, skipping display.")

------------------------------
Start of loaded text corpus (first 500 characters):
------------------------------
Copy paste of the Wikipedia article on Taylor Swift, as of Feb 16, 2024.
---

Main menu

WikipediaThe Free Encyclopedia

Search
Create account
Log in

Personal tools
Contents  hide
(Top)
Life and career
Toggle Life and career subsection
Artistry
Toggle Artistry subsection
Accolades and achievements
Cultural status
Toggle Cultural status subsection
Wealth
Toggle Wealth subsection
Discography
Filmography
Tours
See also
Footnotes
References
Toggle References subsection
External links
Taylor Swift



------------------------------
[...] (rest of the corpus follows)
------------------------------



In [84]:

class SentencePieceTokenizer:
    def __init__(self, model_path):
        print(f"Loading SentencePiece model from: {model_path}")
        self.processor = spm.SentencePieceProcessor()
        self.processor.load(model_path)

        self._vocab_size = self.processor.get_piece_size()
        self._pad_id = self.processor.pad_id()
        self._unk_id = self.processor.unk_id()
        self._bos_id = self.processor.bos_id()
        self._eos_id = self.processor.eos_id()

        print(f"SentencePiece model loaded.")
        print(f"  Vocabulary size: {self._vocab_size}")
        print(f"  PAD ID: {self._pad_id} ('{self.processor.id_to_piece(self._pad_id)}')")
        print(f"  UNK ID: {self._unk_id} ('{self.processor.id_to_piece(self._unk_id)}')")
        print(f"  BOS ID: {self._bos_id} ('{self.processor.id_to_piece(self._bos_id)}')")
        print(f"  EOS ID: {self._eos_id} ('{self.processor.id_to_piece(self._eos_id)}')")

        # Verify our configured PAD ID matches the loaded model's PAD ID
        if self._pad_id != config.sp_pad_id:
             print(f"Warning: Configured PAD ID ({config.sp_pad_id}) does not match loaded model PAD ID ({self._pad_id}). Using loaded model PAD ID.")
             # Update config pad_id to match the actual loaded model for safety
             config.sp_pad_id = self._pad_id
        # --------------------


    def encode(self, text, add_bos=False, add_eos=False):
        """Encodes text into token IDs."""
        encoded = self.processor.encode_as_ids(text)
        if add_bos:
            encoded = [self._bos_id] + encoded
        if add_eos:
            encoded = encoded + [self._eos_id]
        return encoded

    def decode(self, token_ids, skip_special_tokens=True):
        """Decodes token IDs back to text."""
        if skip_special_tokens:
            ids_to_decode = [id for id in token_ids if id not in [self._pad_id, self._bos_id, self._eos_id, self._unk_id]]
        else:
            ids_to_decode = token_ids

        return self.processor.decode_ids(ids_to_decode)

    def vocab_size(self):
        """Returns the vocabulary size."""
        return self._vocab_size

    @property
    def pad_token_id(self):
        return self._pad_id

    @property
    def unk_token_id(self):
        return self._unk_id

    @property
    def bos_token_id(self):
        return self._bos_id

    @property
    def eos_token_id(self):
        return self._eos_id

### 4.2 Text Dataset

In [85]:

class TextDataset(Dataset):
    def __init__(self, text_data, tokenizer, max_seq_len):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len # Max length in TOKENS
        print("Tokenizing data using SentencePiece...")
        start_time = time.time()

        self.tokenized_data = tokenizer.encode(text_data, add_bos=True, add_eos=True)
        self.data_len = len(self.tokenized_data) # Length is now number of tokens
        print(f"Tokenization complete ({time.time() - start_time:.2f}s). Total tokens: {self.data_len:,}")

        # Sequence creation logic remains the same, but operates on tokens
        # Example: [BOS, t1a, t1b, t2, t3a, t3b, t3c, EOS], max_seq_len=4 (tokens)
        # Chunk 1: [BOS, t1a, t1b, t2, t3a] -> Input: [BOS, t1a, t1b, t2], Target: [t1a, t1b, t2, t3a]
        # Chunk 2: [t1a, t1b, t2, t3a, t3b] -> Input: [t1a, t1b, t2, t3a], Target: [t1b, t2, t3a, t3b]
        # ...

    def __len__(self):
        return max(0, self.data_len - self.max_seq_len - 1)

    def __getitem__(self, idx):

        start_idx = idx
        end_idx = idx + self.max_seq_len + 1
        chunk = self.tokenized_data[start_idx : end_idx]
        current_chunk_len = len(chunk)
        if current_chunk_len < self.max_seq_len + 1:
             print(f"Warning: Chunk at index {idx} is shorter than expected ({current_chunk_len} vs {self.max_seq_len + 1}). Padding.")
             padding_needed = (self.max_seq_len + 1) - current_chunk_len
             # Pad with PAD token ID
             chunk.extend([self.tokenizer.pad_token_id] * padding_needed)


        input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
        target_ids = torch.tensor(chunk[1:], dtype=torch.long)

        # Double-check lengths after potential padding or slicing issues
        if len(input_ids) != self.max_seq_len:
             raise RuntimeError(f"Input length mismatch after processing: expected {self.max_seq_len}, got {len(input_ids)} at index {idx}")
        if len(target_ids) != self.max_seq_len:
             raise RuntimeError(f"Target length mismatch after processing: expected {self.max_seq_len}, got {len(target_ids)} at index {idx}")


        return input_ids, target_ids

### 4.6 Instantiate Tokenizer, Dataset, and DataLoader

In [86]:
# --- Create SentencePiece Tokenizer ---
# Load the model trained/verified in the previous step
sp_model_path = f"{config.sp_model_prefix}.model"
try:
    # Ensure the text_corpus variable exists from the loading step
    if 'text_corpus' not in locals() or not text_corpus:
         raise RuntimeError("Text corpus not loaded. Cannot proceed.")
    tokenizer = SentencePieceTokenizer(sp_model_path)
except Exception as e:
     print(f"Error loading SentencePiece model from '{sp_model_path}': {e}")
     print("Ensure the SentencePiece training cell ran successfully and the .model file exists.")
     raise SystemExit("Tokenizer loading failed.")

# Update config with the ACTUAL vocabulary size from the loaded SentencePiece model
config.vocab_size = tokenizer.vocab_size()
print(f"Updated config.vocab_size to actual SentencePiece vocab size: {config.vocab_size}")


# --- Create Dataset ---
try:
    dataset = TextDataset(text_corpus, tokenizer, config.max_seq_len)
except Exception as e:
    print(f"Error creating TextDataset: {e}")
    raise SystemExit("Dataset creation failed.")


# --- Create DataLoader ---
# Use multiple workers for data loading if on GPU and OS supports it well
num_workers = 2 if config.device.type == 'cuda' and os.name == 'posix' else 0 # Use workers mainly on Linux/macOS with CUDA
dataloader = DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=True, # Shuffle data each epoch for better training
    num_workers=num_workers,
    pin_memory=(config.device.type == 'cuda') # Helps speed up CPU->GPU transfer
)

print(f"\nDataset size: {len(dataset):,} sequences")
print(f"Final Vocab size used in model: {config.vocab_size}")

# --- Print Example to Verify ---
if len(dataset) > 0:
    print("\n--- Verifying Dataloader Output ---")
    try:
        example_batch_input, example_batch_target = next(iter(dataloader))
        print(f"Sample batch shapes: input={example_batch_input.shape}, target={example_batch_target.shape}")

        # Take the first item from the batch for detailed view
        example_input = example_batch_input[0]
        example_target = example_batch_target[0]

        print(f"\nExample item from first batch:")
        print(f"Input tokens : {example_input[:30].tolist()}...")
        print(f"Target tokens: {example_target[:30].tolist()}...")
        print("-" * 20)
        # Decode example tokens back to text, skipping special tokens for readability
        decoded_input_sample = tokenizer.decode(example_input[:50].tolist(), skip_special_tokens=True)
        decoded_target_sample = tokenizer.decode(example_target[:50].tolist(), skip_special_tokens=True)
        print(f"Decoded input sample : '{decoded_input_sample}'...")
        print(f"Decoded target sample: '{decoded_target_sample}'...")
        print("-" * 20)
    except StopIteration:
        print("Could not get a batch from the dataloader (dataset might be smaller than batch size).")
    except Exception as e:
        print(f"Error verifying dataloader output: {e}")
else:
    print("\nWarning: Dataset is empty. Cannot verify dataloader. Check data file, tokenization, and max_seq_len.")
    if 'text_corpus' in locals() and len(text_corpus) > 0:
         print(f"Text corpus length ({len(text_corpus)} chars) resulted in {len(dataset)} sequences of {config.max_seq_len} tokens.")

# Clean up large text variable if no longer needed
if 'text_corpus' in locals():
    del text_corpus
    print("Cleaned up text_corpus variable from memory.")

Loading SentencePiece model from: ts_spm_bpe.model
SentencePiece model loaded.
  Vocabulary size: 4000
  PAD ID: 3 ('<pad>')
  UNK ID: 0 ('<unk>')
  BOS ID: 1 ('<s>')
  EOS ID: 2 ('</s>')
Updated config.vocab_size to actual SentencePiece vocab size: 4000
Tokenizing data using SentencePiece...
Tokenization complete (0.13s). Total tokens: 43,533

Dataset size: 43,404 sequences
Final Vocab size used in model: 4000

--- Verifying Dataloader Output ---
Sample batch shapes: input=torch.Size([32, 128]), target=torch.Size([32, 128])

Example item from first batch:
Input tokens : [3770, 1418, 472, 79, 112, 32, 3942, 3916, 1173, 1417, 38, 626, 495, 3158, 81, 3391, 149, 3921, 78, 66, 17, 76, 48, 269, 434, 160, 54, 190, 397, 195]...
Target tokens: [1418, 472, 79, 112, 32, 3942, 3916, 1173, 1417, 38, 626, 495, 3158, 81, 3391, 149, 3921, 78, 66, 17, 76, 48, 269, 434, 160, 54, 190, 397, 195, 984]...
--------------------
Decoded input sample : 'Here Are All of Taylor Swift's Biggest Accomplishments in

### Relative Position Bias

In [87]:
class RelativePositionBias(nn.Module):
    """
    Simplified Relative Position Bias module, inspired by T5.
    Learns embeddings for relative distances between keys and queries.
    Each head gets its own set of biases.
    """
    def __init__(self, num_buckets, num_heads, max_distance=128, bidirectional=True):
        super().__init__()
        self.num_buckets = num_buckets
        self.num_heads = num_heads
        self.max_distance = max_distance
        self.bidirectional = bidirectional
        # Learnable embeddings: one vector of size num_heads for each bucket index
        self.relative_attention_bias = nn.Embedding(self.num_buckets, self.num_heads)
        # Initialize biases to zeros (optional, but common)
        nn.init.zeros_(self.relative_attention_bias.weight)


    def _relative_position_bucket(self, relative_position):
        """ Maps relative position values to integer bucket indices. """
        ret = 0
        n = -relative_position # Positive values mean query is ahead of key
        num_buckets = self.num_buckets
        max_dist = self.max_distance

        if self.bidirectional:
            # Split buckets for positive and negative positions
            num_buckets //= 2
            # Add offset for negative positions (if n < 0)
            ret += (n < 0).to(torch.long) * num_buckets
            n = torch.abs(n) # Consider absolute distance now
        else:
            # For unidirectional (decoder self-attention), clip non-positive positions to 0
            n = torch.max(n, torch.zeros_like(n))

        # Now n is non-negative
        # Bucketing logic: linear up to max_exact, logarithmic beyond
        max_exact = num_buckets // 2
        is_small = (n < max_exact)

        # Calculate bucket index for large distances logarithmically
        # Add epsilon for numerical stability with log
        # Clamp log argument to avoid log(0)
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact + 1e-6)
            / math.log(max_dist / max_exact + 1e-6) # Add epsilon here too
            * (num_buckets - max_exact)
        ).to(torch.long)

        # Ensure bucket indices stay within bounds [0, num_buckets-1]
        val_if_large = torch.min(val_if_large, torch.full_like(n, num_buckets - 1))
        # Also ensure indices are non-negative if calculation goes wrong
        val_if_large = torch.max(val_if_large, torch.zeros_like(n))


        # Combine buckets for small and large distances based on the `is_small` mask
        ret += torch.where(is_small, n, val_if_large)
        # Final check to ensure all bucket indices are within the valid range
        ret = torch.clamp(ret, 0, self.num_buckets - 1 if self.bidirectional else self.num_buckets // 2 * 2 -1 )

        return ret

    def forward(self, qlen, klen, device):
        """ Calculate relative position bias matrix for given sequence lengths. """
        # Generate tensors representing positions in the query and key sequences
        context_position = torch.arange(qlen, dtype=torch.long, device=device)[:, None] # Shape (qlen, 1)
        memory_position = torch.arange(klen, dtype=torch.long, device=device)[None, :]  # Shape (1, klen)

        # Calculate pairwise relative positions: shape (qlen, klen)
        # E.g., result[i, j] = memory_position[j] - context_position[i]
        relative_position = memory_position - context_position

        # Compute bucket indices for each relative position
        rp_bucket = self._relative_position_bucket(relative_position) # Shape (qlen, klen)

        # Look up the learned bias vectors from the embedding table using the bucket indices
        # Shape: (qlen, klen, num_heads)
        values = self.relative_attention_bias(rp_bucket)

        # Reshape for broadcasting with attention scores: (num_heads, qlen, klen)
        values = values.permute(2, 0, 1)
        return values

### mutli head attention

In [88]:

class MultiHeadAttention(nn.Module):
    """ Multi-Head Attention Layer incorporating Relative Position Bias. """
    def __init__(self, d_model, nhead, dropout=0.1, is_decoder=False, relative_attention_num_buckets=32):
        super().__init__()
        assert d_model % nhead == 0, "d_model must be divisible by nhead"
        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead # Dimension of each attention head
        self.is_decoder = is_decoder # Flag for relative bias directionality

        # Linear projections for Query, Key, Value, and Output. Bias is often False.
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

        # Instantiate the Relative Position Bias module
        self.relative_position_bias = RelativePositionBias(
            num_buckets=relative_attention_num_buckets,
            num_heads=nhead,
            bidirectional=(not is_decoder) # Encoder is bidirectional, Decoder self-attn is unidirectional
        )

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None):
        """
        Args:
            query (Tensor): Query tensor, shape (batch_size, target_len, d_model)
            key (Tensor): Key tensor, shape (batch_size, source_len, d_model)
            value (Tensor): Value tensor, shape (batch_size, source_len, d_model)
            key_padding_mask (BoolTensor, optional): Mask for padding tokens in key/value. Shape (batch_size, source_len). True where padded.
            attn_mask (Tensor, optional): Mask to prevent attention to certain positions (e.g., future tokens). Shape (target_len, source_len) or broadcastable. Can be bool (True where masked) or float (-inf where masked).

        Returns:
            Tensor: Output tensor, shape (batch_size, target_len, d_model)
        """
        batch_size = query.size(0)
        target_len = query.size(1)
        source_len = key.size(1)
        device = query.device

        # 1. Linear projections: (batch_size, seq_len, d_model)
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        # 2. Reshape for multi-head attention:
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, nhead, d_k) -> (batch_size, nhead, seq_len, d_k)
        q = q.view(batch_size, target_len, self.nhead, self.d_k).transpose(1, 2)
        k = k.view(batch_size, source_len, self.nhead, self.d_k).transpose(1, 2)
        v = v.view(batch_size, source_len, self.nhead, self.d_k).transpose(1, 2)

        # 3. Scaled dot-product attention scores:
        # (batch_size, nhead, target_len, d_k) @ (batch_size, nhead, d_k, source_len) -> (batch_size, nhead, target_len, source_len)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 4. Add Relative Position Bias:
        # Compute bias matrix: (nhead, target_len, source_len)
        rel_pos_bias = self.relative_position_bias(target_len, source_len, device=device)
        # Add bias (broadcasts across batch dim): -> (batch_size, nhead, target_len, source_len)
        attn_scores = attn_scores + rel_pos_bias.unsqueeze(0)

        # 5. Apply masks:
        # Apply attn_mask (e.g., causal mask)
        if attn_mask is not None:
             # Ensure attn_mask is broadcastable: e.g., (target_len, source_len) -> (1, 1, target_len, source_len)
             if attn_mask.dim() == 2:
                 attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
             elif attn_mask.dim() == 3: # Should not happen for standard masks
                 attn_mask = attn_mask.unsqueeze(1)

             # Apply mask where mask is True (bool) or non-zero/negative (float)
             if attn_mask.dtype == torch.bool:
                # Ensure mask is on the same device and has correct shape for broadcasting
                attn_scores = attn_scores.masked_fill(attn_mask.to(device), float('-inf'))
             else: # Assuming float mask with large negative values
                attn_scores = attn_scores + attn_mask.to(device) # Broadcasting

        # Apply key_padding_mask (masks padding tokens in K/V)
        if key_padding_mask is not None:
            # Reshape mask for broadcasting: (batch_size, source_len) -> (batch_size, 1, 1, source_len)
            mask = key_padding_mask.unsqueeze(1).unsqueeze(2).to(device)
            attn_scores = attn_scores.masked_fill(mask == True, float('-inf'))

        # 6. Softmax to get attention probabilities:
        attn_probs = F.softmax(attn_scores, dim=-1) # Apply softmax over source length dimension
        attn_probs = self.dropout(attn_probs) # Apply dropout to attention weights

        # 7. Weighted sum of values:
        # (batch_size, nhead, target_len, source_len) @ (batch_size, nhead, source_len, d_k) -> (batch_size, nhead, target_len, d_k)
        output = torch.matmul(attn_probs, v)

        # 8. Reshape and final linear projection:
        # (batch_size, nhead, target_len, d_k) -> (batch_size, target_len, nhead, d_k) -> (batch_size, target_len, d_model)
        output = output.transpose(1, 2).contiguous().view(batch_size, target_len, self.d_model)
        output = self.out_proj(output) # Final linear layer

        return output

### Position-wise Feed-Forward Network

In [89]:

class PositionwiseFeedForward(nn.Module):
    """ Implements the Position-wise Feed-Forward layer of a Transformer. """
    def __init__(self, d_model, dim_feedforward, dropout=0.1):
        super().__init__()
        # Two linear layers with an activation function and dropout in between
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Activation: ReLU is simple, GELU is common in modern Transformers
        self.activation = nn.ReLU()
        # self.activation = nn.GELU() # Option: Use GELU

    def forward(self, x):
        """ Applies the FFN transformation element-wise. """
        # x shape: (batch_size, seq_len, d_model)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x) # Dropout often applied after activation
        x = self.linear2(x)
        # Output shape: (batch_size, seq_len, d_model)
        return x

### encoder layer

In [90]:

class EncoderLayer(nn.Module):
    """ Single Encoder layer combining Self-Attention and FFN. """
    def __init__(self, d_model, nhead, dim_feedforward, dropout, relative_attention_num_buckets):
        super().__init__()
        # Sub-layers: Multi-Head Self-Attention and Position-wise Feed-Forward
        self.self_attn = MultiHeadAttention(d_model, nhead, dropout, is_decoder=False, relative_attention_num_buckets=relative_attention_num_buckets)
        self.feed_forward = PositionwiseFeedForward(d_model, dim_feedforward, dropout)

        # Layer Normalization and Dropout for residual connections
        self.norm1 = nn.LayerNorm(d_model) # T5 might use RMSNorm instead
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        """ Forward pass for the encoder layer. """
        # src shape: (batch_size, src_seq_len, d_model)

        # --- Self-attention block (Sublayer 1) ---
        # Calculate attention output
        attn_output = self.self_attn(src, src, src, # Query, Key, Value are all 'src'
                                      key_padding_mask=src_key_padding_mask,
                                      attn_mask=src_mask) # src_mask is rarely used unless specific masking needed
        # Residual connection + Dropout + LayerNorm (Post-LN style)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)
        # --- End Self-attention ---

        # --- Feed-forward block (Sublayer 2) ---
        # Pass through FFN
        ff_output = self.feed_forward(src)
        # Residual connection + Dropout + LayerNorm (Post-LN style)
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)
        # --- End Feed-forward ---

        # Output shape: (batch_size, src_seq_len, d_model)
        return src

### decoder layer

In [91]:

class DecoderLayer(nn.Module):
    """ Single Decoder layer combining Masked Self-Attention, Cross-Attention, and FFN. """
    def __init__(self, d_model, nhead, dim_feedforward, dropout, relative_attention_num_buckets):
        super().__init__()
        # Sub-layers: Masked Self-Attention, Cross-Attention, Position-wise Feed-Forward
        self.self_attn = MultiHeadAttention(d_model, nhead, dropout, is_decoder=True, relative_attention_num_buckets=relative_attention_num_buckets)
        self.cross_attn = MultiHeadAttention(d_model, nhead, dropout, is_decoder=True, relative_attention_num_buckets=relative_attention_num_buckets)
        self.feed_forward = PositionwiseFeedForward(d_model, dim_feedforward, dropout)

        # Layer Normalization (one after each sublayer) and Dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory,
                tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, # <--- ADDED THIS ARGUMENT
                memory_key_padding_mask=None,
                return_attention_weights=False):
        """ Forward pass for the decoder layer. """
        # tgt shape: (batch_size, tgt_seq_len, d_model)
        # memory shape: (batch_size, src_seq_len, d_model) - Output from Encoder

        # --- Masked Self-attention block (Sublayer 1) ---
        # Attends to the decoder input sequence (`tgt`) itself, using causal mask.
        self_attn_output = self.self_attn(tgt, tgt, tgt, # Q, K, V = tgt
                                           key_padding_mask=tgt_key_padding_mask, # Mask padding in tgt
                                           attn_mask=tgt_mask) # Apply causal mask
        # Residual connection + Dropout + LayerNorm
        tgt = tgt + self.dropout1(self_attn_output)
        tgt = self.norm1(tgt)
        # --- End Masked Self-attention ---

        # --- Cross-attention block (Sublayer 2) ---
        # Attends to the encoder output (`memory`) using the decoder state (`tgt`) as query.
        cross_attn_output = self.cross_attn(tgt, memory, memory, # Query=tgt, Key=memory, Value=memory
                                             key_padding_mask=memory_key_padding_mask, # Mask padding in memory
                                             attn_mask=memory_mask) # Usually None
        # Residual connection + Dropout + LayerNorm
        tgt = tgt + self.dropout2(cross_attn_output)
        tgt = self.norm2(tgt)
        # --- End Cross-attention ---

        # --- Feed-forward block (Sublayer 3) ---
        ff_output = self.feed_forward(tgt)
        # Residual connection + Dropout + LayerNorm
        tgt = tgt + self.dropout3(ff_output)
        tgt = self.norm3(tgt)
        # --- End Feed-forward ---

        # Output shape: (batch_size, tgt_seq_len, d_model)
        return tgt

In [92]:
class Encoder(nn.Module):
    """ Stack of EncoderLayers. """
    def __init__(self, encoder_layer_config, num_layers, d_model, nhead, dim_feedforward, dropout, relative_attention_num_buckets, norm=True):
        super().__init__()
        # Create a list of independent EncoderLayer instances
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, nhead, dim_feedforward, dropout, relative_attention_num_buckets)
            for _ in range(num_layers)
        ])
        self.num_layers = num_layers
        # Optional final Layer Normalization after the stack
        self.norm = nn.LayerNorm(d_model) if norm else None

    def forward(self, src, mask=None, src_key_padding_mask=None):
        """ Pass input through the stack of encoder layers. """
        output = src
        # Sequentially apply each layer
        for layer in self.layers:
            output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        # Apply final normalization if specified
        if self.norm is not None:
            output = self.norm(output)
        return output

In [93]:

class Decoder(nn.Module):
    """ Stack of DecoderLayers. """
    def __init__(self, decoder_layer_config, num_layers, d_model, nhead, dim_feedforward, dropout, relative_attention_num_buckets, norm=True):
        super().__init__()
         # Create a list of independent DecoderLayer instances
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, nhead, dim_feedforward, dropout, relative_attention_num_buckets)
            for _ in range(num_layers)
        ])
        self.num_layers = num_layers
        # Optional final Layer Normalization after the stack
        self.norm = nn.LayerNorm(d_model) if norm else None

    def forward(self, tgt, memory,
                tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        """ Pass input and memory through the stack of decoder layers. """
        output = tgt
        # Sequentially apply each layer, passing memory and all masks
        for layer in self.layers:
            output = layer(output, memory,
                           tgt_mask=tgt_mask, memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask)

        # Apply final normalization if specified
        if self.norm is not None:
            output = self.norm(output)
        return output

### T5 Model (Putting it all together)

In [94]:

class T5Model(nn.Module):
    """ The main T5 model combining Encoder, Decoder, and Shared Embeddings. """
    def __init__(self, config: Config, tokenizer: SentencePieceTokenizer): # Pass tokenizer to get pad_id
        super().__init__()
        self.config = config
        self.d_model = config.d_model
        # Get vocab size from config (set after tokenizer loaded)
        self.vocab_size = config.vocab_size
        # Get PAD ID directly from the tokenizer instance
        self.pad_token_id = tokenizer.pad_token_id

        if self.vocab_size is None or self.pad_token_id is None:
             raise ValueError("Vocab size or Pad token ID is not set in config/tokenizer.")

        print(f"Initializing T5Model with vocab_size={self.vocab_size} and pad_token_id={self.pad_token_id}")

        # --- Shared Embedding Layer ---
        # Uses the actual vocab size and the correct padding index from the tokenizer
        self.shared_embedding = nn.Embedding(self.vocab_size, config.d_model, padding_idx=self.pad_token_id)
        # Scale embeddings (common practice, helps stabilize)
        self.scale_emb = math.sqrt(self.d_model)

        # --- Dropout for embeddings ---
        self.dropout = nn.Dropout(config.dropout)

        # --- Encoder ---
        # Pass config values directly to the Encoder constructor
        encoder_norm = True # T5 typically uses a final norm in the encoder
        self.encoder = Encoder(
            encoder_layer_config={}, # Config passed directly below
            num_layers=config.num_encoder_layers,
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            relative_attention_num_buckets=config.relative_attention_num_buckets,
            norm=encoder_norm
        )

        # --- Decoder ---
        decoder_norm = True # T5 typically uses a final norm in the decoder
        self.decoder = Decoder(
            decoder_layer_config={}, # Config passed directly below
            num_layers=config.num_decoder_layers,
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
            relative_attention_num_buckets=config.relative_attention_num_buckets,
            norm=decoder_norm
        )

        # --- Final Linear Layer (Output Head) ---
        # Projects decoder output (d_model) to vocabulary scores (vocab_size)
        self.lm_head = nn.Linear(config.d_model, self.vocab_size, bias=False)

        # --- Weight Tying (Crucial T5 Feature) ---
        # Share weights between input embeddings and final output layer
        self.lm_head.weight = self.shared_embedding.weight
        self._init_weights()
        print("T5 Model components initialized.")

    def _init_weights(self):
        """ Initializes model weights. """
        # Simple initialization scheme. More sophisticated methods exist (e.g., T5 uses truncated normal).
        initrange = 0.1
        # Initialize shared embeddings uniformly
        self.shared_embedding.weight.data.uniform_(-initrange, initrange)
        # lm_head weights are automatically initialized due to tying

        # Initialize linear layers (e.g., in Attention, FFN) using Xavier/Glorot Uniform
        # Initialize LayerNorm weights to 1, biases to 0
        for module in self.modules():
            if isinstance(module, nn.Linear):
                 nn.init.xavier_uniform_(module.weight)
                 if module.bias is not None: # Check if bias exists (we set bias=False in some layers)
                     nn.init.zeros_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                 nn.init.ones_(module.weight)
                 nn.init.zeros_(module.bias)
        print("Model weights initialized (Embeddings: Uniform, Linear: Xavier, LayerNorm: Default).")


    def forward(self, src_ids, tgt_ids, src_padding_mask=None, tgt_padding_mask=None, memory_key_padding_mask=None, tgt_mask=None):
        """
        Performs the forward pass of the T5 model.
        Args:
            src_ids (Tensor): Input token IDs for the encoder. Shape (batch_size, src_seq_len)
            tgt_ids (Tensor): Input token IDs for the decoder (shifted right). Shape (batch_size, tgt_seq_len)
            ... masks ...
        Returns:
            Tensor: Logits output from the decoder head. Shape (batch_size, tgt_seq_len, vocab_size)
        """
        # 1. Embeddings (Source and Target use the same embedding layer)
        src_emb = self.shared_embedding(src_ids) * self.scale_emb
        tgt_emb = self.shared_embedding(tgt_ids) * self.scale_emb

        # Apply dropout to embeddings
        src_emb = self.dropout(src_emb)
        tgt_emb = self.dropout(tgt_emb)

        # Note: Relative position bias is handled within the attention layers.

        # 2. Encoder Pass
        # Encodes the source sequence.
        # memory shape: (batch_size, src_seq_len, d_model)
        memory = self.encoder(src_emb, src_key_padding_mask=src_padding_mask)

        # 3. Decoder Pass
        # Takes encoder output (memory) and the shifted target sequence embeddings.
        # decoder_output shape: (batch_size, tgt_seq_len, d_model)
        decoder_output = self.decoder(tgt_emb, memory,
                                      tgt_mask=tgt_mask,                # Causal mask for decoder self-attention
                                      tgt_key_padding_mask=tgt_padding_mask, # Padding mask for decoder input tokens
                                      memory_key_padding_mask=memory_key_padding_mask) # Padding mask for encoder output (memory)

        # 4. Final Linear Layer (Logits)
        # Project decoder output to vocabulary space.
        # logits shape: (batch_size, tgt_seq_len, vocab_size)
        logits = self.lm_head(decoder_output)

        return logits

    # Helper function to create all necessary masks based on input token IDs
    def create_masks(self, src_ids, tgt_ids):
        """
        Generates padding masks and the target causal mask.
        Uses the model's internal pad_token_id.
        """
        device = src_ids.device
        # Source padding mask: True where src_ids is padding (ID = self.pad_token_id)
        src_padding_mask = (src_ids == self.pad_token_id) # (batch_size, src_seq_len)

        # Target padding mask: True where tgt_ids (decoder input) is padding
        tgt_padding_mask = (tgt_ids == self.pad_token_id) # (batch_size, tgt_seq_len)

        # Memory key padding mask (for cross-attention): Based on source padding.
        # True where the *memory* comes from a padded source token.
        memory_key_padding_mask = src_padding_mask # (batch_size, src_seq_len)

        # Target causal (look-ahead) mask: Prevents attention to future tokens in the target sequence.
        tgt_seq_len = tgt_ids.size(1)
        # Generates a square matrix where the upper triangle (future positions) is masked.
        # PyTorch's function returns float tensor with -inf for masked positions.
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len, device=device)
        # Shape: (tgt_seq_len, tgt_seq_len)

        return src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask

In [70]:

if config.vocab_size is None or 'tokenizer' not in locals() or tokenizer is None:
    raise ValueError("Config vocab_size not set or tokenizer not loaded. Run data loading cells first.")

print(f"Initializing model with vocab size: {config.vocab_size}")
model = T5Model(config, tokenizer).to(config.device) # Pass tokenizer to get pad_id

# Print model parameter count for verification
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel created on {config.device}.")
print(f"Total Trainable Parameters: {num_params:,}")

# Optional: Print model structure summary (can be very long)
print("\nModel Structure:")
print(model)

Initializing model with vocab size: 4000
Initializing T5Model with vocab_size=4000 and pad_token_id=3
Model weights initialized (Embeddings: Uniform, Linear: Xavier, LayerNorm: Default).
T5 Model components initialized.

Model created on cuda.
Total Trainable Parameters: 46,156,288

Model Structure:
T5Model(
  (shared_embedding): Embedding(4000, 512, padding_idx=3)
  (dropout): Dropout(p=0.1, inplace=False)
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attn): MultiHeadAttention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=512, bias=False)
          (v_proj): Linear(in_features=512, out_features=512, bias=False)
          (out_proj): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
          (relative_position_bias): RelativePositionBias(
            (relative_attention_bias): Embedding(32, 8)
      

In [95]:
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01)
print(f"Optimizer: AdamW (lr={config.learning_rate}, weight_decay=0.01)")

# Loss Function: Cross Entropy Loss for classification (predicting next token)
# Crucially, ignore the padding token ID during loss calculation.
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) # Use pad_id from tokenizer
print(f"Loss Function: CrossEntropyLoss (ignoring padding token ID: {tokenizer.pad_token_id})")

Optimizer: AdamW (lr=0.0001, weight_decay=0.01)
Loss Function: CrossEntropyLoss (ignoring padding token ID: 3)


In [96]:
# %% [markdown]
# ### 7.1 Training Epoch Function
# Defines the function to run one epoch of training.

# %%
def train_epoch(model, dataloader, optimizer, criterion, config, epoch_num):
    """Runs one training epoch."""
    model.train() # Set model to training mode (enables dropout, etc.)
    total_loss = 0.0
    num_batches = len(dataloader)

    # Wrap dataloader with tqdm for a progress bar
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch_num+1}/{config.epochs} Training", leave=False, unit="batch")

    for batch_idx, (input_ids, target_ids) in enumerate(progress_bar):
        # Move data to the configured device
        # input_ids = context, target_ids = sequence shifted left by one
        input_ids = input_ids.to(config.device)   # Shape: (batch_size, max_seq_len)
        target_ids = target_ids.to(config.device) # Shape: (batch_size, max_seq_len)

        # --- Prepare inputs for T5 LM ---
        # Encoder input = `input_ids`
        # Decoder input = `input_ids` (acts as the shifted-right input for prediction)
        # Target for loss = `target_ids` (the actual next tokens)

        # Create masks based on the *input* sequences using the model's helper method
        # Note: We pass input_ids for both src and tgt mask generation in this LM setup
        src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask = \
            model.create_masks(src_ids=input_ids, tgt_ids=input_ids)

        # --- Training Step ---
        optimizer.zero_grad() # Reset gradients from previous batch

        # Forward pass: Get model predictions (logits)
        logits = model(src_ids=input_ids, tgt_ids=input_ids, # Use same sequence for enc/dec input
                       src_padding_mask=src_padding_mask,
                       tgt_padding_mask=tgt_padding_mask, # Mask padding in decoder input
                       memory_key_padding_mask=memory_key_padding_mask, # Mask padding in encoder output
                       tgt_mask=tgt_mask) # Apply causal mask in decoder
        # Logits shape: (batch_size, max_seq_len, vocab_size)

        # Calculate loss
        # Reshape logits to (N, C) and targets to (N) for CrossEntropyLoss
        loss = criterion(logits.view(-1, config.vocab_size), target_ids.view(-1))

        # Backward pass: Compute gradients
        loss.backward()

        # Gradient clipping: Prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad_norm)

        # Optimizer step: Update model weights
        optimizer.step()
        # --- End Training Step ---

        # --- Logging ---
        current_loss = loss.item() # Get scalar loss value
        total_loss += current_loss
        # Update progress bar postfix with current and average loss
        progress_bar.set_postfix(loss=f"{current_loss:.4f}", avg_loss=f"{total_loss / (batch_idx + 1):.4f}")

    # Return average loss for the epoch
    return total_loss / num_batches

### training

In [99]:

# Check if dataset is populated before starting training
if 'dataset' not in locals() or len(dataset) == 0:
    print("ERROR: Dataset not created or is empty. Cannot start training.")
    print("Please ensure the data loading and tokenization steps ran successfully.")
elif 'model' not in locals() or model is None:
    print("ERROR: Model not initialized. Cannot start training.")
else:
    print("Starting training...")
    print(f"  Epochs: {config.epochs}")
    print(f"  Batch size: {config.batch_size}")
    print(f"  Max sequence length (tokens): {config.max_seq_len}")
    print(f"  Number of batches per epoch: {len(dataloader)}")
    print(f"  Device: {config.device}")
    train_start_time = time.time()

    training_losses = [] # Store average loss per epoch

    # --- Main Training Loop ---
    for epoch in range(config.epochs):
        epoch_start_time = time.time()

        # Run one epoch of training
        avg_train_loss = train_epoch(model, dataloader, optimizer, criterion, config, epoch)
        training_losses.append(avg_train_loss)

        epoch_duration = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}/{config.epochs} | Train Loss: {avg_train_loss:.4f} | Time: {epoch_duration:.2f}s")

        # Save model periodically to avoid losing progress on long runs
        if (epoch + 1) % 5 == 0 or epoch == config.epochs - 1: # Save every 5 epochs and at the end
            chkpt_path = f"t5_{config.sp_model_prefix}_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_train_loss,
                'config': vars(config), # Save config for reproducibility
                'sp_model_prefix': config.sp_model_prefix # Need this to load tokenizer later
            }, chkpt_path)
            print(f"Checkpoint saved to {chkpt_path}")

    # --- Training Finished ---
    train_duration = time.time() - train_start_time
    print(f"\nTraining finished.")
    print(f"Total Training Time: {train_duration // 60:.0f}m {train_duration % 60:.0f}s")
    if training_losses:
        print(f"Final Average Training Loss: {training_losses[-1]:.4f}")
    final_model_path = f"t5_{config.sp_model_prefix}_final.pt"
    try:
        torch.save({
            'model_state_dict': model.state_dict(),
            'config': vars(config), 
            'sp_model_prefix': config.sp_model_prefix 
        }, final_model_path)
        print(f"\nFinal model state dictionary saved to '{final_model_path}'")
        print(f"-> IMPORTANT: Keep SentencePiece files '{config.sp_model_prefix}.model' and '{config.sp_model_prefix}.vocab' alongside this checkpoint file to be able to reload the model and tokenizer later.")
    except Exception as e:
        print(f"\nError saving final model: {e}")

Starting training...
  Epochs: 30
  Batch size: 32
  Max sequence length (tokens): 128
  Number of batches per epoch: 1357
  Device: cuda


Epoch 1/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 1/30 | Train Loss: 0.6501 | Time: 342.31s


Epoch 2/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 2/30 | Train Loss: 0.1604 | Time: 342.24s


Epoch 3/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 3/30 | Train Loss: 0.0788 | Time: 342.30s


Epoch 4/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 4/30 | Train Loss: 0.0530 | Time: 342.13s


Epoch 5/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 5/30 | Train Loss: 0.0401 | Time: 342.54s
Checkpoint saved to t5_ts_spm_bpe_epoch_5.pt


Epoch 6/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 6/30 | Train Loss: 0.0319 | Time: 341.87s


Epoch 7/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 7/30 | Train Loss: 0.0269 | Time: 341.94s


Epoch 8/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 8/30 | Train Loss: 0.0231 | Time: 341.90s


Epoch 9/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 9/30 | Train Loss: 0.0200 | Time: 341.95s


Epoch 10/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 10/30 | Train Loss: 0.0178 | Time: 341.98s
Checkpoint saved to t5_ts_spm_bpe_epoch_10.pt


Epoch 11/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 11/30 | Train Loss: 0.0162 | Time: 343.16s


Epoch 12/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 12/30 | Train Loss: 0.0145 | Time: 343.55s


Epoch 13/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 13/30 | Train Loss: 0.0133 | Time: 343.51s


Epoch 14/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 14/30 | Train Loss: 0.0121 | Time: 342.71s


Epoch 15/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 15/30 | Train Loss: 0.0113 | Time: 342.44s
Checkpoint saved to t5_ts_spm_bpe_epoch_15.pt


Epoch 16/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 16/30 | Train Loss: 0.0103 | Time: 342.71s


Epoch 17/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 17/30 | Train Loss: 0.0096 | Time: 342.71s


Epoch 18/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 18/30 | Train Loss: 0.0091 | Time: 342.57s


Epoch 19/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 19/30 | Train Loss: 0.0084 | Time: 343.31s


Epoch 20/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 20/30 | Train Loss: 0.0079 | Time: 343.29s
Checkpoint saved to t5_ts_spm_bpe_epoch_20.pt


Epoch 21/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 21/30 | Train Loss: 0.0071 | Time: 343.66s


Epoch 22/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 22/30 | Train Loss: 0.0067 | Time: 343.30s


Epoch 23/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 23/30 | Train Loss: 0.0062 | Time: 342.26s


Epoch 24/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 24/30 | Train Loss: 0.0058 | Time: 341.99s


Epoch 25/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 25/30 | Train Loss: 0.0055 | Time: 341.70s
Checkpoint saved to t5_ts_spm_bpe_epoch_25.pt


Epoch 26/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 26/30 | Train Loss: 0.0052 | Time: 341.54s


Epoch 27/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 27/30 | Train Loss: 0.0048 | Time: 341.99s


Epoch 28/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 28/30 | Train Loss: 0.0048 | Time: 341.70s


Epoch 29/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 29/30 | Train Loss: 0.0045 | Time: 342.14s


Epoch 30/30 Training:   0%|          | 0/1357 [00:00<?, ?batch/s]

Epoch 30/30 | Train Loss: 0.0041 | Time: 342.31s
Checkpoint saved to t5_ts_spm_bpe_epoch_30.pt

Training finished.
Total Training Time: 171m 18s
Final Average Training Loss: 0.0041

Final model state dictionary saved to 't5_ts_spm_bpe_final.pt'
-> IMPORTANT: Keep SentencePiece files 'ts_spm_bpe.model' and 'ts_spm_bpe.vocab' alongside this checkpoint file to be able to reload the model and tokenizer later.


### generate examples

In [100]:

def generate(model, tokenizer: SentencePieceTokenizer, prompt, max_length=50, config=config,
             temperature=1.0, top_p=0.9): # Added temp and top_p
    """Generates text autoregressively using nucleus sampling."""
    model.eval()
    device = config.device
    pad_token_id = tokenizer.pad_token_id
    bos_token_id = tokenizer.bos_token_id
    eos_token_id = tokenizer.eos_token_id
    d_model = config.d_model
    scale_emb = math.sqrt(d_model)

    print(f"\n--- Generating from prompt: '{prompt}' (temp={temperature}, top_p={top_p}) ---")

    # Prepare prompt
    prompt_tokens_ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
    input_ids = torch.tensor([prompt_tokens_ids], dtype=torch.long).to(device)

    # --- Encoder Pass (Once) ---
    with torch.no_grad():
        src_padding_mask = (input_ids == pad_token_id)
        src_emb = model.shared_embedding(input_ids) * scale_emb
        memory = model.encoder(src_emb, src_key_padding_mask=src_padding_mask)
        memory_key_padding_mask = src_padding_mask

    # --- Decoder Autoregressive Loop ---
    decoder_input_ids = torch.tensor([[bos_token_id]], dtype=torch.long).to(device) # Start with BOS
    generated_token_ids = []

    print(f"Generating (max_length={max_length}): ", end='')
    for i in range(max_length):
        with torch.no_grad():
            tgt_seq_len = decoder_input_ids.size(1)
            _, tgt_padding_mask, _, tgt_mask = model.create_masks(decoder_input_ids, decoder_input_ids)

            # --- Get Logits ---
            # Use the simple model forward (no attention weights needed)
            logits = model(src_ids=input_ids, tgt_ids=decoder_input_ids,
                           src_padding_mask=src_padding_mask,
                           tgt_padding_mask=tgt_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           tgt_mask=tgt_mask)

            last_token_logits = logits[:, -1, :] # Get logits for the very last position

            # Apply temperature
            if temperature != 1.0:
                last_token_logits = last_token_logits / temperature

            # Calculate probabilities
            probs = F.softmax(last_token_logits, dim=-1) # Shape: (1, vocab_size)

            # Sort probabilities and indices
            sorted_probs, sorted_indices = torch.sort(probs, descending=True) # Shape: (1, vocab_size)

            # Calculate cumulative probabilities
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Shape: (1, vocab_size)

            # Create mask for probabilities NOT in the nucleus
            sorted_indices_to_remove = cumulative_probs > top_p # Shape: (1, vocab_size)
            # Shift the mask to the right to keep the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0 # Always keep the most likely token

            # Create a mask in the original indices space
            indices_to_remove = torch.zeros_like(probs, dtype=torch.bool).scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)

            # Zero out probabilities of tokens to remove (outside the nucleus)
            probs[indices_to_remove] = 0
            # torch.multinomial requires non-negative weights summing to non-zero.
            # If all probabilities become zero (e.g., top_p=0), multinomial will error.
            # We can handle this edge case, but with reasonable top_p it's unlikely.
            # Also, multinomial works with unnormalized weights, so explicit re-normalization isn't strictly required.
            if torch.sum(probs) == 0: # Handle case where all probs got zeroed (e.g. top_p=0)
                 next_token_id = torch.argmax(last_token_logits, dim=-1).unsqueeze(0) # Fallback to greedy
                 print("[Warning: All probs zeroed, falling back to greedy] ", end='')
            else:
                 next_token_id = torch.multinomial(probs, num_samples=1) # Shape: (1, 1)


            # --- Stopping Condition ---
            if next_token_id.item() == eos_token_id:
                print(" (EOS)", end='')
                break

            # --- Store and Update ---
            generated_token_ids.append(next_token_id.item())
            # Detach is not strictly needed here as we are in no_grad context, but good practice
            decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.detach()], dim=1)
            try:
                 piece = tokenizer.processor.id_to_piece(next_token_id.item())
                 print(piece.replace(' ', ' '), end='', flush=True)
            except IndexError:
                 print(f" [Invalid ID: {next_token_id.item()}] ", end='', flush=True)

    print() 

    generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
    return generated_text

In [101]:

if 'model' in locals() and model is not None and \
   'tokenizer' in locals() and tokenizer is not None:
    print("\n--- Running Generation Examples (with Nucleus Sampling) ---")

    prompts = [
        "Taylor Swift was born in",
        "The album Fearless featured the song",
        "Folklore and Evermore explored",
        "Her songwriting is known for",
        "The Eras Tour became the"
    ]
    generation_max_len = 75 # Number of tokens to generate

    sampling_temp = 0.7   # Lower -> more deterministic; Higher -> more random
    sampling_top_p = 0.9  # Probability mass to keep (e.g., 0.9 means keep top 90%)

    for prompt in prompts:
        start_gen_time = time.time()
        # Call generate with sampling parameters
        generated_text = generate(model, tokenizer, prompt,
                                  max_length=generation_max_len,
                                  config=config,
                                  temperature=sampling_temp,
                                  top_p=sampling_top_p)
        gen_duration = time.time() - start_gen_time
        print(f"-> Generated ({gen_duration:.2f}s): '{generated_text}'")
        print(f"-> Full Text: '{prompt}{generated_text}'")
        print("-" * 30)

else:
    print("\nModel or tokenizer not found or not initialized.")
    print("Please ensure the previous cells, including training, have been run successfully.")


--- Running Generation Examples (with Nucleus Sampling) ---

--- Generating from prompt: 'Taylor Swift was born in' (temp=0.7, top_p=0.9) ---
Generating (max_length=75): ▁Casticles:▁Taylor▁Swift's▁'Miss▁Americana'▁Is▁What▁Youth▁Annual▁Grammy▁Awards".▁Billboard.▁Archived▁from▁the▁original▁on▁May▁13,▁2020.▁Retrieved▁May▁13,▁2020.▁"9▁Things▁You▁Might▁Have▁Missed▁in▁Taylor▁Swift's▁Netflix▁Concert▁Film".▁E!▁News.▁December▁28,▁2021.▁Retrieved▁May▁13,▁2020.▁Hiatt,▁Brian▁(September▁30,▁2019).▁"
-> Generated (1.71s): 'Casticles: Taylor Swift's 'Miss Americana' Is What Youth Annual Grammy Awards". Billboard. Archived from the original on May 13, 2020. Retrieved May 13, 2020. "9 Things You Might Have Missed in Taylor Swift's Netflix Concert Film". E! News. December 28, 2021. Retrieved May 13, 2020. Hiatt, Brian (September 30, 2019). "'
-> Full Text: 'Taylor Swift was born inCasticles: Taylor Swift's 'Miss Americana' Is What Youth Annual Grammy Awards". Billboard. Archived from the original on Ma