In [None]:
 #!/usr/bin/env python3
"""
Foundational AI Project 2 – Language Modeling with RNNs, LSTMs, and Transformer (Graduate Version)

This script trains three language models (RNN, LSTM, Transformer) for text generation.
It uses a SentencePiece BPE tokenizer (vocab size=10000) to tokenize text from JSONL files,
builds fixed-length sequences via a sliding window approach, and trains the models
using early stopping with a cosine annealing learning rate scheduler.
Evaluation metrics include perplexity (exp(cross-entropy loss)), token accuracy,
and BLEU score (computed with nltk). Sample outputs and loss curves with detailed plots
are generated, and model performance is compared.
Graduate-level requirements such as temperature-based decoding are supported in the prompt methods.
"""

import os
import math
import json
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR  # Cosine learning rate scheduler
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm  # for subword tokenization
import matplotlib.pyplot as plt
import nltk
from nltk.translate.bleu_score import sentence_bleu

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# Select computation device: GPU if available, otherwise CPU.
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
            device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"[✓] Using device: {device}")

###############################################################################
# Positional Encoding Module
###############################################################################
class PositionalEncoding(nn.Module):
    """
    Computes sinusoidal positional encodings and adds them to token embeddings.
    This implementation follows 'Attention is All You Need' (Vaswani et al., 2017).
    
    Args:
        embed_dim (int): Embedding dimension (d_model).
        max_len (int): Maximum sequence length (number of positions).
        dropout (float): Dropout probability applied after adding positional encodings.
    """
    def __init__(self, embed_dim, max_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        # Initialize an empty tensor for positional encoding: shape [max_len, embed_dim]
        pos_enc = torch.zeros(max_len, embed_dim)
        # Create a column vector of positions [0, 1, ..., max_len-1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # Compute the divisor term for the sinusoidal functions
        div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float) * -(math.log(10000.0) / embed_dim))
        # Apply sine to even indices and cosine to odd indices of the embedding dimension
        pos_enc[:, 0::2] = torch.sin(position * div_term)  # even indices: sin(position/scale)
        pos_enc[:, 1::2] = torch.cos(position * div_term)  # odd indices: cos(position/scale)
        # Add an extra batch dimension so that pos_enc becomes [1, max_len, embed_dim]
        pos_enc = pos_enc.unsqueeze(0)
        # Register pos_enc as a buffer so it is saved with the model and moved across devices
        self.register_buffer("pos_enc", pos_enc)
    
    def forward(self, x):
        """
        Adds positional encoding to input tensor x.
        
        Args:
            x (Tensor): Input embeddings of shape [batch_size, seq_length, embed_dim].
        Returns:
            Tensor: Output embeddings with added positional encoding.
        """
        # Slice the positional encoding to the input sequence length and add to x.
        x = x + self.pos_enc[:, :x.size(1)]
        # Apply dropout and return
        return self.dropout(x)

###############################################################################
# Data Preparation Functions
###############################################################################
def train_tokenizer_if_needed(tokenizer_model_prefix="tokenizer", vocab_size=10000, training_text_file="train.txt"):
    """
    Trains a SentencePiece tokenizer if the model file is not found.
    The tokenizer is used to encode text into subword tokens.
    
    Args:
        tokenizer_model_prefix (str): Prefix for the tokenizer model filename.
        vocab_size (int): Subword vocabulary size.
        training_text_file (str): Plain text file used to train the tokenizer.
    
    Returns:
        SentencePieceProcessor: The trained tokenizer.
    """
    if not os.path.exists(f"{tokenizer_model_prefix}.model"):
        print("Training tokenizer...")
        # Read training text from file
        with open(training_text_file, "r", encoding="utf-8") as f:
            training_text = f.read()
        # Save the training text to a temporary file needed by SentencePieceTrainer.
        with open("temp_training.txt", "w", encoding="utf-8") as f:
            f.write(training_text)
        # Train the SentencePiece model using BPE
        spm.SentencePieceTrainer.train(
            input="temp_training.txt",
            model_prefix=tokenizer_model_prefix,
            vocab_size=vocab_size,
            model_type="bpe",
            character_coverage=1.0
        )
        # Remove the temporary file
        os.remove("temp_training.txt")
    # Load the tokenizer model
    sp = spm.SentencePieceProcessor(model_file=f"{tokenizer_model_prefix}.model")
    return sp

def load_and_tokenize(file_path, sp):
    """
    Loads text data from a JSONL file and tokenizes it.
    Each JSON object should have "prompt" and "completion" fields.
    
    Args:
        file_path (str): Path to the JSONL file.
        sp (SentencePieceProcessor): Trained SentencePiece tokenizer.
    
    Returns:
        list[int]: List of token IDs from the combined text.
    """
    texts = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                # Parse JSON line
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue  # Skip lines that are not valid JSON
            prompt = obj.get("prompt", "")
            completion = obj.get("completion", "")
            # Combine prompt and completion
            text = (prompt + " " + completion).strip()
            if text:
                texts.append(text)
    # Combine all texts into one large text separated by newlines
    combined = "\n".join(texts)
    print(f"Loaded {len(texts)} text entries from {file_path}. Total length: {len(combined)} characters")
    # Tokenize the combined text using SentencePiece
    return sp.encode(combined, out_type=int)

def build_sequences(token_ids, seq_length):
    """
    Creates overlapping sequences from token IDs.
    Each sequence will have (seq_length + 1) tokens to allow input/target pairing.
    
    Args:
        token_ids (list[int]): List of token IDs.
        seq_length (int): Desired input sequence length.
    
    Returns:
        list[list[int]]: List where each sublist is a token sequence of length seq_length+1.
    """
    return [token_ids[i:i+seq_length+1] for i in range(len(token_ids) - seq_length)]

###############################################################################
# Custom Dataset Class for Language Modeling
###############################################################################
class LanguageModelDataset(Dataset):
    """
    Custom dataset for language modeling.
    Each sample is a tuple: (input_tokens, target_tokens) where target_tokens is the input shifted by one.
    """
    def __init__(self, sequences, seq_length):
        # Filter sequences to ensure they have the exact required length.
        self.samples = [(seq[:-1], seq[1:]) for seq in sequences if len(seq) == seq_length + 1]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        inp, target = self.samples[idx]
        return torch.tensor(inp, dtype=torch.long), torch.tensor(target, dtype=torch.long)

###############################################################################
# Text Generation and Model Definitions
###############################################################################
def generate_text(model, tokenizer, prompt_text, max_length=128, temperature=1.0):
    """
    Generates text by autoregressively sampling from the model.
    
    Args:
        model (nn.Module): The trained language model.
        tokenizer (SentencePieceProcessor): Tokenizer to encode/decode text.
        prompt_text (str): The initial text prompt.
        max_length (int): Maximum number of tokens to generate (excluding the prompt tokens).
        temperature (float): Sampling temperature, where values near 0 imply greedy decoding.
    
    Returns:
        str: The generated text (decoded).
    """
    model.eval()
    # Encode prompt text to token IDs.
    token_ids = tokenizer.encode(prompt_text, out_type=int)
    generated = token_ids.copy()  # Copy prompt tokens to start generation
    device = next(model.parameters()).device
    
    with torch.no_grad():
        for _ in range(max_length):
            # Prepare input tensor of shape [1, current_length]
            input_ids = torch.tensor([generated], dtype=torch.long, device=device)
            logits = model(input_ids)  # Forward pass through the model; shape: [1, seq_len, vocab_size]
            next_logits = logits[0, -1, :]  # Take the logits for the last time step
            # Determine next token
            if temperature < 1e-5:
                # Greedy: take the token with maximum probability
                next_token = torch.argmax(next_logits).item()
            else:
                # Scale logits by temperature
                scaled_logits = next_logits / temperature
                # Compute probabilities
                probs = torch.softmax(scaled_logits, dim=0)
                # Sample the next token
                next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
            # Stop if end-of-sequence token is produced.
            if next_token == tokenizer.eos_id():
                break
    # Decode the generated token IDs back into text.
    return tokenizer.decode(generated)

class RNNLanguageModel(nn.Module):
    """
    Vanilla RNN-based language model.
    
    Architecture:
      - Embedding layer: maps token IDs to a continuous representation.
      - RNN layers: processes the sequence.
      - Fully connected layer: projects RNN outputs to vocabulary logits.
    
    Includes a prompt method that supports temperature-based generation
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)  # Map token indices to embeddings
        self.rnn = nn.RNN(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=0.2)  # RNN layer(s)
        self.fc = nn.Linear(hidden_dim, vocab_size)  # Output projection layer
    
    def forward(self, x):
        # x has shape [batch_size, seq_length]
        embeds = self.embedding(x)  # Convert tokens to embeddings: shape [batch, seq_length, embed_dim]
        output, _ = self.rnn(embeds)  # Process embeddings with RNN: output shape [batch, seq_length, hidden_dim]
        logits = self.fc(output)  # Project to vocabulary logits: [batch, seq_length, vocab_size]
        return logits
    
    def prompt(self, tokenizer, prompt_text, max_length=128, temperature=1.0):
        # Generate text using the generate_text function
        return generate_text(self, tokenizer, prompt_text, max_length, temperature)

class LSTMLanguageModel(nn.Module):
    """
    LSTM-based language model.
    
    Architecture similar to RNNLanguageModel but uses nn.LSTM layers.
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        embeds = self.embedding(x)
        output, _ = self.lstm(embeds)
        logits = self.fc(output)
        return logits
    
    def prompt(self, tokenizer, prompt_text, max_length=128, temperature=1.0):
        return generate_text(self, tokenizer, prompt_text, max_length, temperature)

class TransformerLanguageModel(nn.Module):
    """
    Transformer-based language model.
    
    Architecture:
      - Embedding layer followed by positional encoding.
      - Transformer encoder: captures long-range dependencies.
      - Fully-connected layer to output vocabulary logits.
      
    Positional encoding ensures that token order is captured.
    """
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, max_seq_length, dropout=0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # Incorporate positional encoding to add sequence order information.
        self.pos_encoder = PositionalEncoding(embed_dim, max_len=max_seq_length, dropout=dropout)
        # Define a single Transformer encoder layer and stack multiple layers.
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                                                   dim_feedforward=hidden_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(embed_dim, vocab_size)
    
    def forward(self, x):
        # Convert token indices to embeddings.
        embeds = self.embedding(x)  # [batch, seq_length, embed_dim]
        # Add positional encodings.
        encoded = self.pos_encoder(embeds)
        # Transformer expects input shape [seq_length, batch, embed_dim]; so we transpose.
        encoded = encoded.transpose(0, 1)
        # Pass through Transformer encoder.
        transformer_out = self.transformer_encoder(encoded)
        # Transpose back to [batch, seq_length, embed_dim].
        transformer_out = transformer_out.transpose(0, 1)
        # Project to vocabulary logits.
        logits = self.fc(transformer_out)
        return logits
    
    def prompt(self, tokenizer, prompt_text, max_length=128, temperature=1.0):
        return generate_text(self, tokenizer, prompt_text, max_length, temperature)

###############################################################################
# Training, Evaluation, and Plotting Functions
###############################################################################
def train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, scheduler, device, patience=5):
    """
    Trains the model for a maximum number of epochs with early stopping.
    
    Uses mini-batch gradient descent with gradient clipping and a cosine annealing scheduler.
    
    Args:
        model (nn.Module): The language model.
        train_loader (DataLoader): Training data loader.
        val_loader (DataLoader): Validation data loader.
        num_epochs (int): Maximum epochs.
        criterion: Loss function (CrossEntropyLoss).
        optimizer: Optimizer (e.g., AdamW).
        scheduler: CosineAnnealingLR scheduler.
        device: Computation device.
        patience (int): Number of epochs to wait without improvement for early stopping.
        
    Returns:
        tuple: Lists of training and validation losses per epoch.
    """
    model.to(device)
    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    best_model_state = None
    epochs_no_improve = 0

    for epoch in range(1, num_epochs + 1):
        model.train()  # Set model to training mode
        total_train_loss = 0
        epoch_start = time.time()

        # Training loop over mini-batches
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()  # Zero the gradients for this batch
            outputs = model(inputs)  # Forward pass: shape [batch, seq_length, vocab_size]
            # Compute loss by flattening outputs and targets
            loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
            loss.backward()  # Backpropagate loss
            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()  # Update model parameters
            total_train_loss += loss.item()
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation evaluation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # Step the LR scheduler after each epoch
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        epoch_duration = time.time() - epoch_start
        print(f"Epoch {epoch} | LR: {current_lr:.6f} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Duration: {epoch_duration:.2f}s")

        # Early stopping mechanism: if validation loss doesn't improve, increment counter
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            best_model_state = {k: v.cpu() for k, v in model.state_dict().items()}
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping triggered. Restoring best model state.")
                model.load_state_dict(best_model_state)
                break

    return train_losses, val_losses

def evaluate_model(model, data_loader, criterion, device):
    """
    Evaluates the model on the provided dataset.
    
    Args:
        model (nn.Module): Trained language model.
        data_loader (DataLoader): DataLoader for evaluation.
        criterion: Loss function.
        device: Computation device.
        
    Returns:
        float: Average loss computed over all batches.
    """
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
            total_loss += loss.item()
    return total_loss / len(data_loader)

def compute_perplexity(loss):
    """
    Computes perplexity from cross-entropy loss.
    
    Args:
        loss (float): Average cross-entropy loss.
    
    Returns:
        float: Perplexity as exp(loss).
    """
    return np.exp(loss)

def compute_token_accuracy(model, data_loader, device):
    """
    Computes token-level accuracy over the entire dataset.
    
    Args:
        model (nn.Module): Trained language model.
        data_loader (DataLoader): DataLoader for the dataset.
        device: Computation device.
    
    Returns:
        float: Fraction of tokens correctly predicted.
    """
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            predictions = torch.argmax(outputs, dim=-1)
            correct += (predictions == targets).sum().item()
            total += targets.numel()
    return correct / total

def compute_bleu(reference, candidate):
    """
    Computes BLEU score comparing a reference sentence to a generated candidate sentence.
    
    Args:
        reference (str): Ground truth sentence.
        candidate (str): Generated sentence.
    
    Returns:
        float: BLEU score (0 to 1).
    """
    return sentence_bleu([nltk.word_tokenize(reference)], nltk.word_tokenize(candidate))

def plot_loss_curve(train_losses, val_losses, model_name):
    """
    Plots the training and validation loss curves with annotations.
    
    Args:
        train_losses (list[float]): Training losses per epoch.
        val_losses (list[float]): Validation losses per epoch.
        model_name (str): Name of the model (used for title and saving the plot).
    """
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(train_losses) + 1)
    # Plot training loss with marker circle
    plt.plot(epochs, train_losses, label="Train Loss", marker='o', linestyle='-', linewidth=2)
    # Plot validation loss with square marker
    plt.plot(epochs, val_losses, label="Validation Loss", marker='s', linestyle='-', linewidth=2)
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("Loss", fontsize=12)
    plt.title(f"{model_name} Loss Curve", fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    # Annotate final loss values with arrows
    plt.annotate(f"Final Train: {train_losses[-1]:.4f}", xy=(epochs[-1], train_losses[-1]),
                 xytext=(epochs[-1]-5, train_losses[-1]+0.05),
                 arrowprops=dict(facecolor='blue', shrink=0.05),
                 fontsize=10, color='blue')
    plt.annotate(f"Final Val: {val_losses[-1]:.4f}", xy=(epochs[-1], val_losses[-1]),
                 xytext=(epochs[-1]-5, val_losses[-1]+0.05),
                 arrowprops=dict(facecolor='red', shrink=0.05),
                 fontsize=10, color='red')
    plt.tight_layout()
    plt.savefig(f"{model_name}_loss.png", dpi=300)
    plt.show()

###############################################################################
# Main Training and Evaluation Pipeline
###############################################################################
def main():
    """
    Main function to train, evaluate, and compare the language models.
    
    This function performs the following steps:
      1. Sets hyperparameters and file paths.
      2. Prepares the tokenizer (training it if needed).
      3. Loads and tokenizes training and validation datasets.
      4. Builds fixed-length sequences and creates DataLoaders.
      5. Initializes the RNN, LSTM, and Transformer models.
      6. Trains each model with early stopping, gradient clipping, and a cosine scheduler.
      7. Evaluates each model using perplexity, token accuracy, and BLEU score.
      8. Generates sample text for a fixed prompt.
      9. Saves model parameters and displays performance comparisons through plots.
    """
    global_start = time.time()

    # ------------------ Hyperparameters ------------------
    vocab_size = 10000          # Vocabulary size for tokenizer
    embed_dim = 256             # Embedding dimension for tokens
    hidden_dim = 256            # Hidden dimension for RNN/LSTM and transformer feedforward
    num_layers = 2              # Number of layers in RNN/LSTM/Transformer encoder
    num_heads = 8               # Number of attention heads in Transformer
    max_seq_length = 128         # Input sequence length (without EOS token)
    batch_size = 256            # Training batch size
    num_epochs = 30             # Maximum number of epochs for training
    learning_rate = 1e-4        # Initial learning rate
    dropout_rate = 0.2         # Dropout rate for Transformer positional encoding and encoder layers
    weight_decay = 1e-4         # Weight decay for optimizer regularization
    pad_token_id = 3            # Token ID for padding (should match tokenizer's setting)

    # ------------------ File Paths ------------------
    # Update these paths as needed.
    train_file = "train.jsonl"
    test_file = "test.jsonl"
    
    if not os.path.exists(train_file) or not os.path.exists(test_file):
        raise FileNotFoundError("train.jsonl and/or test.jsonl not found.")
    
    # ------------------ Prepare Training Text for Tokenizer ------------------
    tokenizer_training_file = "train.txt"
    if not os.path.exists(tokenizer_training_file):
        texts = []
        with open(train_file, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    obj = json.loads(line)
                except json.JSONDecodeError:
                    continue
                prompt = obj.get("prompt", "")
                completion = obj.get("completion", "")
                text = (prompt + " " + completion).strip()
                if text:
                    texts.append(text)
        # Combine all text entries into one large file.
        combined_text = "\n".join(texts)
        with open(tokenizer_training_file, "w", encoding="utf-8") as f:
            f.write(combined_text)

    # ------------------ Tokenizer Preparation ------------------
    sp = train_tokenizer_if_needed(tokenizer_model_prefix="tokenizer", vocab_size=vocab_size, training_text_file=tokenizer_training_file)

    # ------------------ Load and Tokenize Datasets ------------------
    train_tokens = load_and_tokenize(train_file, sp)
    val_tokens = load_and_tokenize(test_file, sp)

    # ------------------ Build Fixed-Length Token Sequences ------------------
    train_seqs = build_sequences(train_tokens, max_seq_length)
    val_seqs = build_sequences(val_tokens, max_seq_length)
    print(f"Number of train tokens: {len(train_tokens)}")
    print(f"Number of val tokens: {len(val_tokens)}")
    print(f"Number of training sequences: {len(train_seqs)}")
    print(f"Number of validation sequences: {len(val_seqs)}")

    # ------------------ Create Dataset Objects and DataLoaders ------------------
    train_dataset = LanguageModelDataset(train_seqs, max_seq_length)
    val_dataset = LanguageModelDataset(val_seqs, max_seq_length)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=0, pin_memory=True if device.type == "cuda" else False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            num_workers=0, pin_memory=True if device.type == "cuda" else False)

    # ------------------ Initialize Models ------------------
    models = {
        "RNN": RNNLanguageModel(vocab_size, embed_dim, hidden_dim, num_layers),
        #"LSTM": LSTMLanguageModel(vocab_size, embed_dim, hidden_dim, num_layers),
        #"Transformer": TransformerLanguageModel(vocab_size, embed_dim, num_heads, hidden_dim, num_layers, max_seq_length, dropout=dropout_rate)
    }

    model_results = {}  # Dictionary to hold evaluation metrics for each model

    # ------------------ Train, Evaluate and Compare Models ------------------
    for name, model in models.items():
        print(f"\n--- Training {name} Model ---")
        model.to(device)
        # Use CrossEntropyLoss ignoring the padding token.
        criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        # Using a cosine annealing scheduler that adjusts the learning rate over epochs.
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
        
        # Train model and record time.
        model_start_time = time.time()
        train_losses, val_losses = train_model(model, train_loader, val_loader, num_epochs,
                                                criterion, optimizer, scheduler, device, patience=5)
        training_time = time.time() - model_start_time
        print(f"Total training time for {name}: {training_time:.2f} seconds")
        
        # Plot loss curves with detailed annotations.
        plot_loss_curve(train_losses, val_losses, name)
        
        # Evaluate model on validation set.
        val_loss = evaluate_model(model, val_loader, criterion, device)
        perplexity = compute_perplexity(val_loss)
        token_accuracy = compute_token_accuracy(model, val_loader, device) * 100
        
        # Compute BLEU score on a random sample from validation dataset.
        sample_idx = random.randint(0, len(val_dataset) - 1)
        sample_input, sample_target = val_dataset[sample_idx]
        prompt_text = sp.decode(sample_input.tolist())
        reference_text = sp.decode(sample_target.tolist())
        generated_text = model.prompt(sp, prompt_text, max_length=128, temperature=1.0)
        bleu = compute_bleu(reference_text, generated_text)
        
        print(f"{name} | Perplexity: {perplexity:.2f} | Token Accuracy: {token_accuracy:.2f}% | BLEU: {bleu:.4f}")
        
        # Generate sample output for a fixed prompt.
        fixed_prompt = "Which do you prefer? Dogs or cats?"
        sample_output = model.prompt(sp, fixed_prompt, max_length=128, temperature=1.0)
        print("Sample generated output:", sample_output)
        
        # Save the trained model state.
        torch.save(model.state_dict(), f"{name}_model.pt")
        print(f"{name} model saved as {name}_model.pt")
        
        # Store evaluation metrics for later comparison.
        model_results[name] = {
            "Perplexity": perplexity,
            "Token Accuracy (%)": token_accuracy,
            "BLEU Score": bleu,
            "Training Time (s)": training_time
        }
    
    # ------------------ Compare Model Performance ------------------
    print("\n--- Model Performance Summary ---")
    header = f"{'Model':<15} {'Perplexity':<12} {'Token Accuracy (%)':<20} {'BLEU Score':<12} {'Train Time (s)':<15}"
    print(header)
    for model_name, metrics in model_results.items():
        print(f"{model_name:<15} {metrics['Perplexity']:<12.2f} {metrics['Token Accuracy (%)']:<20.2f} "
              f"{metrics['BLEU Score']:<12.4f} {metrics['Training Time (s)']:<15.2f}")
    
    # Create comparative bar plots.
    model_names = list(model_results.keys())
    perplexities = [model_results[m]["Perplexity"] for m in model_names]
    accuracies = [model_results[m]["Token Accuracy (%)"] for m in model_names]
    bleu_scores = [model_results[m]["BLEU Score"] for m in model_names]
    train_times = [model_results[m]["Training Time (s)"] for m in model_names]
    
    x = np.arange(len(model_names))
    width = 0.2
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))
    
    axs[0, 0].bar(x, perplexities, width, color="skyblue")
    axs[0, 0].set_title("Perplexity")
    axs[0, 0].set_xticks(x)
    axs[0, 0].set_xticklabels(model_names)
    
    axs[0, 1].bar(x, accuracies, width, color="lightgreen")
    axs[0, 1].set_title("Token Accuracy (%)")
    axs[0, 1].set_xticks(x)
    axs[0, 1].set_xticklabels(model_names)
    
    axs[1, 0].bar(x, bleu_scores, width, color="salmon")
    axs[1, 0].set_title("BLEU Score")
    axs[1, 0].set_xticks(x)
    axs[1, 0].set_xticklabels(model_names)
    
    axs[1, 1].bar(x, train_times, width, color="plum")
    axs[1, 1].set_title("Training Time (s)")
    axs[1, 1].set_xticks(x)
    axs[1, 1].set_xticklabels(model_names)
    
    plt.suptitle("Model Performance and Computational Requirements Comparison", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig("model_comparison.png", dpi=300)
    plt.show()
    
    total_time = time.time() - global_start
    print(f"\nTotal elapsed time: {total_time:.2f} seconds")

if __name__ == "__main__":
    main()


[✓] Using device: mps
Loaded 39557 text entries from train.jsonl. Total length: 14673121 characters
Loaded 9890 text entries from test.jsonl. Total length: 3684340 characters
Number of train tokens: 3402656
Number of val tokens: 855555
Number of training sequences: 3402536
Number of validation sequences: 855435


KeyboardInterrupt: 