# Part 2: Model Training (All 8 Models)

## Installation Requirements

In [1]:
# Install PyTorch and other dependencies
!pip install torch torchvision numpy matplotlib tqdm scikit-learn




## Import Libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm
import os
import time
import glob

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Check GPU memory
if torch.cuda.is_available():
    free_mem, total_mem = torch.cuda.mem_get_info()
    print(f"GPU Memory - Free: {free_mem/1e9:.1f} GB / Total: {total_mem/1e9:.1f} GB")


Using device: cuda
GPU Memory - Free: 42.0 GB / Total: 42.4 GB


## Load Preprocessed Data

In [3]:
# Load Wikipedia data
print("Loading Wikipedia dataset...")
with open('wikipedia_processed.pkl', 'rb') as f:
    wiki_data = pickle.load(f)

wiki_X = wiki_data['X']
wiki_y = wiki_data['y']
wiki_vocab_size = wiki_data['vocab_size']
wiki_word_to_idx = wiki_data['word_to_idx']
wiki_idx_to_word = wiki_data['idx_to_word']

print(f"Wikipedia - Vocab size: {wiki_vocab_size}, Samples: {len(wiki_X)}")

# Load Linux data
print("Loading Linux kernel dataset...")
with open('linux_processed.pkl', 'rb') as f:
    linux_data = pickle.load(f)

linux_X = linux_data['X']
linux_y = linux_data['y']
linux_vocab_size = linux_data['vocab_size']
linux_word_to_idx = linux_data['word_to_idx']
linux_idx_to_word = linux_data['idx_to_word']

print(f"Linux - Vocab size: {linux_vocab_size}, Samples: {len(linux_X)}")


Loading Wikipedia dataset...
Wikipedia - Vocab size: 78683, Samples: 674773
Loading Linux kernel dataset...
Linux - Vocab size: 90882, Samples: 779603


## Checkpoint Verification Utility

In [4]:
def verify_checkpoints():
    """List and verify all model checkpoints"""
    print("\n" + "="*80)
    print("CHECKPOINT VERIFICATION")
    print("="*80)
    
    checkpoints = sorted(glob.glob("*_best.pth"))
    
    if not checkpoints:
        print("No checkpoints found. All models need training.")
        return {}
    
    verified = {}
    print(f"\nFound {len(checkpoints)} checkpoint files:\n")
    print(f"{'Checkpoint':<40} {'Size (MB)':<12} {'Status'}")
    print("-" * 70)
    
    for ckpt_path in checkpoints:
        size_mb = os.path.getsize(ckpt_path) / 1e6
        model_name = ckpt_path.replace('_best.pth', '')
        
        try:
            # Try loading
            ckpt = torch.load(ckpt_path, map_location='cpu')
            required_keys = ['epoch', 'model_state_dict', 'val_loss', 'val_accuracy']
            
            if all(k in ckpt for k in required_keys):
                status = f"‚úÖ OK (epoch {ckpt['epoch']+1})"
                verified[model_name] = {
                    'path': ckpt_path,
                    'epoch': ckpt['epoch'],
                    'val_loss': ckpt['val_loss'],
                    'val_accuracy': ckpt['val_accuracy']
                }
            else:
                status = "‚ö†Ô∏è Missing keys"
        except Exception as e:
            status = f"‚ùå Corrupted: {str(e)[:30]}"
        
        print(f"{ckpt_path:<40} {size_mb:>10.1f}  {status}")
    
    print(f"\n‚úÖ {len(verified)}/{8} models have valid checkpoints")
    return verified

# Run verification
verified_models = verify_checkpoints()



CHECKPOINT VERIFICATION

Found 8 checkpoint files:

Checkpoint                               Size (MB)    Status
----------------------------------------------------------------------
linux_32d-ReLU_best.pth                      1154.7  ‚úÖ OK (epoch 1)
linux_32d-tanh_best.pth                      1154.7  ‚úÖ OK (epoch 2)
linux_64d-ReLU_best.pth                      1191.6  ‚úÖ OK (epoch 1)
linux_64d-tanh_best.pth                      1191.6  ‚úÖ OK (epoch 2)
wiki_32d-ReLU_best.pth                       1000.0  ‚úÖ OK (epoch 1)
wiki_32d-tanh_best.pth                       1000.0  ‚úÖ OK (epoch 1)
wiki_64d-ReLU_best.pth                       1032.2  ‚úÖ OK (epoch 1)
wiki_64d-tanh_best.pth                       1032.2  ‚úÖ OK (epoch 1)

‚úÖ 8/8 models have valid checkpoints


## Load or Initialize Results Dictionary

In [5]:
# Load previously trained models if they exist
all_models = {}

if os.path.exists('all_models_results.pkl'):
    try:
        with open('all_models_results.pkl', 'rb') as f:
            all_models = pickle.load(f)
        print(f"\n‚úÖ Loaded {len(all_models)} models from all_models_results.pkl")
        for name in all_models.keys():
            print(f"   - {name}")
    except (EOFError, pickle.UnpicklingError) as e:
        print(f"\n‚ö†Ô∏è Corrupted pickle file found: {e}")
        print("Renaming to backup and starting fresh...")
        os.rename('all_models_results.pkl', f'all_models_results_corrupted_{int(time.time())}.pkl')
        all_models = {}
else:
    print("\nNo previous all_models_results.pkl found. Starting fresh.")


‚úÖ Loaded 0 models from all_models_results.pkl


In [6]:
# Rebuild all_models from existing checkpoints if the summary file was empty/missing
import pickle, os, glob

def model_present(name):
    return os.path.exists(f"{name}_best.pth")

rebuilt = {}

# Wiki models
with open('wikipedia_processed.pkl', 'rb') as f:
    wiki_data = pickle.load(f)
for name in ["wiki_64d-ReLU","wiki_32d-ReLU","wiki_64d-tanh","wiki_32d-tanh"]:
    if model_present(name):
        rebuilt[name] = {
            'model': None,
            'final_val_loss': None,
            'final_val_accuracy': None,
            'config': {'name': name},
            'vocab_size': wiki_data['vocab_size'],
            'dataset_type': 'wikipedia'
        }

# Linux models
with open('linux_processed.pkl', 'rb') as f:
    linux_data = pickle.load(f)
for name in ["linux_64d-ReLU","linux_32d-ReLU","linux_64d-tanh","linux_32d-tanh"]:
    if model_present(name):
        rebuilt[name] = {
            'model': None,
            'final_val_loss': None,
            'final_val_accuracy': None,
            'config': {'name': name},
            'vocab_size': linux_data['vocab_size'],
            'dataset_type': 'linux'
        }

# Merge into all_models
all_models.update(rebuilt)
print(f"Rebuilt {len(rebuilt)} entries from existing checkpoints. Total in memory: {len(all_models)}")

# Save immediately so future runs load correctly
with open('all_models_results.pkl', 'wb') as f:
    pickle.dump(all_models, f)
print("Saved rebuilt all_models_results.pkl")


Rebuilt 8 entries from existing checkpoints. Total in memory: 8
Saved rebuilt all_models_results.pkl


## Dataset Class

In [7]:
class WordPredictionDataset(Dataset):
    """PyTorch Dataset for word prediction"""
    def __init__(self, X, y):
        self.X = torch.LongTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


## MLP Model Definition

In [8]:
def create_mlp_model(vocab_size, embedding_dim, hidden_size, context_length, activation='relu'):
    """
    Create MLP model for next-word prediction
    
    Args:
        vocab_size: Size of vocabulary
        embedding_dim: Dimension of word embeddings (32 or 64)
        hidden_size: Number of neurons in hidden layer (1024)
        context_length: Number of context words (5)
        activation: 'relu' or 'tanh'
    """
    
    class MLPWordPredictor(nn.Module):
        def __init__(self):
            super(MLPWordPredictor, self).__init__()
            
            # Embedding layer
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
            
            # Calculate input size to hidden layer
            input_size = context_length * embedding_dim
            
            # Hidden layer (single layer as per requirement)
            self.fc1 = nn.Linear(input_size, hidden_size)
            
            # Activation function
            if activation == 'relu':
                self.activation = nn.ReLU()
            elif activation == 'tanh':
                self.activation = nn.Tanh()
            else:
                raise ValueError("Activation must be 'relu' or 'tanh'")
            
            # Output layer
            self.fc2 = nn.Linear(hidden_size, vocab_size)
            
        def forward(self, x):
            # x shape: (batch_size, context_length)
            embedded = self.embedding(x)  # (batch_size, context_length, embedding_dim)
            
            # Flatten embeddings
            embedded = embedded.reshape(embedded.size(0), -1)  # (batch_size, context_length * embedding_dim)
            
            # Hidden layer
            hidden = self.activation(self.fc1(embedded))
            
            # Output layer
            output = self.fc2(hidden)
            
            return output
        
        def get_embeddings(self):
            """Return the embedding weights"""
            return self.embedding.weight.detach().cpu().numpy()
    
    return MLPWordPredictor()


## Training Function

In [9]:
def train_model(model, train_loader, val_loader, num_epochs, learning_rate, device, model_name):
    """
    Train the MLP model with mixed precision for speed
    
    Returns:
        model: Trained model
        train_losses: List of training losses
        val_losses: List of validation losses
        val_accuracies: List of validation accuracies
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Mixed precision scaler
    scaler = GradScaler()
    amp_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    train_losses = []
    val_losses = []
    val_accuracies = []
    
    best_val_loss = float('inf')
    
    print(f"\n{'='*80}")
    print(f"Training {model_name}")
    print(f"{'='*80}")
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_batches = 0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]', leave=False)
        for batch_X, batch_y in train_pbar:
            batch_X = batch_X.to(device, non_blocking=True)
            batch_y = batch_y.to(device, non_blocking=True)
            
            # Zero gradients
            optimizer.zero_grad(set_to_none=True)
            
            # Forward pass with mixed precision
            with autocast(device_type="cuda", dtype=amp_dtype):
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
            
            # Backward pass with scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item()
            train_batches += 1
            train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_train_loss = train_loss / train_batches
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_batches = 0
        
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                batch_X = batch_X.to(device, non_blocking=True)
                batch_y = batch_y.to(device, non_blocking=True)
                
                # Forward pass with mixed precision
                with autocast(device_type="cuda", dtype=amp_dtype):
                    outputs = model(batch_X)
                    loss = criterion(outputs, batch_y)
                
                val_loss += loss.item()
                val_batches += 1
                
                # Calculate accuracy
                _, predicted = torch.max(outputs, 1)
                val_total += batch_y.size(0)
                val_correct += (predicted == batch_y).sum().item()
        
        avg_val_loss = val_loss / val_batches
        val_accuracy = 100 * val_correct / val_total
        
        val_losses.append(avg_val_loss)
        val_accuracies.append(val_accuracy)
        
        epoch_time = time.time() - epoch_start_time
        
        print(f'Epoch [{epoch+1}/{num_epochs}] ({epoch_time:.1f}s) - '
              f'Train Loss: {avg_train_loss:.4f} | '
              f'Val Loss: {avg_val_loss:.4f} | '
              f'Val Acc: {val_accuracy:.2f}%')
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_val_loss,
                'val_accuracy': val_accuracy
            }, f'{model_name}_best.pth')
    
    print(f"\n‚úÖ Training complete! Best val loss: {best_val_loss:.4f}")
    return model, train_losses, val_losses, val_accuracies


## Prediction and Evaluation Functions

In [10]:
def generate_sample_predictions(model, dataset_X, idx_to_word, device, num_samples=5):
    """Generate sample predictions from the model"""
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for i in range(num_samples):
            idx = np.random.randint(0, len(dataset_X))
            context = torch.LongTensor([dataset_X[idx]]).to(device)
            
            amp_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
            with autocast(device_type="cuda", dtype=amp_dtype):
                output = model(context)
            _, predicted_idx = torch.max(output, 1)
            
            context_words = [idx_to_word[int(idx)] for idx in dataset_X[idx]]
            predicted_word = idx_to_word[int(predicted_idx.item())]
            
            predictions.append({
                'context': ' '.join(context_words),
                'predicted': predicted_word
            })
    
    return predictions

def plot_training_curves(train_losses, val_losses, val_accuracies, model_name):
    """Plot training and validation curves"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    ax1.plot(train_losses, label='Train Loss', linewidth=2)
    ax1.plot(val_losses, label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title(f'{model_name} - Training vs Validation Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Accuracy plot
    ax2.plot(val_accuracies, label='Val Accuracy', color='green', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title(f'{model_name} - Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{model_name}_training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"üìä Training curves saved as '{model_name}_training_curves.png'")


## Training Configuration

In [11]:
# Hyperparameters
CONTEXT_LENGTH = 5
HIDDEN_SIZE = 1024
BATCH_SIZE = 1024  # Optimized for speed
NUM_EPOCHS = 100  # Reduced for time constraints
LEARNING_RATE = 0.001
VAL_SPLIT = 0.1
NUM_WORKERS = 8  # Parallel data loading
PIN_MEMORY = True  # Faster GPU transfer

# Model configurations
model_configs = [
    {'embedding_dim': 64, 'activation': 'relu', 'name': '64d-ReLU'},
    {'embedding_dim': 32, 'activation': 'relu', 'name': '32d-ReLU'},
    {'embedding_dim': 64, 'activation': 'tanh', 'name': '64d-tanh'},
    {'embedding_dim': 32, 'activation': 'tanh', 'name': '32d-tanh'},
]

print(f"Training configuration:")
print(f"  Context length: {CONTEXT_LENGTH}")
print(f"  Hidden size: {HIDDEN_SIZE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Num workers: {NUM_WORKERS}")
print(f"  Mixed precision: ENABLED")


Training configuration:
  Context length: 5
  Hidden size: 1024
  Batch size: 1024
  Epochs: 100
  Learning rate: 0.001
  Num workers: 8
  Mixed precision: ENABLED


## Wikipedia Models Training (with Skip Logic)

In [12]:
print("\n" + "="*80)
print("TRAINING WIKIPEDIA MODELS")
print("="*80)

# Create datasets
wiki_dataset = WordPredictionDataset(wiki_X, wiki_y)
wiki_train_size = int((1 - VAL_SPLIT) * len(wiki_dataset))
wiki_val_size = len(wiki_dataset) - wiki_train_size
wiki_train_dataset, wiki_val_dataset = random_split(wiki_dataset, [wiki_train_size, wiki_val_size])

wiki_train_loader = DataLoader(wiki_train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                               num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
wiki_val_loader = DataLoader(wiki_val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                             num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

for config in model_configs:
    model_name = f"wiki_{config['name']}"
    
    # Skip if already trained (checkpoint exists and is valid)
    if model_name in verified_models:
        print(f"\n‚è≠Ô∏è  Skipping {model_name} - already trained!")
        print(f"   Checkpoint: {verified_models[model_name]['path']}")
        print(f"   Val Loss: {verified_models[model_name]['val_loss']:.4f}")
        print(f"   Val Accuracy: {verified_models[model_name]['val_accuracy']:.2f}%")
        continue
    
    print(f"\nüöÄ Starting training for {model_name}...")
    
    # Create model
    model = create_mlp_model(
        vocab_size=wiki_vocab_size,
        embedding_dim=config['embedding_dim'],
        hidden_size=HIDDEN_SIZE,
        context_length=CONTEXT_LENGTH,
        activation=config['activation']
    ).to(device)
    
    print(f"\nüìä Model: {model_name}")
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Train model
    model, train_losses, val_losses, val_accuracies = train_model(
        model, wiki_train_loader, wiki_val_loader, NUM_EPOCHS, LEARNING_RATE, device, model_name
    )
    
    # Plot training curves
    plot_training_curves(train_losses, val_losses, val_accuracies, model_name)
    
    # Generate sample predictions
    print(f"\nüéØ Sample Predictions for {model_name}:")
    predictions = generate_sample_predictions(model, wiki_X, wiki_idx_to_word, device, num_samples=5)
    for i, pred in enumerate(predictions, 1):
        print(f"   {i}. Context: '{pred['context']}' ‚Üí Predicted: '{pred['predicted']}'")
    
    # Save results
    all_models[model_name] = {
        'model': model,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies,
        'config': config,
        'final_val_loss': val_losses[-1],
        'final_val_accuracy': val_accuracies[-1],
        'vocab_size': wiki_vocab_size,
        'word_to_idx': wiki_word_to_idx,
        'idx_to_word': wiki_idx_to_word,
        'dataset_type': 'wikipedia'
    }
    
    # Incremental save after each model
    with open('all_models_results.pkl', 'wb') as f:
        pickle.dump(all_models, f)
    print(f"üíæ Saved checkpoint with {len(all_models)} models so far")
    
    print(f"\n‚úÖ {model_name} training complete!")
    print(f"   Final Val Loss: {val_losses[-1]:.4f}")
    print(f"   Final Val Accuracy: {val_accuracies[-1]:.2f}%")



TRAINING WIKIPEDIA MODELS

‚è≠Ô∏è  Skipping wiki_64d-ReLU - already trained!
   Checkpoint: wiki_64d-ReLU_best.pth
   Val Loss: 7.3232
   Val Accuracy: 13.34%

‚è≠Ô∏è  Skipping wiki_32d-ReLU - already trained!
   Checkpoint: wiki_32d-ReLU_best.pth
   Val Loss: 7.4249
   Val Accuracy: 12.37%

‚è≠Ô∏è  Skipping wiki_64d-tanh - already trained!
   Checkpoint: wiki_64d-tanh_best.pth
   Val Loss: 7.4166
   Val Accuracy: 13.23%

‚è≠Ô∏è  Skipping wiki_32d-tanh - already trained!
   Checkpoint: wiki_32d-tanh_best.pth
   Val Loss: 7.5366
   Val Accuracy: 12.30%


## Linux Kernel Models Training (with Skip Logic)

In [13]:
print("\n" + "="*80)
print("TRAINING LINUX KERNEL MODELS")
print("="*80)

# Create datasets
linux_dataset = WordPredictionDataset(linux_X, linux_y)
linux_train_size = int((1 - VAL_SPLIT) * len(linux_dataset))
linux_val_size = len(linux_dataset) - linux_train_size
linux_train_dataset, linux_val_dataset = random_split(linux_dataset, [linux_train_size, linux_val_size])

linux_train_loader = DataLoader(linux_train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                                num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
linux_val_loader = DataLoader(linux_val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

for config in model_configs:
    model_name = f"linux_{config['name']}"
    
    # Skip if already trained
    if model_name in verified_models:
        print(f"\n‚è≠Ô∏è  Skipping {model_name} - already trained!")
        print(f"   Checkpoint: {verified_models[model_name]['path']}")
        print(f"   Val Loss: {verified_models[model_name]['val_loss']:.4f}")
        print(f"   Val Accuracy: {verified_models[model_name]['val_accuracy']:.2f}%")
        continue
    
    print(f"\nüöÄ Starting training for {model_name}...")
    
    # Create model
    model = create_mlp_model(
        vocab_size=linux_vocab_size,
        embedding_dim=config['embedding_dim'],
        hidden_size=HIDDEN_SIZE,
        context_length=CONTEXT_LENGTH,
        activation=config['activation']
    ).to(device)
    
    print(f"\nüìä Model: {model_name}")
    print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Train model
    model, train_losses, val_losses, val_accuracies = train_model(
        model, linux_train_loader, linux_val_loader, NUM_EPOCHS, LEARNING_RATE, device, model_name
    )
    
    # Plot training curves
    plot_training_curves(train_losses, val_losses, val_accuracies, model_name)
    
    # Generate sample predictions
    print(f"\nüéØ Sample Predictions for {model_name}:")
    predictions = generate_sample_predictions(model, linux_X, linux_idx_to_word, device, num_samples=5)
    for i, pred in enumerate(predictions, 1):
        print(f"   {i}. Context: '{pred['context']}' ‚Üí Predicted: '{pred['predicted']}'")
    
    # Save results
    all_models[model_name] = {
        'model': model,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies,
        'config': config,
        'final_val_loss': val_losses[-1],
        'final_val_accuracy': val_accuracies[-1],
        'vocab_size': linux_vocab_size,
        'word_to_idx': linux_word_to_idx,
        'idx_to_word': linux_idx_to_word,
        'dataset_type': 'linux'
    }
    
    # Incremental save after each model
    with open('all_models_results.pkl', 'wb') as f:
        pickle.dump(all_models, f)
    print(f"üíæ Saved checkpoint with {len(all_models)} models so far")
    
    print(f"\n‚úÖ {model_name} training complete!")
    print(f"   Final Val Loss: {val_losses[-1]:.4f}")
    print(f"   Final Val Accuracy: {val_accuracies[-1]:.2f}%")



TRAINING LINUX KERNEL MODELS

‚è≠Ô∏è  Skipping linux_64d-ReLU - already trained!
   Checkpoint: linux_64d-ReLU_best.pth
   Val Loss: 4.9966
   Val Accuracy: 35.33%

‚è≠Ô∏è  Skipping linux_32d-ReLU - already trained!
   Checkpoint: linux_32d-ReLU_best.pth
   Val Loss: 5.0820
   Val Accuracy: 34.16%

‚è≠Ô∏è  Skipping linux_64d-tanh - already trained!
   Checkpoint: linux_64d-tanh_best.pth
   Val Loss: 4.9388
   Val Accuracy: 37.62%

‚è≠Ô∏è  Skipping linux_32d-tanh - already trained!
   Checkpoint: linux_32d-tanh_best.pth
   Val Loss: 5.0973
   Val Accuracy: 36.35%


## Final Save and Summary

In [15]:
# First: Load metrics from checkpoints into all_models if missing
for model_name in all_models.keys():
    if all_models[model_name].get('final_val_loss') is None:
        ckpt_path = f"{model_name}_best.pth"
        if os.path.exists(ckpt_path):
            try:
                ckpt = torch.load(ckpt_path, map_location='cpu')
                all_models[model_name]['final_val_loss'] = ckpt.get('val_loss')
                all_models[model_name]['final_val_accuracy'] = ckpt.get('val_accuracy')
            except:
                pass  # Keep as None if checkpoint unreadable

# Save all results with complete metrics
print("\n" + "="*80)
print("SAVING ALL MODELS AND RESULTS")
print("="*80)

with open('all_models_results.pkl', 'wb') as f:
    pickle.dump(all_models, f)

print(f"‚úÖ All {len(all_models)} models saved to 'all_models_results.pkl'")

## Summary Report

# Print summary table
print("\n" + "="*80)
print("TRAINING SUMMARY - ALL 8 MODELS")
print("="*80)

print(f"\n{'Model':<25} {'Val Loss':<12} {'Val Acc (%)':<12} {'Parameters':<15}")
print("-" * 70)

for model_name, results in all_models.items():
    # Get metrics safely
    val_loss = results.get('final_val_loss')
    val_acc = results.get('final_val_accuracy')
    
    # Format strings
    val_loss_str = f"{val_loss:.4f}" if val_loss is not None else "N/A"
    val_acc_str = f"{val_acc:.2f}" if val_acc is not None else "N/A"
    
    # Get parameter count
    if results.get('model') is not None:
        params = sum(p.numel() for p in results['model'].parameters())
    else:
        # Reconstruct model temporarily to count params
        if 'wiki' in model_name:
            vocab_size = wiki_vocab_size
        else:
            vocab_size = linux_vocab_size
        
        embedding_dim = 64 if '64d' in model_name else 32
        activation = 'tanh' if 'tanh' in model_name else 'relu'
        
        temp_model = create_mlp_model(vocab_size, embedding_dim, HIDDEN_SIZE, CONTEXT_LENGTH, activation)
        params = sum(p.numel() for p in temp_model.parameters())
        del temp_model
    
    print(f"{model_name:<25} {val_loss_str:<12} {val_acc_str:<12} {params:>14,}")

print("\n‚ú® All training complete! Ready for embedding visualization and Streamlit deployment.")



SAVING ALL MODELS AND RESULTS
‚úÖ All 8 models saved to 'all_models_results.pkl'

TRAINING SUMMARY - ALL 8 MODELS

Model                     Val Loss     Val Acc (%)  Parameters     
----------------------------------------------------------------------
wiki_64d-ReLU             7.3232       13.34            86,014,491
wiki_32d-ReLU             7.4249       12.37            83,332,795
wiki_64d-tanh             7.4166       13.23            86,014,491
wiki_32d-tanh             7.5366       12.30            83,332,795
linux_64d-ReLU            4.9966       35.33            99,299,202
linux_32d-ReLU            5.0820       34.16            96,227,138
linux_64d-tanh            4.9388       37.62            99,299,202
linux_32d-tanh            5.0973       36.35            96,227,138

‚ú® All training complete! Ready for embedding visualization and Streamlit deployment.
