In [1]:
%pip install evaluate gensim

Note: you may need to restart the kernel to use updated packages.




In [2]:
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os , re , json
import pickle
import hashlib , math
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import gensim.downloader as api
import time
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict , Counter
from pathlib import Path


In [46]:
from Dataset.Vocabulary import Vocabulary
from Dataset.TinyStories import TinyStoriesDataset
from Dataset.load_fasttext_model import load_fasttext_model
from model.PositionalEncoding import PositionalEncoding
from model.MultiHeadAttention import MultiHeadAttention
from model.FeedForward import FeedForward
from model.LayerNorm import LayerNorm



In [47]:

class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, *args):
        ctx.run_function = run_function
        ctx.save_for_backward(*args)

        with torch.no_grad():
            outputs = run_function(*args)

        return outputs

    @staticmethod
    def backward(ctx, *grad_outputs):
        # Retrieve saved tensors
        inputs = ctx.saved_tensors

        # Recompute forward pass with gradients
        with torch.enable_grad():
            detached_inputs = [x.detach().requires_grad_(True) if isinstance(x, torch.Tensor)
                             else x for x in inputs]
            outputs = ctx.run_function(*detached_inputs)

        # Handle both single tensor and tuple outputs
        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        # Ensure grad_outputs is a tuple
        if not isinstance(grad_outputs, tuple):
            grad_outputs = (grad_outputs,)

        # Compute gradients only for tensors that require grad
        tensors_with_grad = [out for out in outputs if isinstance(out, torch.Tensor) and out.requires_grad]
        grad_tensors = grad_outputs[:len(tensors_with_grad)]

        if tensors_with_grad:
            torch.autograd.backward(tensors_with_grad, grad_tensors)

        # Collect gradients
        grads = tuple(x.grad if isinstance(x, torch.Tensor) and x.grad is not None else None
                     for x in detached_inputs)

        return (None,) + grads



In [48]:

def checkpoint_function(function, *args):
    """Apply gradient checkpointing to a function"""
    return CheckpointFunction.apply(function, *args)



In [53]:
class CheckpointedTransformerBlock(nn.Module):
    """Transformer block with manual gradient checkpointing"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout

        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = LayerNorm(d_model)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def attention_forward(self, x, mask):
        """Separate attention computation for checkpointing"""
        attn_output, attn_weights = self.attention(x, mask)
        return attn_output

    def feedforward_forward(self, x):
        """Separate feedforward computation for checkpointing"""
        return self.feed_forward(x)

    def forward(self, x, mask=None, use_checkpointing=False):
        if use_checkpointing and self.training:
            # Manual checkpointing: recompute in backward pass
            attn_output = checkpoint_function(self.attention_forward, x, mask)
        else:
            attn_output = self.attention_forward(x, mask)

        x = self.norm1(x + self.dropout(attn_output))

        if use_checkpointing and self.training:
            ff_output = checkpoint_function(self.feedforward_forward, x)
        else:
            ff_output = self.feedforward_forward(x)

        x = self.norm2(x + self.dropout(ff_output))

        return x, None


In [54]:
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads,
                 d_ff, max_seq_len, dropout=0.1, pretrained_embeddings=None):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout

        # Store the original embedding dimension (300 for FastText)
        self.embedding_dim = pretrained_embeddings.shape[1] if pretrained_embeddings is not None else d_model

        # Create embedding layer with original FastText dimension
        self.embedding = nn.Embedding(vocab_size, self.embedding_dim)

        # Load pretrained embeddings if provided
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
            # Add projection layer to convert from FastText dim to d_model
            self.embedding_proj = nn.Linear(self.embedding_dim, d_model)
        else:
            self.embedding_proj = nn.Identity()


        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        self.layers = nn.ModuleList([
            CheckpointedTransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_causal_mask(self, seq_len, device):
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        mask = mask.unsqueeze(0).unsqueeze(0)
        return mask

    def forward(self, x, return_attention=False, use_checkpointing=False):
        """
        Forward pass with optional gradient checkpointing

        Args:
            x: Input tensor [batch_size, seq_len]
            return_attention: Whether to return attention weights
            use_checkpointing: Whether to use gradient checkpointing in transformer blocks
        """
        batch_size, seq_len = x.shape
        mask = self.create_causal_mask(seq_len, x.device)

        # Get embeddings in original dimension (300)
        x = self.embedding(x) * math.sqrt(self.embedding_dim)

        # Project to d_model (which is divisible by num_heads)
        x = self.embedding_proj(x)

        x = self.pos_encoding(x)
        x = self.dropout(x)

        attention_weights = []

        # Pass use_checkpointing to each transformer block
        for layer in self.layers:
            x, attn_weights = layer(x, mask, use_checkpointing=use_checkpointing)
            if return_attention:
                attention_weights.append(attn_weights)

        x = self.norm(x)
        logits = self.output_projection(x)

        if return_attention:
            return logits, attention_weights
        return logits

    def get_config(self):
        """Return model configuration for reinitialization"""
        return {
            'vocab_size': self.vocab_size,
            'd_model': self.d_model,
            'num_layers': self.num_layers,
            'num_heads': self.num_heads,
            'd_ff': self.d_ff,
            'max_seq_len': self.max_seq_len,
            'dropout': self.dropout_p,
            'pretrained_embeddings': None  # Don't reuse pretrained embeddings in experiments
        }

In [57]:
CONFIG = {
    'name': 'baseline',
    'description': 'Standard baseline configuration from assignment',
    'context_length': 64,
    'num_layers': 3,
    'num_heads': 8,
    'd_model': 296,
    'd_ff': 1184,
    'dropout': 0.1,
    'batch_size': 32,
    'learning_rate': 3e-4,
    'num_epochs': 10,
    'max_train_samples': 15000,
    'max_val_samples': 5000,
    'save_dir': 'checkpoints/baseline',
    'plot_dir': 'plots/baseline'
}

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"\nConfiguration:")
for k, v in CONFIG.items():
     print(f"  {k}: {v}")

    # Load FastText
    
print("Loading FastText embeddings...")
fasttext_model = load_fasttext_model()

# Load Dataset
    
print("Loading TinyStories dataset...")
dataset = load_dataset("roneneldan/TinyStories")

    
print("Building vocabulary...")
vocab_path = f"{CONFIG['save_dir']}/vocab.json"

if os.path.exists(vocab_path):
        print("Loading existing vocabulary...")
        vocab = Vocabulary.load(vocab_path, fasttext_model)
else:
    vocab = Vocabulary(fasttext_model)
    # Build vocabulary from training data
    num_samples = min(CONFIG['max_train_samples'], len(dataset['train']))
    for i in tqdm(range(num_samples), desc="Building vocabulary"):
        text = dataset['train'][i]['text']
        for word in vocab.tokenize(text):
            vocab.add_word(word)
    vocab.save(vocab_path)

print(f"Vocabulary size: {len(vocab)}")

# Create Datasets
print("Creating datasets...")

# Prepare train texts
train_texts = [dataset['train'][i]['text'] for i in range(min(CONFIG['max_train_samples'], len(dataset['train'])))]
val_texts = [dataset['validation'][i]['text'] for i in range(min(CONFIG['max_val_samples'], len(dataset['validation'])))]

train_dataset = TinyStoriesDataset(
    train_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_train_samples']
)

val_dataset = TinyStoriesDataset(
    val_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_val_samples']
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                         shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                       shuffle=False, num_workers=0)

# Initialize Model
print("Initializing model...")
embedding_matrix = vocab.create_embedding_matrix()


Using device: cuda

Configuration:
  name: baseline
  description: Standard baseline configuration from assignment
  context_length: 64
  num_layers: 3
  num_heads: 8
  d_model: 296
  d_ff: 1184
  dropout: 0.1
  batch_size: 32
  learning_rate: 0.0003
  num_epochs: 10
  max_train_samples: 15000
  max_val_samples: 5000
  save_dir: checkpoints/baseline
  plot_dir: plots/baseline

Loading FastText embeddings...
Loading FastText model from cache...
Model loaded successfully!

Loading TinyStories dataset...

Building vocabulary...
Loading existing vocabulary...
Vocabulary size: 10598

Creating datasets...
Preparing dataset...


100%|██████████| 15000/15000 [00:07<00:00, 1969.88it/s]


Created 3083375 sequences
Preparing dataset...


100%|██████████| 5000/5000 [00:01<00:00, 3857.89it/s]


Created 925828 sequences

Initializing model...
Found 9972/10598 words in FastText


In [59]:

model = DecoderTransformer(
    vocab_size=len(vocab),
    d_model=CONFIG['d_model'],
    num_layers=CONFIG['num_layers'],
    num_heads=CONFIG['num_heads'],
    d_ff=CONFIG['d_ff'],
    max_seq_len=CONFIG['context_length'],
    dropout=CONFIG['dropout'],
    pretrained_embeddings=embedding_matrix
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[vocab.PAD_TOKEN])

In [60]:
def train_with_checkpointing(
    model, dataloader, optimizer, criterion, device, use_checkpointing=False, epoch=1
):
    """Train with optional gradient checkpointing"""
    model.train()
    total_loss = 0
    num_batches = 0

    # Measure memory
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        start_memory = torch.cuda.memory_allocated()

    start_time = time.time()
    progress_bar = tqdm(
        dataloader, desc=f"Epoch {epoch} ({'CP' if use_checkpointing else 'No CP'})"
    )

    for batch in progress_bar:
        batch = batch.to(device)
        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        optimizer.zero_grad()

        # Forward with checkpointing if enabled
        # The model's transformer blocks should check their use_checkpointing flag
        logits = model(inputs, use_checkpointing=use_checkpointing)

        loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

    epoch_time = time.time() - start_time

    if torch.cuda.is_available():
        peak_memory = torch.cuda.max_memory_allocated()
        memory_used = (peak_memory - start_memory) / (1024**3)  # Convert to GB
        torch.cuda.empty_cache()
    else:
        memory_used = 0

    avg_loss = total_loss / max(num_batches, 1)

    return {
        "loss": avg_loss,
        "time": epoch_time,
        "peak_memory_gb": memory_used,
        "batches": num_batches
    }



In [61]:
def experiment_gradient_checkpointing(
    model, train_loader, optimizer, criterion, device, num_epochs=1
):
    """Compare training with and without gradient checkpointing"""
    print("\nGradient Checkpointing Experiment...")

    results = {"without_cp": [], "with_cp": []}

    for use_cp in [False, True]:
        cp_str = "with_cp" if use_cp else "without_cp"
        print(f"\n{'='*50}")
        print(f"{'With' if use_cp else 'Without'} Gradient Checkpointing")
        print(f"{'='*50}")

        # Reset model
        model_state = model.state_dict()
        opt_state = optimizer.state_dict()

        epoch_results = []

        for epoch in range(1, num_epochs + 1):
            result = train_with_checkpointing(
                model,
                train_loader,
                optimizer,
                criterion,
                device,
                use_checkpointing=use_cp,
                epoch=epoch,
            )

            epoch_results.append(result)

            print(f"Epoch {epoch}:")
            print(f"  Loss: {result['loss']:.4f}")
            print(f"  Time: {result['time']:.2f}s")
            print(f"  Peak Memory: {result['peak_memory_gb']:.2f} GB")

        results[cp_str] = epoch_results

        # Restore state
        model.load_state_dict(model_state)
        optimizer.load_state_dict(opt_state)

    return results



In [62]:

def run_checkpoint_experiment(
    model,
    train_dataloader,
    val_dataloader,
    device,
    num_epochs=3,
    learning_rate=1e-4,
    output_dir="result"
):
    """
    Run checkpoint comparison experiment and save results

    Args:
        model: The transformer model to train
        train_dataloader: Training data loader
        val_dataloader: Validation data loader
        device: torch device
        num_epochs: Number of epochs to train
        learning_rate: Learning rate for optimizer
        output_dir: Directory to save results
    """

    # Create output directory
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    config = {
        "num_epochs": num_epochs,
        "learning_rate": learning_rate,
        "device": str(device),
        "model_params": sum(p.numel() for p in model.parameters()),
    }

    criterion = nn.CrossEntropyLoss()

    print("=" * 60)
    print("Running Training WITHOUT Gradient Checkpointing")
    print("=" * 60)

    # Train without checkpointing
    results_no_cp = {
        "config": config,
        "checkpointing_enabled": False,
        "epochs": []
    }

    model_no_cp = model
    optimizer_no_cp = torch.optim.Adam(model_no_cp.parameters(), lr=learning_rate)

    for epoch in range(1, num_epochs + 1):
        train_stats = train_with_checkpointing(
            model_no_cp,
            train_dataloader,
            optimizer_no_cp,
            criterion,
            device,
            use_checkpointing=False,
            epoch=epoch
        )

        # Validation
        model_no_cp.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                batch = batch.to(device)
                inputs = batch[:, :-1]
                targets = batch[:, 1:]
                logits = model_no_cp(inputs, use_checkpointing=False)
                loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
                val_loss += loss.item()

        val_loss /= len(val_dataloader)

        epoch_results = {
            "epoch": epoch,
            "train_loss": train_stats["loss"],
            "val_loss": val_loss,
            "time_seconds": train_stats["time"],
            "peak_memory_gb": train_stats["peak_memory_gb"]
        }

        results_no_cp["epochs"].append(epoch_results)

        print(f"\nEpoch {epoch} (No CP): Train Loss={train_stats['loss']:.4f}, "
              f"Val Loss={val_loss:.4f}, Time={train_stats['time']:.2f}s, "
              f"Memory={train_stats['peak_memory_gb']:.2f}GB\n")

    # Calculate and add summary for no checkpoint
    avg_memory_no_cp = sum(r["peak_memory_gb"] for r in results_no_cp["epochs"]) / num_epochs
    avg_time_no_cp = sum(r["time_seconds"] for r in results_no_cp["epochs"]) / num_epochs

    results_no_cp["summary"] = {
        "avg_train_loss": sum(r["train_loss"] for r in results_no_cp["epochs"]) / num_epochs,
        "avg_val_loss": sum(r["val_loss"] for r in results_no_cp["epochs"]) / num_epochs,
        "avg_time_seconds": avg_time_no_cp,
        "avg_memory_gb": avg_memory_no_cp,
        "total_time_seconds": sum(r["time_seconds"] for r in results_no_cp["epochs"])
    }

    # Save no checkpoint results
    no_cp_path = os.path.join(output_dir, "no_checkpoint.json")
    with open(no_cp_path, 'w') as f:
        json.dump(results_no_cp, f, indent=2)
    print(f"\n✓ No checkpoint results saved to: {no_cp_path}\n")



    print("\n" + "=" * 60)
    print("Running Training WITH Gradient Checkpointing")
    print("=" * 60)

    # Train with checkpointing
    results_cp = {
        "config": config,
        "checkpointing_enabled": True,
        "epochs": []
    }

    print("=" * 60)
    print("Running Training WITH Gradient Checkpointing")
    print("=" * 60)
    # Reinitialize model for fair comparison
    model_cp = type(model)(**model.get_config()) if hasattr(model, 'get_config') else model
    model_cp = model_cp.to(device)
    optimizer_cp = torch.optim.Adam(model_cp.parameters(), lr=learning_rate)

    for epoch in range(1, num_epochs + 1):
        train_stats = train_with_checkpointing(
            model_cp,
            train_dataloader,
            optimizer_cp,
            criterion,
            device,
            use_checkpointing=True,
            epoch=epoch
        )

        # Validation
        model_cp.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                batch = batch.to(device)
                inputs = batch[:, :-1]
                targets = batch[:, 1:]
                logits = model_cp(inputs, use_checkpointing=False)
                loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
                val_loss += loss.item()

        val_loss /= len(val_dataloader)

        epoch_results = {
            "epoch": epoch,
            "train_loss": train_stats["loss"],
            "val_loss": val_loss,
            "time_seconds": train_stats["time"],
            "peak_memory_gb": train_stats["peak_memory_gb"]
        }

        results_cp["epochs"].append(epoch_results)

        print(f"\nEpoch {epoch} (CP): Train Loss={train_stats['loss']:.4f}, "
              f"Val Loss={val_loss:.4f}, Time={train_stats['time']:.2f}s, "
              f"Memory={train_stats['peak_memory_gb']:.2f}GB\n")

    # Calculate and add summary for checkpoint
    avg_memory_cp = sum(r["peak_memory_gb"] for r in results_cp["epochs"]) / num_epochs
    avg_time_cp = sum(r["time_seconds"] for r in results_cp["epochs"]) / num_epochs

    results_cp["summary"] = {
        "avg_train_loss": sum(r["train_loss"] for r in results_cp["epochs"]) / num_epochs,
        "avg_val_loss": sum(r["val_loss"] for r in results_cp["epochs"]) / num_epochs,
        "avg_time_seconds": avg_time_cp,
        "avg_memory_gb": avg_memory_cp,
        "total_time_seconds": sum(r["time_seconds"] for r in results_cp["epochs"])
    }

    # Save checkpoint results
    cp_path = os.path.join(output_dir, "checkpoint.json")
    with open(cp_path, 'w') as f:
        json.dump(results_cp, f, indent=2)
    print(f"\n✓ Checkpoint results saved to: {cp_path}\n")
    torch.save(model.state_dict(), 'model_with_cp.pt')



    # Create comparison summary
    comparison = {
        "config": config,
        "without_checkpointing": results_no_cp["summary"],
        "with_checkpointing": results_cp["summary"],
        "comparison": {
            "memory_savings_gb": avg_memory_no_cp - avg_memory_cp,
            "memory_savings_percent": ((avg_memory_no_cp - avg_memory_cp) / avg_memory_no_cp * 100) if avg_memory_no_cp > 0 else 0,
            "time_overhead_seconds": avg_time_cp - avg_time_no_cp,
            "time_overhead_percent": ((avg_time_cp - avg_time_no_cp) / avg_time_no_cp * 100) if avg_time_no_cp > 0 else 0
        }
    }

    # Save comparison results
    comparison_path = os.path.join(output_dir, "result.json")
    with open(comparison_path, 'w') as f:
        json.dump(comparison, f, indent=2)

    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"Memory Savings: {comparison['comparison']['memory_savings_gb']:.2f} GB "
          f"({comparison['comparison']['memory_savings_percent']:.1f}%)")
    print(f"Time Overhead: {comparison['comparison']['time_overhead_seconds']:.2f}s "
          f"({comparison['comparison']['time_overhead_percent']:.1f}%)")
    print(f"\n✓ Comparison summary saved to: {comparison_path}")
    print(f"\nAll results saved in: {output_dir}/")
    print(f"  - no_checkpoint.json")
    print(f"  - checkpoint.json")
    print(f"  - result.json (comparison)")

    return comparison


results = run_checkpoint_experiment(
        model=model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        device=device,
        num_epochs=1,
        learning_rate=1e-4,
        output_dir="result"
    )


Running Training WITHOUT Gradient Checkpointing


Epoch 1 (No CP):   0%|          | 0/96356 [00:00<?, ?it/s]

Epoch 1 (No CP):   1%|          | 669/96356 [00:21<51:16, 31.10it/s, loss=3.5400] 


KeyboardInterrupt: 