In [1]:
!pip install matplotlib




In [5]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer
import math
from tqdm import tqdm
import numpy as np 

# GPT CLASS Inspired by GPT architecture. GPT 2 in particular
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, activation_fn):
        super().__init__()
        self.block_size = block_size
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head, activation_fn) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        
        for block in self.blocks:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)), 
                targets.view(-1),
                ignore_index=-1
            )
        
        return logits, loss

class Block(nn.Module):
    def __init__(self, n_embd, n_head, activation_fn):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = nn.MultiheadAttention(n_embd, n_head, batch_first=True)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = MLP(n_embd, activation_fn)
        
    def forward(self, x):
        x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
        x = x + self.mlp(self.ln2(x))
        return x

class MLP(nn.Module):
    def __init__(self, n_embd, activation_fn):
        super().__init__()
        self.c_fc = nn.Linear(n_embd, 4 * n_embd)
        self.act = activation_fn
        self.c_proj = nn.Linear(4 * n_embd, n_embd)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, block_size):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.examples = []
        
        for text in tqdm(texts):
            if text.strip():
                tokens = tokenizer.encode(text, truncation=True, max_length=block_size+1)
                if len(tokens) > 1:
                    self.examples.append(tokens)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        tokens = self.examples[idx]
        if len(tokens) < self.block_size + 1:
            tokens = tokens + [self.tokenizer.eos_token_id] * (self.block_size + 1 - len(tokens))
        else:
            tokens = tokens[:self.block_size + 1]
        
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        return x, y

def analyze_gradient_flow(model):
    layers = []
    avg_grads = []
    max_grads = []
    
    for name, param in model.named_parameters():
        if param.requires_grad and "weight" in name and param.grad is not None:
            layers.append(name)
            avg_grads.append(param.grad.abs().mean().item())
            max_grads.append(param.grad.abs().max().item())
    
    return layers, avg_grads, max_grads

def train_model_with_gradient_analysis(model, train_loader, val_loader, device, epochs=5, lr=3e-4):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    train_losses = []
    val_losses = []
    gradient_stats = []
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch_idx, (x, y) in enumerate(progress_bar):
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            logits, loss = model(x, y)
            loss.backward()
            
            if epoch == 0 and batch_idx < 10:
                layers, avg_grads, max_grads = analyze_gradient_flow(model)
                gradient_stats.append({
                    'epoch': epoch,
                    'batch': batch_idx,
                    'avg_grads': avg_grads.copy()
                })
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
            
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits, loss = model(x, y)
                total_val_loss += loss.item()
        
        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
    
    return train_losses, val_losses, gradient_stats

if __name__ == "__main__":
    device = 'cpu'
    
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    block_size = 256
    train_dataset = TextDataset(dataset['train']['text'][:1000], tokenizer, block_size)
    val_dataset = TextDataset(dataset['validation']['text'][:200], tokenizer, block_size)
    
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)
    
    activations = {
        'ReLU': nn.ReLU(),
        'GELU': nn.GELU(),
        'SiLU': nn.SiLU(),
        'Mish': nn.Mish()
    }
    
    results = {}

    
    for name, act_fn in activations.items():

        
        print("Current model is ", name)
        
        model = MiniGPT(
            vocab_size=50257,
            n_embd=256,
            n_head=4,
            n_layer=4,
            block_size=256,
            activation_fn=act_fn
        )
        
        train_losses, val_losses, gradient_stats = train_model_with_gradient_analysis(model, train_loader, val_loader, device, epochs=5)
        
        results[name] = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'gradient_stats': gradient_stats,
            'model': model
        }
        
        torch.save(model.state_dict(), f"model_{name}.pt")
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    for name, data in results.items():
        plt.plot(data['train_losses'], label=name, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Train Loss')
    plt.title('Training Loss Comparison')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    for name, data in results.items():
        plt.plot(data['val_losses'], label=name, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss')
    plt.title('Validation Loss Comparison')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('activation_comparison.png', dpi=150)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    for idx, (name, model_data) in enumerate(results.items()):
        ax = axes[idx // 2, idx % 2]
        
        grad_stats = model_data['gradient_stats']
        
        if len(grad_stats) > 0:
            avg_by_layer = np.mean([s['avg_grads'] for s in grad_stats], axis=0)
            
            ax.bar(range(len(avg_by_layer)), avg_by_layer)
            ax.set_title(f'{name} - Gradient Flow')
            ax.set_xlabel('Layer Index')
            ax.set_ylabel('Average Gradient Magnitude')
            ax.set_yscale('log')
            ax.grid(True, alpha=0.3)
        else:
            ax.text(0.5, 0.5, 'No gradient data', ha='center', va='center')
    
    plt.tight_layout()
    plt.savefig('gradient_flow_comparison.png', dpi=150)

100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 3207.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 4157.12it/s]


Current model is  ReLU


Epoch 1/5: 100%|██████████████████████████████████████████████████████████| 162/162 [00:43<00:00,  3.69it/s, loss=3.97]


Epoch 1: Train Loss = 3.0651, Val Loss = 2.8021


Epoch 2/5:  51%|█████████████████████████████▊                             | 82/162 [00:23<00:22,  3.54it/s, loss=1.17]


KeyboardInterrupt: 