In [None]:
%pip install evaluate gensim

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting gensim
  Downloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy>=1.17 (from evaluate)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.w

In [None]:
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os , re , json
import pickle
import hashlib , math
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import gensim.downloader as api
import time
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict , Counter
from pathlib import Path


In [None]:

class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, *args):
        ctx.run_function = run_function
        ctx.save_for_backward(*args)

        with torch.no_grad():
            outputs = run_function(*args)

        return outputs

    @staticmethod
    def backward(ctx, *grad_outputs):
        # Retrieve saved tensors
        inputs = ctx.saved_tensors

        # Recompute forward pass with gradients
        with torch.enable_grad():
            detached_inputs = [x.detach().requires_grad_(True) if isinstance(x, torch.Tensor)
                             else x for x in inputs]
            outputs = ctx.run_function(*detached_inputs)

        # Handle both single tensor and tuple outputs
        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        # Ensure grad_outputs is a tuple
        if not isinstance(grad_outputs, tuple):
            grad_outputs = (grad_outputs,)

        # Compute gradients only for tensors that require grad
        tensors_with_grad = [out for out in outputs if isinstance(out, torch.Tensor) and out.requires_grad]
        grad_tensors = grad_outputs[:len(tensors_with_grad)]

        if tensors_with_grad:
            torch.autograd.backward(tensors_with_grad, grad_tensors)

        # Collect gradients
        grads = tuple(x.grad if isinstance(x, torch.Tensor) and x.grad is not None else None
                     for x in detached_inputs)

        return (None,) + grads



In [None]:

def checkpoint_function(function, *args):
    """Apply gradient checkpointing to a function"""
    return CheckpointFunction.apply(function, *args)



In [None]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta


In [None]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1), :]


In [None]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        # assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_o(context)
        return output, attn_weights



In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [None]:
class CheckpointedTransformerBlock(nn.Module):
    """Transformer block with manual gradient checkpointing"""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout

        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = LayerNorm(d_model)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def attention_forward(self, x, mask):
        """Separate attention computation for checkpointing"""
        attn_output, attn_weights = self.attention(x, mask)
        return attn_output

    def feedforward_forward(self, x):
        """Separate feedforward computation for checkpointing"""
        return self.feed_forward(x)

    def forward(self, x, mask=None, use_checkpointing=False):
        if use_checkpointing and self.training:
            # Manual checkpointing: recompute in backward pass
            attn_output = checkpoint_function(self.attention_forward, x, mask)
        else:
            attn_output = self.attention_forward(x, mask)

        x = self.norm1(x + self.dropout(attn_output))

        if use_checkpointing and self.training:
            ff_output = checkpoint_function(self.feedforward_forward, x)
        else:
            ff_output = self.feedforward_forward(x)

        x = self.norm2(x + self.dropout(ff_output))

        return x, None


In [None]:
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads,
                 d_ff, max_seq_len, dropout=0.1, pretrained_embeddings=None):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout_p = dropout

        # Store the original embedding dimension (300 for FastText)
        self.embedding_dim = pretrained_embeddings.shape[1] if pretrained_embeddings is not None else d_model

        # Create embedding layer with original FastText dimension
        self.embedding = nn.Embedding(vocab_size, self.embedding_dim)

        # Load pretrained embeddings if provided
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
            # Add projection layer to convert from FastText dim to d_model
            self.embedding_proj = nn.Linear(self.embedding_dim, d_model)
        else:
            self.embedding_proj = nn.Identity()


        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        self.layers = nn.ModuleList([
            CheckpointedTransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_causal_mask(self, seq_len, device):
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        mask = mask.unsqueeze(0).unsqueeze(0)
        return mask

    def forward(self, x, return_attention=False, use_checkpointing=False):
        """
        Forward pass with optional gradient checkpointing

        Args:
            x: Input tensor [batch_size, seq_len]
            return_attention: Whether to return attention weights
            use_checkpointing: Whether to use gradient checkpointing in transformer blocks
        """
        batch_size, seq_len = x.shape
        mask = self.create_causal_mask(seq_len, x.device)

        # Get embeddings in original dimension (300)
        x = self.embedding(x) * math.sqrt(self.embedding_dim)

        # Project to d_model (which is divisible by num_heads)
        x = self.embedding_proj(x)

        x = self.pos_encoding(x)
        x = self.dropout(x)

        attention_weights = []

        # Pass use_checkpointing to each transformer block
        for layer in self.layers:
            x, attn_weights = layer(x, mask, use_checkpointing=use_checkpointing)
            if return_attention:
                attention_weights.append(attn_weights)

        x = self.norm(x)
        logits = self.output_projection(x)

        if return_attention:
            return logits, attention_weights
        return logits

    def get_config(self):
        """Return model configuration for reinitialization"""
        return {
            'vocab_size': self.vocab_size,
            'd_model': self.d_model,
            'num_layers': self.num_layers,
            'num_heads': self.num_heads,
            'd_ff': self.d_ff,
            'max_seq_len': self.max_seq_len,
            'dropout': self.dropout_p,
            'pretrained_embeddings': None  # Don't reuse pretrained embeddings in experiments
        }

In [None]:

class Vocabulary:
    def __init__(self, fasttext_model=None):
        self.word2idx = {}
        self.idx2word = {}
        self.word_counts = Counter()
        self.PAD_TOKEN = '<pad>'
        self.SOS_TOKEN = '<sos>'
        self.EOS_TOKEN = '<eos>'
        self.UNK_TOKEN = '<unk>'
        self.add_word(self.PAD_TOKEN)
        self.add_word(self.SOS_TOKEN)
        self.add_word(self.EOS_TOKEN)
        self.add_word(self.UNK_TOKEN)
        self.fasttext_model = fasttext_model

    def add_word(self, word):
        if word not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word
        self.word_counts[word] += 1

    def __len__(self):
        return len(self.word2idx)

    def encode(self, text):
        tokens = self.tokenize(text)
        return [self.word2idx.get(token, self.word2idx[self.UNK_TOKEN])
                for token in tokens]

    def decode(self, indices):
        words = []
        for idx in indices:
            if idx in [self.word2idx[self.PAD_TOKEN], self.word2idx[self.SOS_TOKEN]]:
                continue
            if idx == self.word2idx[self.EOS_TOKEN]:
                break
            words.append(self.idx2word.get(idx, self.UNK_TOKEN))
        return ' '.join(words)

    def tokenize(self, text):
        text = text.lower()
        tokens = re.findall(r'\b\w+\b|[.,!?;]', text)
        return tokens

    def create_embedding_matrix(self):
        embedding_matrix = torch.randn(len(self.word2idx), 300) * 0.01
        if self.fasttext_model is not None:
            found = 0
            for word, idx in self.word2idx.items():
                if word in self.fasttext_model:
                    embedding_matrix[idx] = torch.tensor(self.fasttext_model[word])
                    found += 1
            print(f"Found {found}/{len(self.word2idx)} words in FastText")
        return embedding_matrix

    def save(self, path):
        with open(path, 'w') as f:
            json.dump({
                'word2idx': self.word2idx,
                'idx2word': {int(k): v for k, v in self.idx2word.items()},
                'word_counts': dict(self.word_counts)
            }, f)

    @classmethod
    def load(cls, path, fasttext_model=None):
        vocab = cls(fasttext_model)
        with open(path, 'r') as f:
            data = json.load(f)
        vocab.word2idx = data['word2idx']
        vocab.idx2word = {int(k): v for k, v in data['idx2word'].items()}
        vocab.word_counts = Counter(data['word_counts'])
        return vocab


In [None]:
class TinyStoriesDataset(Dataset):
    def __init__(self, texts, vocab, context_length, max_samples=None):
        self.vocab = vocab
        self.context_length = context_length
        self.sequences = []

        print("Preparing dataset...")
        for idx, text in enumerate(tqdm(texts)):
            if max_samples and idx >= max_samples:
                break

            tokens = [vocab.word2idx[vocab.SOS_TOKEN]] + vocab.encode(text) + [vocab.word2idx[vocab.EOS_TOKEN]]

            for i in range(len(tokens) - 1):
                end_idx = min(i + context_length + 1, len(tokens))
                seq = tokens[i:end_idx]

                if len(seq) < context_length + 1:
                    seq = seq + [vocab.word2idx[vocab.PAD_TOKEN]] * (context_length + 1 - len(seq))

                self.sequences.append(seq)

        print(f"Created {len(self.sequences)} sequences")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.long)



In [None]:
CONFIG = {
    'name': 'baseline',
    'description': 'Standard baseline configuration from assignment',
    'context_length': 64,
    'num_layers': 3,
    'num_heads': 8,
    'd_model': 296,
    'd_ff': 1184,
    'dropout': 0.1,
    'batch_size': 32,
    'learning_rate': 3e-4,
    'num_epochs': 10,
    'max_train_samples': 15000,
    'max_val_samples': 5000,
    'save_dir': 'checkpoints/baseline',
    'plot_dir': 'plots/baseline'
}

In [None]:
os.makedirs(CONFIG['save_dir'], exist_ok=True)
os.makedirs(CONFIG['plot_dir'], exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"\nConfiguration:")
for k, v in CONFIG.items():
     print(f"  {k}: {v}")

    # Load FastText
print("\n" + "="*50)
print("Loading FastText embeddings...")
print("="*50)
fasttext_model = api.load('fasttext-wiki-news-subwords-300')

    # Load Dataset
print("\n" + "="*50)
print("Loading TinyStories dataset...")
print("="*50)
dataset = load_dataset("roneneldan/TinyStories")

print("\n" + "="*50)
print("Building vocabulary...")
print("="*50)
vocab_path = f"{CONFIG['save_dir']}/vocab.json"

if os.path.exists(vocab_path):
        print("Loading existing vocabulary...")
        vocab = Vocabulary.load(vocab_path, fasttext_model)
else:
    vocab = Vocabulary(fasttext_model)
    # Build vocabulary from training data
    num_samples = min(CONFIG['max_train_samples'], len(dataset['train']))
    for i in tqdm(range(num_samples), desc="Building vocabulary"):
        text = dataset['train'][i]['text']
        for word in vocab.tokenize(text):
            vocab.add_word(word)
    vocab.save(vocab_path)

print(f"Vocabulary size: {len(vocab)}")

# Create Datasets
print("\n" + "="*50)
print("Creating datasets...")
print("="*50)

# Prepare train texts
train_texts = [dataset['train'][i]['text'] for i in range(min(CONFIG['max_train_samples'], len(dataset['train'])))]
val_texts = [dataset['validation'][i]['text'] for i in range(min(CONFIG['max_val_samples'], len(dataset['validation'])))]

train_dataset = TinyStoriesDataset(
    train_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_train_samples']
)

val_dataset = TinyStoriesDataset(
    val_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_val_samples']
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                         shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                       shuffle=False, num_workers=0)

# Initialize Model
print("\n" + "="*50)
print("Initializing model...")
print("="*50)
embedding_matrix = vocab.create_embedding_matrix()


Using device: cuda

Configuration:
  name: baseline
  description: Standard baseline configuration from assignment
  context_length: 64
  num_layers: 3
  num_heads: 8
  d_model: 296
  d_ff: 1184
  dropout: 0.1
  batch_size: 32
  learning_rate: 0.0003
  num_epochs: 10
  max_train_samples: 15000
  max_val_samples: 5000
  save_dir: checkpoints/baseline
  plot_dir: plots/baseline

Loading FastText embeddings...

Loading TinyStories dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]


Building vocabulary...


Building vocabulary: 100%|██████████| 15000/15000 [00:02<00:00, 5373.11it/s]


Vocabulary size: 10598

Creating datasets...
Preparing dataset...


100%|██████████| 15000/15000 [00:11<00:00, 1333.26it/s]


Created 3083375 sequences
Preparing dataset...


100%|██████████| 5000/5000 [00:02<00:00, 1677.80it/s]


Created 925828 sequences

Initializing model...
Found 9972/10598 words in FastText


In [None]:

model = DecoderTransformer(
    vocab_size=len(vocab),
    d_model=CONFIG['d_model'],
    num_layers=CONFIG['num_layers'],
    num_heads=CONFIG['num_heads'],
    d_ff=CONFIG['d_ff'],
    max_seq_len=CONFIG['context_length'],
    dropout=CONFIG['dropout'],
    pretrained_embeddings=embedding_matrix
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[vocab.PAD_TOKEN])

In [None]:
def train_with_checkpointing(
    model, dataloader, optimizer, criterion, device, use_checkpointing=False, epoch=1
):
    """Train with optional gradient checkpointing"""
    model.train()
    total_loss = 0
    num_batches = 0

    # Measure memory
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        start_memory = torch.cuda.memory_allocated()

    start_time = time.time()
    progress_bar = tqdm(
        dataloader, desc=f"Epoch {epoch} ({'CP' if use_checkpointing else 'No CP'})"
    )

    for batch in progress_bar:
        batch = batch.to(device)
        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        optimizer.zero_grad()

        # Forward with checkpointing if enabled
        # The model's transformer blocks should check their use_checkpointing flag
        logits = model(inputs, use_checkpointing=use_checkpointing)

        loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

    epoch_time = time.time() - start_time

    if torch.cuda.is_available():
        peak_memory = torch.cuda.max_memory_allocated()
        memory_used = (peak_memory - start_memory) / (1024**3)  # Convert to GB
        torch.cuda.empty_cache()
    else:
        memory_used = 0

    avg_loss = total_loss / max(num_batches, 1)

    return {
        "loss": avg_loss,
        "time": epoch_time,
        "peak_memory_gb": memory_used,
        "batches": num_batches
    }



In [None]:
def experiment_gradient_checkpointing(
    model, train_loader, optimizer, criterion, device, num_epochs=1
):
    """Compare training with and without gradient checkpointing"""
    print("\nGradient Checkpointing Experiment...")

    results = {"without_cp": [], "with_cp": []}

    for use_cp in [False, True]:
        cp_str = "with_cp" if use_cp else "without_cp"
        print(f"\n{'='*50}")
        print(f"{'With' if use_cp else 'Without'} Gradient Checkpointing")
        print(f"{'='*50}")

        # Reset model
        model_state = model.state_dict()
        opt_state = optimizer.state_dict()

        epoch_results = []

        for epoch in range(1, num_epochs + 1):
            result = train_with_checkpointing(
                model,
                train_loader,
                optimizer,
                criterion,
                device,
                use_checkpointing=use_cp,
                epoch=epoch,
            )

            epoch_results.append(result)

            print(f"Epoch {epoch}:")
            print(f"  Loss: {result['loss']:.4f}")
            print(f"  Time: {result['time']:.2f}s")
            print(f"  Peak Memory: {result['peak_memory_gb']:.2f} GB")

        results[cp_str] = epoch_results

        # Restore state
        model.load_state_dict(model_state)
        optimizer.load_state_dict(opt_state)

    return results



In [None]:

def run_checkpoint_experiment(
    model,
    train_dataloader,
    val_dataloader,
    device,
    num_epochs=3,
    learning_rate=1e-4,
    output_dir="result"
):
    """
    Run checkpoint comparison experiment and save results

    Args:
        model: The transformer model to train
        train_dataloader: Training data loader
        val_dataloader: Validation data loader
        device: torch device
        num_epochs: Number of epochs to train
        learning_rate: Learning rate for optimizer
        output_dir: Directory to save results
    """

    # Create output directory
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    config = {
        "num_epochs": num_epochs,
        "learning_rate": learning_rate,
        "device": str(device),
        "model_params": sum(p.numel() for p in model.parameters()),
    }

    criterion = nn.CrossEntropyLoss()

    print("=" * 60)
    print("Running Training WITHOUT Gradient Checkpointing")
    print("=" * 60)

    # Train without checkpointing
    results_no_cp = {
        "config": config,
        "checkpointing_enabled": False,
        "epochs": []
    }

    model_no_cp = model
    optimizer_no_cp = torch.optim.Adam(model_no_cp.parameters(), lr=learning_rate)

    for epoch in range(1, num_epochs + 1):
        train_stats = train_with_checkpointing(
            model_no_cp,
            train_dataloader,
            optimizer_no_cp,
            criterion,
            device,
            use_checkpointing=False,
            epoch=epoch
        )

        # Validation
        model_no_cp.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                batch = batch.to(device)
                inputs = batch[:, :-1]
                targets = batch[:, 1:]
                logits = model_no_cp(inputs, use_checkpointing=False)
                loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
                val_loss += loss.item()

        val_loss /= len(val_dataloader)

        epoch_results = {
            "epoch": epoch,
            "train_loss": train_stats["loss"],
            "val_loss": val_loss,
            "time_seconds": train_stats["time"],
            "peak_memory_gb": train_stats["peak_memory_gb"]
        }

        results_no_cp["epochs"].append(epoch_results)

        print(f"\nEpoch {epoch} (No CP): Train Loss={train_stats['loss']:.4f}, "
              f"Val Loss={val_loss:.4f}, Time={train_stats['time']:.2f}s, "
              f"Memory={train_stats['peak_memory_gb']:.2f}GB\n")

    # Calculate and add summary for no checkpoint
    avg_memory_no_cp = sum(r["peak_memory_gb"] for r in results_no_cp["epochs"]) / num_epochs
    avg_time_no_cp = sum(r["time_seconds"] for r in results_no_cp["epochs"]) / num_epochs

    results_no_cp["summary"] = {
        "avg_train_loss": sum(r["train_loss"] for r in results_no_cp["epochs"]) / num_epochs,
        "avg_val_loss": sum(r["val_loss"] for r in results_no_cp["epochs"]) / num_epochs,
        "avg_time_seconds": avg_time_no_cp,
        "avg_memory_gb": avg_memory_no_cp,
        "total_time_seconds": sum(r["time_seconds"] for r in results_no_cp["epochs"])
    }

    # Save no checkpoint results
    no_cp_path = os.path.join(output_dir, "no_checkpoint.json")
    with open(no_cp_path, 'w') as f:
        json.dump(results_no_cp, f, indent=2)
    print(f"\n✓ No checkpoint results saved to: {no_cp_path}\n")



    print("\n" + "=" * 60)
    print("Running Training WITH Gradient Checkpointing")
    print("=" * 60)

    # Train with checkpointing
    results_cp = {
        "config": config,
        "checkpointing_enabled": True,
        "epochs": []
    }

    print("=" * 60)
    print("Running Training WITH Gradient Checkpointing")
    print("=" * 60)
    # Reinitialize model for fair comparison
    model_cp = type(model)(**model.get_config()) if hasattr(model, 'get_config') else model
    model_cp = model_cp.to(device)
    optimizer_cp = torch.optim.Adam(model_cp.parameters(), lr=learning_rate)

    for epoch in range(1, num_epochs + 1):
        train_stats = train_with_checkpointing(
            model_cp,
            train_dataloader,
            optimizer_cp,
            criterion,
            device,
            use_checkpointing=True,
            epoch=epoch
        )

        # Validation
        model_cp.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                batch = batch.to(device)
                inputs = batch[:, :-1]
                targets = batch[:, 1:]
                logits = model_cp(inputs, use_checkpointing=False)
                loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
                val_loss += loss.item()

        val_loss /= len(val_dataloader)

        epoch_results = {
            "epoch": epoch,
            "train_loss": train_stats["loss"],
            "val_loss": val_loss,
            "time_seconds": train_stats["time"],
            "peak_memory_gb": train_stats["peak_memory_gb"]
        }

        results_cp["epochs"].append(epoch_results)

        print(f"\nEpoch {epoch} (CP): Train Loss={train_stats['loss']:.4f}, "
              f"Val Loss={val_loss:.4f}, Time={train_stats['time']:.2f}s, "
              f"Memory={train_stats['peak_memory_gb']:.2f}GB\n")

    # Calculate and add summary for checkpoint
    avg_memory_cp = sum(r["peak_memory_gb"] for r in results_cp["epochs"]) / num_epochs
    avg_time_cp = sum(r["time_seconds"] for r in results_cp["epochs"]) / num_epochs

    results_cp["summary"] = {
        "avg_train_loss": sum(r["train_loss"] for r in results_cp["epochs"]) / num_epochs,
        "avg_val_loss": sum(r["val_loss"] for r in results_cp["epochs"]) / num_epochs,
        "avg_time_seconds": avg_time_cp,
        "avg_memory_gb": avg_memory_cp,
        "total_time_seconds": sum(r["time_seconds"] for r in results_cp["epochs"])
    }

    # Save checkpoint results
    cp_path = os.path.join(output_dir, "checkpoint.json")
    with open(cp_path, 'w') as f:
        json.dump(results_cp, f, indent=2)
    print(f"\n✓ Checkpoint results saved to: {cp_path}\n")
    torch.save(model.state_dict(), 'model_with_cp.pt')



    # Create comparison summary
    comparison = {
        "config": config,
        "without_checkpointing": results_no_cp["summary"],
        "with_checkpointing": results_cp["summary"],
        "comparison": {
            "memory_savings_gb": avg_memory_no_cp - avg_memory_cp,
            "memory_savings_percent": ((avg_memory_no_cp - avg_memory_cp) / avg_memory_no_cp * 100) if avg_memory_no_cp > 0 else 0,
            "time_overhead_seconds": avg_time_cp - avg_time_no_cp,
            "time_overhead_percent": ((avg_time_cp - avg_time_no_cp) / avg_time_no_cp * 100) if avg_time_no_cp > 0 else 0
        }
    }

    # Save comparison results
    comparison_path = os.path.join(output_dir, "result.json")
    with open(comparison_path, 'w') as f:
        json.dump(comparison, f, indent=2)

    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"Memory Savings: {comparison['comparison']['memory_savings_gb']:.2f} GB "
          f"({comparison['comparison']['memory_savings_percent']:.1f}%)")
    print(f"Time Overhead: {comparison['comparison']['time_overhead_seconds']:.2f}s "
          f"({comparison['comparison']['time_overhead_percent']:.1f}%)")
    print(f"\n✓ Comparison summary saved to: {comparison_path}")
    print(f"\nAll results saved in: {output_dir}/")
    print(f"  - no_checkpoint.json")
    print(f"  - checkpoint.json")
    print(f"  - result.json (comparison)")

    return comparison


results = run_checkpoint_experiment(
        model=model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        device=device,
        num_epochs=1,
        learning_rate=1e-4,
        output_dir="result"
    )


Running Training WITHOUT Gradient Checkpointing


Epoch 1 (No CP): 100%|██████████| 96356/96356 [1:04:56<00:00, 24.73it/s, loss=1.6331]



Epoch 1 (No CP): Train Loss=2.0136, Val Loss=2.2524, Time=3896.02s, Memory=0.60GB


✓ No checkpoint results saved to: result/no_checkpoint.json


Running Training WITH Gradient Checkpointing
Running Training WITH Gradient Checkpointing


Epoch 1 (CP):   7%|▋         | 7144/96356 [05:29<1:08:20, 21.75it/s, loss=3.2971]