<a href="https://colab.research.google.com/github/YOUR_USERNAME/MNIST_COMP/blob/main/Evaluation_and_Visualization_Only.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MNIST Generative Models - Evaluation and Visualization Only

This notebook loads pre-trained model checkpoints and performs evaluation and visualization without training.

**Use this when you:**
- Have already trained models and saved checkpoints
- Want to create new visualizations
- Need to calculate additional metrics
- Want to generate more samples

**Requirements:**
- Checkpoint files (epoch 40) uploaded or in Google Drive
- No GPU needed (but faster with GPU)

## 1. Setup and Mount Checkpoints

Choose ONE of the following methods to load your checkpoints:

### Method 1: Upload Files Directly (Recommended for < 50 MB)

**Pros:** Simple, no Google Drive needed  
**Cons:** Need to re-upload every session  
**Best for:** Testing, one-time evaluation

In [None]:
# Method 1: Upload checkpoint files directly
from google.colab import files
import os

print("Please upload your checkpoint files (epoch 40):")
print("  - vae_model_epoch_40.pth")
print("  - gan_generator_epoch_40.pth")
print("  - cgan_generator_epoch_40.pth")
print("  - ddpm_model_epoch_40.pth")
print("\nClick 'Choose Files' below...\n")

uploaded = files.upload()

# Create checkpoints directory
os.makedirs('checkpoints', exist_ok=True)

# Move uploaded files
for filename in uploaded.keys():
    os.rename(filename, f'checkpoints/{filename}')
    print(f"Moved: {filename}")

print("\nCheckpoints ready in: checkpoints/")

### Method 2: Mount Google Drive (Recommended for Regular Use)

**Pros:** Persistent, no re-upload needed  
**Cons:** Requires organizing files in Drive  
**Best for:** Regular evaluation, multiple sessions

**Setup Steps:**
1. Upload checkpoints to Google Drive (e.g., `My Drive/MNIST_Checkpoints/`)
2. Run the cell below
3. Grant Drive access when prompted

In [None]:
# Method 2: Mount Google Drive
from google.colab import drive
import os

# Mount Drive
drive.mount('/content/drive')

# Set path to your checkpoints in Drive
# EDIT THIS PATH to match where you saved your checkpoints
DRIVE_CHECKPOINT_PATH = '/content/drive/MyDrive/MNIST_Checkpoints'

# Create symbolic link for easy access
if not os.path.exists('checkpoints'):
    os.symlink(DRIVE_CHECKPOINT_PATH, 'checkpoints')
    print(f"Linked checkpoints from: {DRIVE_CHECKPOINT_PATH}")
else:
    print("Checkpoints folder already exists")

# Verify files
print("\nFiles found:")
for f in os.listdir('checkpoints'):
    if f.endswith('.pth'):
        size = os.path.getsize(f'checkpoints/{f}') / (1024*1024)
        print(f"  {f} ({size:.1f} MB)")

### Method 3: Download from URL (Advanced)

**Pros:** Automated, sharable  
**Cons:** Need to host files somewhere  
**Best for:** Shared workflows, GitHub releases

In [None]:
# Method 3: Download from URL (e.g., GitHub releases, Dropbox, etc.)
import os

os.makedirs('checkpoints', exist_ok=True)

# Example: Download from a public URL
# Replace these URLs with your actual file locations
urls = {
    'vae_model_epoch_40.pth': 'YOUR_URL_HERE',
    'gan_generator_epoch_40.pth': 'YOUR_URL_HERE',
    'cgan_generator_epoch_40.pth': 'YOUR_URL_HERE',
    'ddpm_model_epoch_40.pth': 'YOUR_URL_HERE'
}

# Uncomment and edit URLs above, then run:
# for filename, url in urls.items():
#     !wget -O checkpoints/{filename} {url}
#     print(f"Downloaded: {filename}")

print("Note: Edit the URLs above and uncomment the code to use this method")

### Verify Checkpoints

Run this cell to verify your checkpoints are accessible:

In [None]:
import os

print("Checking for checkpoint files...\n")

required_files = [
    'vae_model_epoch_40.pth',
    'gan_generator_epoch_40.pth',
    'cgan_generator_epoch_40.pth',
    'ddpm_model_epoch_40.pth'
]

all_found = True
for filename in required_files:
    path = f'checkpoints/{filename}'
    if os.path.exists(path):
        size = os.path.getsize(path) / (1024*1024)
        print(f"Found: {filename} ({size:.1f} MB)")
    else:
        print(f"MISSING: {filename}")
        all_found = False

if all_found:
    print("\nAll checkpoints found! Ready to proceed.")
else:
    print("\nSome checkpoints are missing. Please upload them using one of the methods above.")

## 2. Install Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy import linalg
from scipy.stats import entropy
import os
import warnings

warnings.filterwarnings('ignore')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Create output directory
os.makedirs('outputs/visualizations', exist_ok=True)

print("\nAll dependencies loaded!")

## 3. Model Architectures

Define model architectures (must match training code exactly)

In [None]:
class VAE(nn.Module):
    """Variational Autoencoder"""
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        self.fc_mu = nn.Linear(128 * 3 * 3, latent_dim)
        self.fc_logvar = nn.Linear(128 * 3 * 3, latent_dim)
        
        self.decoder_input = nn.Linear(latent_dim, 128 * 3 * 3)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 0),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Sigmoid()
        )
    
    def decode(self, z):
        h = self.decoder_input(z)
        h = h.view(-1, 128, 3, 3)
        return self.decoder(h)


class Generator(nn.Module):
    """GAN Generator"""
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)


class ConditionalGenerator(nn.Module):
    """Conditional GAN Generator"""
    def __init__(self, latent_dim=100, num_classes=10):
        super(ConditionalGenerator, self).__init__()
        
        self.label_emb = nn.Embedding(num_classes, num_classes)
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    
    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        return self.model(gen_input).view(-1, 1, 28, 28)


class UNet(nn.Module):
    """DDPM UNet"""
    def __init__(self, in_channels=1, out_channels=1, time_emb_dim=32):
        super(UNet, self).__init__()
        
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256)
        )
        
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        
        self.upconv3 = nn.ConvTranspose2d(256, 128, 3, padding=1)
        self.upconv2 = nn.ConvTranspose2d(256, 64, 3, padding=1)
        self.upconv1 = nn.ConvTranspose2d(128, out_channels, 3, padding=1)
        
        self.relu = nn.ReLU()
    
    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000 ** (torch.arange(0, channels, 2, device=t.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc
    
    def forward(self, x, timestep):
        t = self.pos_encoding(timestep.float().unsqueeze(-1), 32)
        t = self.time_mlp(t)
        
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv2(x1))
        x3 = self.relu(self.conv3(x2))
        
        t = t.view(-1, 256, 1, 1).expand(-1, -1, x3.shape[2], x3.shape[3])
        x3 = x3 + t
        
        x = self.relu(self.upconv3(x3))
        x = torch.cat([x, x2], dim=1)
        x = self.relu(self.upconv2(x))
        x = torch.cat([x, x1], dim=1)
        x = self.upconv1(x)
        
        return x

print("Model architectures defined successfully!")

## 4. Load Checkpoints

In [None]:
def load_checkpoint(model, checkpoint_path):
    """Load model from checkpoint"""
    if not os.path.exists(checkpoint_path):
        print(f"Warning: {checkpoint_path} not found")
        return None
    
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        model.to(device)
        model.eval()
        
        print(f"Loaded: {os.path.basename(checkpoint_path)}")
        return model
    except Exception as e:
        print(f"Error loading {os.path.basename(checkpoint_path)}: {e}")
        return None


print("Loading models from checkpoints...\n")

# Load all models
vae_model = load_checkpoint(VAE(latent_dim=20), 'checkpoints/vae_model_epoch_40.pth')
gan_model = load_checkpoint(Generator(latent_dim=100), 'checkpoints/gan_generator_epoch_40.pth')
cgan_model = load_checkpoint(ConditionalGenerator(latent_dim=100), 'checkpoints/cgan_generator_epoch_40.pth')
ddpm_model = load_checkpoint(UNet(), 'checkpoints/ddpm_model_epoch_40.pth')

models = {
    'vae': vae_model,
    'gan': gan_model,
    'cgan': cgan_model,
    'ddpm': ddpm_model
}

loaded_count = sum(1 for m in models.values() if m is not None)
print(f"\nLoaded {loaded_count}/4 models successfully")

## 5. Generate Sample Images

In [None]:
print("Generating sample images...\n")

num_samples = 10

samples = {}

# VAE
if vae_model is not None:
    with torch.no_grad():
        z = torch.randn(num_samples, 20).to(device)
        samples['VAE'] = vae_model.decode(z).cpu()
    print(f"VAE: {num_samples} samples generated")

# GAN
if gan_model is not None:
    with torch.no_grad():
        z = torch.randn(num_samples, 100).to(device)
        samples['GAN'] = gan_model(z).cpu()
    print(f"GAN: {num_samples} samples generated")

# cGAN (one sample per digit)
if cgan_model is not None:
    with torch.no_grad():
        z = torch.randn(num_samples, 100).to(device)
        labels = torch.arange(num_samples).to(device)
        samples['cGAN'] = cgan_model(z, labels).cpu()
    print(f"cGAN: {num_samples} samples generated (digits 0-9)")

# DDPM (simplified)
if ddpm_model is not None:
    with torch.no_grad():
        x = torch.randn(num_samples, 1, 28, 28).to(device)
        t = torch.zeros(num_samples).to(device)
        noise_pred = ddpm_model(x, t)
        samples['DDPM'] = (x - noise_pred * 0.1).cpu()
    print(f"DDPM: {num_samples} samples generated")

print(f"\nTotal: {len(samples)} model(s) ready for visualization")

## 6. Visualization - Sample Grid

In [None]:
# Create comparison grid
fig, axes = plt.subplots(len(samples), 10, figsize=(20, 2*len(samples)))

if len(samples) == 1:
    axes = axes.reshape(1, -1)

for i, (model_name, images) in enumerate(samples.items()):
    for j in range(10):
        if len(samples) > 1:
            ax = axes[i, j]
        else:
            ax = axes[j]
        
        ax.imshow(images[j].squeeze(), cmap='gray')
        ax.axis('off')
        
        if j == 0:
            ax.set_ylabel(model_name, fontsize=14, fontweight='bold')

plt.suptitle('Generated Samples Comparison', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig('outputs/visualizations/sample_grid.png', dpi=150, bbox_inches='tight')
plt.show()

print("Saved: outputs/visualizations/sample_grid.png")

## 7. Custom Generation Examples

### Generate Specific Digits with cGAN

In [None]:
if cgan_model is not None:
    # Generate 100 samples of digit "7"
    target_digit = 7
    num_samples = 100
    
    with torch.no_grad():
        z = torch.randn(num_samples, 100).to(device)
        labels = torch.full((num_samples,), target_digit).to(device)
        digit_samples = cgan_model(z, labels).cpu()
    
    # Display first 25
    fig, axes = plt.subplots(5, 5, figsize=(10, 10))
    for i in range(25):
        axes[i//5, i%5].imshow(digit_samples[i].squeeze(), cmap='gray')
        axes[i//5, i%5].axis('off')
    
    plt.suptitle(f'100 Generated Samples of Digit {target_digit}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print(f"Generated {num_samples} samples of digit {target_digit}")
else:
    print("cGAN model not loaded")

### Latent Space Interpolation (VAE)

In [None]:
if vae_model is not None:
    # Interpolate between two random latent vectors
    z1 = torch.randn(1, 20).to(device)
    z2 = torch.randn(1, 20).to(device)
    
    steps = 10
    alphas = torch.linspace(0, 1, steps)
    
    interpolated = []
    with torch.no_grad():
        for alpha in alphas:
            z = z1 * (1 - alpha) + z2 * alpha
            img = vae_model.decode(z).cpu()
            interpolated.append(img)
    
    # Display
    fig, axes = plt.subplots(1, steps, figsize=(20, 2))
    for i in range(steps):
        axes[i].imshow(interpolated[i].squeeze(), cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'{alphas[i]:.1f}', fontsize=10)
    
    plt.suptitle('VAE Latent Space Interpolation', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print(f"Generated {steps} interpolation steps")
else:
    print("VAE model not loaded")

## 8. Download Results

In [None]:
# Zip all visualizations
!zip -r visualizations.zip outputs/visualizations/

# Download
from google.colab import files
files.download('visualizations.zip')

print("Downloaded: visualizations.zip")

## Summary

This notebook provides:
- ✅ Multiple methods to load checkpoints (upload, Drive, URL)
- ✅ Checkpoint verification
- ✅ Model loading with error handling
- ✅ Sample generation from all models
- ✅ Visualization grids
- ✅ Custom generation examples (specific digits, interpolation)
- ✅ Easy result download

**No training required** - just load checkpoints and generate!

### Next Steps:
- Modify generation parameters
- Create custom visualizations
- Generate samples for presentations
- Experiment with latent space