# <center> </center>
# <center> **Computer Vision** </center>
# <center> **Portfolio Exam 3**</center>
# <center>**Implement and analyze a Variational Autoencoder on CelebA**</center>


**Submitted by:**
****

*   **Riya Biju - 10000742**
*   **Harsha Sathish - 10001000**
*   **Harshith Babu Prakash Babu - 10001191**


In [None]:
"""
VAE Implementation for CelebA Dataset - Portfolio 3 Phase 1
Baseline model with latent_dim = 128
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
from datetime import datetime

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# ============================================================================
# Dataset Class
# ============================================================================

class CelebADataset(Dataset):
    """
    Custom Dataset for CelebA aligned and cropped images
    """
    def __init__(self, root_dir, partition_file, split='train', transform=None, subset_size=None):
        """
        Args:
            root_dir: Path to CelebA directory
            partition_file: Path to list_eval_partition.txt
            split: 'train', 'val', or 'test' (0, 1, 2 in partition file)
            transform: Transformations to apply
            subset_size: If specified, use only this many images (for faster training)
        """
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'Img', 'img_align_celeba')
        self.transform = transform
        
        # Load partition information
        self.image_files = []
        split_map = {'train': 0, 'val': 1, 'test': 2}
        split_id = split_map[split]
        
        with open(partition_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) == 2:
                    img_name, partition = parts
                    if int(partition) == split_id:
                        self.image_files.append(img_name)
        
        # Apply subset if specified
        if subset_size is not None and subset_size < len(self.image_files):
            np.random.seed(42)
            indices = np.random.choice(len(self.image_files), subset_size, replace=False)
            self.image_files = [self.image_files[i] for i in sorted(indices)]
        
        print(f"{split.upper()} set: {len(self.image_files)} images")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image


# ============================================================================
# VAE Model Architecture
# ============================================================================

class VAE(nn.Module):
    """
    Convolutional Variational Autoencoder for CelebA (64x64 images)
    """
    def __init__(self, latent_dim=128, img_channels=3):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        
        # ====== ENCODER ======
        # Input: (3, 64, 64)
        self.encoder = nn.Sequential(
            # Layer 1: (3, 64, 64) -> (32, 32, 32)
            nn.Conv2d(img_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 2: (32, 32, 32) -> (64, 16, 16)
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 3: (64, 16, 16) -> (128, 8, 8)
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 4: (128, 8, 8) -> (256, 4, 4)
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        # Flattened feature size: 256 * 4 * 4 = 4096
        self.fc_encoder = nn.Sequential(
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True)
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)
        
        # ====== DECODER ======
        self.fc_decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256 * 4 * 4),
            nn.ReLU(inplace=True)
        )
        
        # Decoder: Mirror of encoder
        self.decoder = nn.Sequential(
            # Layer 1: (256, 4, 4) -> (128, 8, 8)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # Layer 2: (128, 8, 8) -> (64, 16, 16)
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # Layer 3: (64, 16, 16) -> (32, 32, 32)
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            # Layer 4: (32, 32, 32) -> (3, 64, 64)
            nn.ConvTranspose2d(32, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Output in [0, 1]
        )
    
    def encode(self, x):
        """Encode input to latent parameters"""
        h = self.encoder(x)
        h = h.view(h.size(0), -1)  # Flatten
        h = self.fc_encoder(h)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick: z = mu + std * epsilon"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        """Decode latent vector to image"""
        h = self.fc_decoder(z)
        h = h.view(h.size(0), 256, 4, 4)  # Reshape
        return self.decoder(h)
    
    def forward(self, x):
        """Full forward pass"""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar


# ============================================================================
# Loss Function
# ============================================================================
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    """
    VAE Loss - Production Ready Version
    """
    batch_size = x.size(0)
    
    # Reconstruction: Binary cross-entropy
    # Sum over pixels, mean over batch
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum') / batch_size
    # Per-pixel for display
    recon_loss_display = recon_loss / (x.size(1) * x.size(2) * x.size(3))
    
    # KL Divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
    
    total_loss = recon_loss + beta * kl_loss
    
    return total_loss, recon_loss_display, kl_loss



# ============================================================================
# Training Function
# ============================================================================

def train_epoch(model, train_loader, optimizer, epoch, beta=1.0):
    """Train for one epoch"""
    model.train()
    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    for batch_idx, data in enumerate(pbar):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        recon_batch, mu, logvar = model(data)
        
        # Compute loss
        loss, recon_loss, kl_loss = vae_loss(recon_batch, data, mu, logvar, beta)
        
        # SAFETY CHECK: Detect posterior collapse
        if kl_loss.item() < 1.0 and epoch > 1:
            print(f"\n‚ö†Ô∏è WARNING: KL collapsed at epoch {epoch}, batch {batch_idx}")
            print(f"   KL: {kl_loss.item():.6f}")
            print(f"   mu range: [{mu.min():.4f}, {mu.max():.4f}]")
            print(f"   logvar range: [{logvar.min():.4f}, {logvar.max():.4f}]")
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Accumulate losses
        train_loss += loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Recon': f'{recon_loss.item():.4f}',
            'KL': f'{kl_loss.item():.4f}'
        })
    
    # Average losses
    n_batches = len(train_loader)
    return train_loss / n_batches, train_recon_loss / n_batches, train_kl_loss / n_batches


def validate(model, val_loader, beta=1.0):
    """Validate the model"""
    model.eval()
    val_loss = 0
    val_recon_loss = 0
    val_kl_loss = 0
    
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss, recon_loss, kl_loss = vae_loss(recon_batch, data, mu, logvar, beta)
            
            val_loss += loss.item()
            val_recon_loss += recon_loss.item()
            val_kl_loss += kl_loss.item()
    
    # Average losses
    n_batches = len(val_loader)
    return val_loss / n_batches, val_recon_loss / n_batches, val_kl_loss / n_batches


# ============================================================================
# Visualization Functions
# ============================================================================

def save_sample_images(model, val_loader, epoch, save_dir, num_samples=8):
    """Save sample reconstructions"""
    model.eval()
    
    # Get a batch
    data = next(iter(val_loader))[:num_samples].to(device)
    
    with torch.no_grad():
        recon, _, _ = model(data)
    
    # Move to CPU and denormalize
    data = data.cpu()
    recon = recon.cpu()
    
    # Create figure
    fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 2, 4))
    
    for i in range(num_samples):
        # Original
        axes[0, i].imshow(data[i].permute(1, 2, 0))
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=10)
        
        # Reconstruction
        axes[1, i].imshow(recon[i].permute(1, 2, 0))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'reconstruction_epoch_{epoch}.png'), dpi=100, bbox_inches='tight')
    plt.close()


def plot_training_curves(history, save_dir):
    """Plot and save training curves"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Total Loss
    axes[0].plot(epochs, history['train_loss'], label='Train', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], label='Validation', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Total Loss')
    axes[0].set_title('Total Loss (ELBO)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Reconstruction Loss
    axes[1].plot(epochs, history['train_recon'], label='Train', linewidth=2)
    axes[1].plot(epochs, history['val_recon'], label='Validation', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Reconstruction Loss')
    axes[1].set_title('Reconstruction Loss (BCE)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # KL Divergence
    axes[2].plot(epochs, history['train_kl'], label='Train', linewidth=2)
    axes[2].plot(epochs, history['val_kl'], label='Validation', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('KL Divergence')
    axes[2].set_title('KL Divergence')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'training_curves.png'), dpi=150, bbox_inches='tight')
    plt.close()


# ============================================================================
# Main Training Script
# ============================================================================

def main():
    # ========== Configuration ==========
    config = {
        'latent_dim': 128,
        'beta': 1.0,  
        'img_size': 64,
        'batch_size': 128,
        'num_epochs': 100,
        'learning_rate': 1e-3,
        'subset_train': 50000,  # Use 50k training images 
        'subset_val': 5000,     # Use 5k validation images
        'num_workers': 0,       # DataLoader workers
        'save_interval': 5,     # Save checkpoint every N epochs
    }
    
    print("=" * 60)
    print("VAE Training Configuration")
    print("=" * 60)
    for key, value in config.items():
        print(f"{key:20s}: {value}")
    print("=" * 60)
    
    # ========== Paths ==========
    celeba_root = './CelebA'  # Adjust if needed
    partition_file = os.path.join(celeba_root, 'Eval', 'list_eval_partition.txt')
    
    # Create output directories
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f'./outputs/vae_baseline_latent{config["latent_dim"]}_{timestamp}'
    checkpoint_dir = os.path.join(output_dir, 'checkpoints')
    sample_dir = os.path.join(output_dir, 'samples')
    
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(sample_dir, exist_ok=True)
    
    # Save configuration
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)
    
    # ========== Data Transforms ==========
    transform = transforms.Compose([
        transforms.Resize(config['img_size']),
        transforms.CenterCrop(config['img_size']),
        transforms.ToTensor(),
        # Images are already in [0, 1] after ToTensor()
    ])
    
    # ========== Datasets and DataLoaders ==========
    print("\nLoading datasets...")
    train_dataset = CelebADataset(
        root_dir=celeba_root,
        partition_file=partition_file,
        split='train',
        transform=transform,
        subset_size=config['subset_train']
    )
    
    val_dataset = CelebADataset(
        root_dir=celeba_root,
        partition_file=partition_file,
        split='val',
        transform=transform,
        subset_size=config['subset_val']
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    # ========== Model Setup ==========
    print("\nInitializing model...")
    model = VAE(latent_dim=config['latent_dim']).to(device)
    
    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # ========== Training Loop ==========
    print("\n" + "=" * 60)
    print("Starting Training")
    print("=" * 60)
    
    history = {
        'train_loss': [], 'train_recon': [], 'train_kl': [],
        'val_loss': [], 'val_recon': [], 'val_kl': []
    }
    
    best_val_loss = float('inf')
    
    for epoch in range(1, config['num_epochs'] + 1):
        # Train
        train_loss, train_recon, train_kl = train_epoch(
            model, train_loader, optimizer, epoch, beta=config['beta']
        )
        
        # Validate
        val_loss, val_recon, val_kl = validate(model, val_loader, beta=config['beta'])
        
        # Update learning rate
        scheduler.step(val_loss)
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_recon'].append(train_recon)
        history['train_kl'].append(train_kl)
        history['val_loss'].append(val_loss)
        history['val_recon'].append(val_recon)
        history['val_kl'].append(val_kl)
        
        # Print epoch summary
        print(f"\nEpoch {epoch}/{config['num_epochs']}:")
        print(f"  Train - Loss: {train_loss:.4f}, Recon: {train_recon:.4f}, KL: {train_kl:.4f}")
        print(f"  Val   - Loss: {val_loss:.4f}, Recon: {val_recon:.4f}, KL: {val_kl:.4f}")
        
        # HEALTH CHECK
        if val_kl < 10.0:
            print(f"  ‚ö†Ô∏è  WARNING: KL very low ({val_kl:.4f}) - monitor for collapse")
        elif val_kl > 200.0:
            print(f"  ‚ö†Ô∏è  WARNING: KL very high ({val_kl:.4f}) - model may ignore inputs")
        
        # Save sample reconstructions
        if epoch % config['save_interval'] == 0:
            save_sample_images(model, val_loader, epoch, sample_dir)
        
        # Save checkpoint
        if epoch % config['save_interval'] == 0 or epoch == config['num_epochs']:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'config': config,
                'history': history,
            }
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
            torch.save(checkpoint, checkpoint_path)
            print(f"  Saved checkpoint: {checkpoint_path}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'config': config,
                'history': history,
                'val_loss': val_loss,
            }
            best_path = os.path.join(checkpoint_dir, 'best_model.pth')
            torch.save(best_checkpoint, best_path)
            print(f"  ‚≠ê New best model saved! Val Loss: {val_loss:.4f}")
    
    # ========== Final Steps ==========
    print("\n" + "=" * 60)
    print("Training Complete!")
    print("=" * 60)
    
    # Plot training curves
    plot_training_curves(history, output_dir)
    print(f"\nTraining curves saved to: {output_dir}/training_curves.png")
    
    # Save final model
    final_checkpoint = {
        'epoch': config['num_epochs'],
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'config': config,
        'history': history,
    }
    final_path = os.path.join(checkpoint_dir, 'final_model.pth')
    torch.save(final_checkpoint, final_path)
    print(f"Final model saved to: {final_path}")
    
    # Save training history
    with open(os.path.join(output_dir, 'training_history.json'), 'w') as f:
        json.dump(history, f, indent=4)
    
    print(f"\nAll outputs saved to: {output_dir}")
    print(f"Best validation loss: {best_val_loss:.4f}")


if __name__ == '__main__':
    main()

Using device: cpu
VAE Training Configuration
latent_dim          : 128
beta                : 1.0
img_size            : 64
batch_size          : 128
num_epochs          : 50
learning_rate       : 0.001
subset_train        : 50000
subset_val          : 5000
num_workers         : 0
save_interval       : 5

Loading datasets...
TRAIN set: 50000 images
VAL set: 5000 images

Initializing model...
Total parameters: 5,777,731
Trainable parameters: 5,777,731

Starting Training


Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6792.6978, Recon=0.5463, KL=79.9877]



Epoch 1/50:
  Train - Loss: 6878.3105, Recon: 0.5542, KL: 68.1596
  Val   - Loss: 6639.6001, Recon: 0.5341, KL: 76.7717
  ‚≠ê New best model saved! Val Loss: 6639.6001


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6434.8579, Recon=0.5173, KL=78.1094]



Epoch 2/50:
  Train - Loss: 6557.9734, Recon: 0.5273, KL: 77.9249
  Val   - Loss: 6523.7637, Recon: 0.5243, KL: 80.6387
  ‚≠ê New best model saved! Val Loss: 6523.7637


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6311.2368, Recon=0.5072, KL=78.4920]



Epoch 3/50:
  Train - Loss: 6478.5579, Recon: 0.5208, KL: 78.3853
  Val   - Loss: 6481.7223, Recon: 0.5213, KL: 76.4383
  ‚≠ê New best model saved! Val Loss: 6481.7223


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6243.8496, Recon=0.5020, KL=75.1776]



Epoch 4/50:
  Train - Loss: 6443.6789, Recon: 0.5181, KL: 76.8124
  Val   - Loss: 6452.3402, Recon: 0.5189, KL: 75.6709
  ‚≠ê New best model saved! Val Loss: 6452.3402


Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:03<00:00,  3.17it/s, Loss=6400.8022, Recon=0.5147, KL=76.4400]



Epoch 5/50:
  Train - Loss: 6421.0561, Recon: 0.5163, KL: 76.5419
  Val   - Loss: 6422.8854, Recon: 0.5165, KL: 75.8616
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_5.pth
  ‚≠ê New best model saved! Val Loss: 6422.8854


Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:12<00:00,  2.94it/s, Loss=6340.9531, Recon=0.5096, KL=79.1084]



Epoch 6/50:
  Train - Loss: 6406.0550, Recon: 0.5151, KL: 76.3173
  Val   - Loss: 6416.3697, Recon: 0.5158, KL: 78.7179
  ‚≠ê New best model saved! Val Loss: 6416.3697


Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6448.1924, Recon=0.5185, KL=76.6198]



Epoch 7/50:
  Train - Loss: 6393.4945, Recon: 0.5141, KL: 76.8150
  Val   - Loss: 6403.3099, Recon: 0.5148, KL: 77.5588
  ‚≠ê New best model saved! Val Loss: 6403.3099


Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:04<00:00,  3.15it/s, Loss=6357.6821, Recon=0.5110, KL=79.0492]



Epoch 8/50:
  Train - Loss: 6387.0419, Recon: 0.5135, KL: 77.5937
  Val   - Loss: 6397.5758, Recon: 0.5142, KL: 78.6530
  ‚≠ê New best model saved! Val Loss: 6397.5758


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:06<00:00,  3.08it/s, Loss=6331.4570, Recon=0.5089, KL=77.5847]



Epoch 9/50:
  Train - Loss: 6380.1822, Recon: 0.5129, KL: 78.1107
  Val   - Loss: 6398.4593, Recon: 0.5143, KL: 78.4119


Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:09<00:00,  3.01it/s, Loss=6318.4966, Recon=0.5076, KL=81.3236]



Epoch 10/50:
  Train - Loss: 6373.8249, Recon: 0.5123, KL: 78.9392
  Val   - Loss: 6400.8673, Recon: 0.5143, KL: 80.8845
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_10.pth


Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:12<00:00,  2.94it/s, Loss=6481.8784, Recon=0.5211, KL=78.0648]



Epoch 11/50:
  Train - Loss: 6370.1122, Recon: 0.5119, KL: 79.5038
  Val   - Loss: 6389.9251, Recon: 0.5135, KL: 79.7429
  ‚≠ê New best model saved! Val Loss: 6389.9251


Epoch 12: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:06<00:00,  3.10it/s, Loss=6162.7178, Recon=0.4950, KL=80.1881]



Epoch 12/50:
  Train - Loss: 6365.3154, Recon: 0.5115, KL: 79.7906
  Val   - Loss: 6393.3338, Recon: 0.5136, KL: 81.6352


Epoch 13: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:03<00:00,  3.17it/s, Loss=6268.4878, Recon=0.5035, KL=81.7440]



Epoch 13/50:
  Train - Loss: 6362.1873, Recon: 0.5112, KL: 79.9884
  Val   - Loss: 6379.9578, Recon: 0.5127, KL: 80.1790
  ‚≠ê New best model saved! Val Loss: 6379.9578


Epoch 14: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:03<00:00,  3.17it/s, Loss=6289.6118, Recon=0.5053, KL=80.1810]



Epoch 14/50:
  Train - Loss: 6360.0893, Recon: 0.5111, KL: 80.1438
  Val   - Loss: 6383.0636, Recon: 0.5127, KL: 83.4185


Epoch 15: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:08<00:00,  3.04it/s, Loss=6440.8608, Recon=0.5177, KL=79.9152]



Epoch 15/50:
  Train - Loss: 6355.8921, Recon: 0.5107, KL: 80.4546
  Val   - Loss: 6374.5516, Recon: 0.5124, KL: 78.3633
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_15.pth
  ‚≠ê New best model saved! Val Loss: 6374.5516


Epoch 16: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:09<00:00,  3.02it/s, Loss=6210.6250, Recon=0.4989, KL=79.6375]



Epoch 16/50:
  Train - Loss: 6354.3516, Recon: 0.5105, KL: 80.7783
  Val   - Loss: 6374.2113, Recon: 0.5121, KL: 81.2469
  ‚≠ê New best model saved! Val Loss: 6374.2113


Epoch 17: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:05<00:00,  3.11it/s, Loss=6506.7666, Recon=0.5230, KL=80.1570]



Epoch 17/50:
  Train - Loss: 6351.1403, Recon: 0.5103, KL: 81.0663
  Val   - Loss: 6378.2788, Recon: 0.5127, KL: 78.5737


Epoch 18: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:04<00:00,  3.14it/s, Loss=6325.9888, Recon=0.5080, KL=83.4667]



Epoch 18/50:
  Train - Loss: 6349.8823, Recon: 0.5101, KL: 81.2761
  Val   - Loss: 6374.5412, Recon: 0.5120, KL: 83.4033


Epoch 19: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6458.6348, Recon=0.5190, KL=81.2497]



Epoch 19/50:
  Train - Loss: 6347.5917, Recon: 0.5099, KL: 81.4595
  Val   - Loss: 6371.7635, Recon: 0.5118, KL: 82.2805
  ‚≠ê New best model saved! Val Loss: 6371.7635


Epoch 20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6320.8408, Recon=0.5076, KL=83.4030]



Epoch 20/50:
  Train - Loss: 6345.8173, Recon: 0.5098, KL: 81.6166
  Val   - Loss: 6369.5498, Recon: 0.5118, KL: 80.6988
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_20.pth
  ‚≠ê New best model saved! Val Loss: 6369.5498


Epoch 21: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:05<00:00,  3.12it/s, Loss=6577.0825, Recon=0.5287, KL=80.2349]



Epoch 21/50:
  Train - Loss: 6345.1591, Recon: 0.5097, KL: 81.7293
  Val   - Loss: 6369.9380, Recon: 0.5117, KL: 81.6618


Epoch 22: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:04<00:00,  3.14it/s, Loss=6315.1074, Recon=0.5073, KL=81.3754]



Epoch 22/50:
  Train - Loss: 6342.4308, Recon: 0.5095, KL: 81.7900
  Val   - Loss: 6368.3759, Recon: 0.5116, KL: 81.9110
  ‚≠ê New best model saved! Val Loss: 6368.3759


Epoch 23: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:05<00:00,  3.12it/s, Loss=6284.9302, Recon=0.5048, KL=81.7567]



Epoch 23/50:
  Train - Loss: 6341.1791, Recon: 0.5094, KL: 81.9361
  Val   - Loss: 6364.1654, Recon: 0.5113, KL: 81.6104
  ‚≠ê New best model saved! Val Loss: 6364.1654


Epoch 24: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6351.8550, Recon=0.5103, KL=80.9809]



Epoch 24/50:
  Train - Loss: 6340.4405, Recon: 0.5093, KL: 82.0723
  Val   - Loss: 6364.6839, Recon: 0.5114, KL: 80.3254


Epoch 25: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6183.3999, Recon=0.4964, KL=83.7705]



Epoch 25/50:
  Train - Loss: 6336.9413, Recon: 0.5090, KL: 82.1642
  Val   - Loss: 6370.0347, Recon: 0.5117, KL: 81.9173
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_25.pth


Epoch 26: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6372.5830, Recon=0.5119, KL=82.5838]



Epoch 26/50:
  Train - Loss: 6336.9243, Recon: 0.5090, KL: 82.3237
  Val   - Loss: 6370.4287, Recon: 0.5117, KL: 82.6841


Epoch 27: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6243.2446, Recon=0.5011, KL=85.5626]



Epoch 27/50:
  Train - Loss: 6335.1962, Recon: 0.5089, KL: 82.4300
  Val   - Loss: 6361.6811, Recon: 0.5109, KL: 83.5818
  ‚≠ê New best model saved! Val Loss: 6361.6811


Epoch 28: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6181.1880, Recon=0.4964, KL=81.0723]



Epoch 28/50:
  Train - Loss: 6332.9836, Recon: 0.5087, KL: 82.5227
  Val   - Loss: 6360.1173, Recon: 0.5109, KL: 81.6714
  ‚≠ê New best model saved! Val Loss: 6360.1173


Epoch 29: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:03<00:00,  3.17it/s, Loss=6394.1719, Recon=0.5136, KL=83.1226]



Epoch 29/50:
  Train - Loss: 6333.3137, Recon: 0.5087, KL: 82.5367
  Val   - Loss: 6361.0769, Recon: 0.5110, KL: 81.7419


Epoch 30: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:07<00:00,  3.07it/s, Loss=6355.7808, Recon=0.5105, KL=83.2858]



Epoch 30/50:
  Train - Loss: 6331.2050, Recon: 0.5085, KL: 82.6543
  Val   - Loss: 6363.3299, Recon: 0.5111, KL: 83.0373
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_30.pth


Epoch 31: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6376.1016, Recon=0.5120, KL=84.0521]



Epoch 31/50:
  Train - Loss: 6330.0346, Recon: 0.5084, KL: 82.7081
  Val   - Loss: 6361.6011, Recon: 0.5109, KL: 83.4094


Epoch 32: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6374.1709, Recon=0.5118, KL=84.8977]



Epoch 32/50:
  Train - Loss: 6330.1492, Recon: 0.5084, KL: 82.9407
  Val   - Loss: 6358.6386, Recon: 0.5107, KL: 83.2992
  ‚≠ê New best model saved! Val Loss: 6358.6386


Epoch 33: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6442.0059, Recon=0.5175, KL=82.8335]



Epoch 33/50:
  Train - Loss: 6328.2123, Recon: 0.5082, KL: 82.9641
  Val   - Loss: 6357.0402, Recon: 0.5106, KL: 82.2747
  ‚≠ê New best model saved! Val Loss: 6357.0402


Epoch 34: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6344.3442, Recon=0.5094, KL=84.4618]



Epoch 34/50:
  Train - Loss: 6326.5197, Recon: 0.5081, KL: 83.1365
  Val   - Loss: 6356.1923, Recon: 0.5105, KL: 83.7558
  ‚≠ê New best model saved! Val Loss: 6356.1923


Epoch 35: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6266.0044, Recon=0.5033, KL=81.6552]



Epoch 35/50:
  Train - Loss: 6325.2870, Recon: 0.5080, KL: 83.1755
  Val   - Loss: 6361.8120, Recon: 0.5109, KL: 84.4153
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_35.pth


Epoch 36: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6366.0308, Recon=0.5110, KL=86.5999]



Epoch 36/50:
  Train - Loss: 6324.7551, Recon: 0.5079, KL: 83.4057
  Val   - Loss: 6360.9611, Recon: 0.5109, KL: 83.5519


Epoch 37: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6291.0034, Recon=0.5052, KL=83.0476]



Epoch 37/50:
  Train - Loss: 6323.9490, Recon: 0.5079, KL: 83.4119
  Val   - Loss: 6355.2878, Recon: 0.5104, KL: 83.6642
  ‚≠ê New best model saved! Val Loss: 6355.2878


Epoch 38: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6387.6929, Recon=0.5131, KL=82.8664]



Epoch 38/50:
  Train - Loss: 6323.2962, Recon: 0.5078, KL: 83.5085
  Val   - Loss: 6360.7877, Recon: 0.5109, KL: 83.1566


Epoch 39: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6446.3022, Recon=0.5175, KL=87.1290]



Epoch 39/50:
  Train - Loss: 6322.2919, Recon: 0.5077, KL: 83.6065
  Val   - Loss: 6359.8396, Recon: 0.5108, KL: 82.7322


Epoch 40: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6272.9976, Recon=0.5038, KL=82.4172]



Epoch 40/50:
  Train - Loss: 6321.3256, Recon: 0.5076, KL: 83.6573
  Val   - Loss: 6356.1699, Recon: 0.5105, KL: 82.8961
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_40.pth


Epoch 41: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6271.7998, Recon=0.5037, KL=82.7789]



Epoch 41/50:
  Train - Loss: 6319.4581, Recon: 0.5075, KL: 83.7897
  Val   - Loss: 6355.3707, Recon: 0.5104, KL: 83.5589


Epoch 42: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6294.2231, Recon=0.5054, KL=83.8687]



Epoch 42/50:
  Train - Loss: 6319.3494, Recon: 0.5075, KL: 83.7799
  Val   - Loss: 6356.1206, Recon: 0.5105, KL: 82.6730


Epoch 43: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6407.7256, Recon=0.5147, KL=82.7574]



Epoch 43/50:
  Train - Loss: 6319.0847, Recon: 0.5074, KL: 83.7517
  Val   - Loss: 6356.8568, Recon: 0.5104, KL: 85.3786


Epoch 44: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6184.9902, Recon=0.4966, KL=83.0815]



Epoch 44/50:
  Train - Loss: 6309.3049, Recon: 0.5066, KL: 84.0141
  Val   - Loss: 6348.0464, Recon: 0.5098, KL: 83.5108
  ‚≠ê New best model saved! Val Loss: 6348.0464


Epoch 45: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6358.9531, Recon=0.5107, KL=82.9790]



Epoch 45/50:
  Train - Loss: 6308.0440, Recon: 0.5065, KL: 84.0641
  Val   - Loss: 6348.3501, Recon: 0.5098, KL: 83.4201
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_45.pth


Epoch 46: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.19it/s, Loss=6259.4468, Recon=0.5027, KL=81.8462]



Epoch 46/50:
  Train - Loss: 6308.1152, Recon: 0.5065, KL: 84.1240
  Val   - Loss: 6348.1996, Recon: 0.5098, KL: 83.5715


Epoch 47: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6534.5088, Recon=0.5247, KL=87.4725]



Epoch 47/50:
  Train - Loss: 6307.6595, Recon: 0.5065, KL: 84.2165
  Val   - Loss: 6348.8349, Recon: 0.5098, KL: 84.3211


Epoch 48: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6133.4321, Recon=0.4923, KL=83.5972]



Epoch 48/50:
  Train - Loss: 6306.8524, Recon: 0.5064, KL: 84.3207
  Val   - Loss: 6350.4367, Recon: 0.5100, KL: 84.0097


Epoch 49: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:01<00:00,  3.21it/s, Loss=6216.7100, Recon=0.4991, KL=83.5762]



Epoch 49/50:
  Train - Loss: 6306.6552, Recon: 0.5064, KL: 84.3302
  Val   - Loss: 6348.7984, Recon: 0.5098, KL: 84.1223


Epoch 50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [02:02<00:00,  3.20it/s, Loss=6209.8389, Recon=0.4986, KL=82.5892]



Epoch 50/50:
  Train - Loss: 6306.1519, Recon: 0.5063, KL: 84.3299
  Val   - Loss: 6348.4335, Recon: 0.5098, KL: 84.1483
  Saved checkpoint: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/checkpoint_epoch_50.pth

Training Complete!

Training curves saved to: ./outputs/vae_baseline_latent128_20251231_170110/training_curves.png
Final model saved to: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/final_model.pth

All outputs saved to: ./outputs/vae_baseline_latent128_20251231_170110
Best validation loss: 6348.0464


In [3]:
"""
SECOND CELL - VAE Testing and Visualization
"""
device = 'cpu'
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import glob

# Reuse the VAE and Dataset classes from first cell (already defined)
# No need to redefine them

print("=" * 60)
print("VAE Model Testing and Visualization")
print("=" * 60)

# ============================================================================
# Configuration
# ============================================================================

# Find the most recent output directory
output_dirs = glob.glob('./outputs/vae_baseline_latent128_*')
if not output_dirs:
    raise FileNotFoundError("No output directories found! Make sure training completed.")

output_dir = sorted(output_dirs)[-1]  # Get most recent
checkpoint_dir = os.path.join(output_dir, 'checkpoints')
print(f"\nUsing output directory: {output_dir}")

# Check which checkpoint to use
best_model_path = os.path.join(checkpoint_dir, 'best_model.pth')
final_model_path = os.path.join(checkpoint_dir, 'final_model.pth')

if os.path.exists(best_model_path):
    checkpoint_path = best_model_path
    print(f"Loading BEST model: {checkpoint_path}")
elif os.path.exists(final_model_path):
    checkpoint_path = final_model_path
    print(f"Loading FINAL model: {checkpoint_path}")
else:
    # Find latest checkpoint
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pth'))
    if checkpoints:
        checkpoint_path = sorted(checkpoints)[-1]
        print(f"Loading latest checkpoint: {checkpoint_path}")
    else:
        raise FileNotFoundError("No checkpoints found!")

# ============================================================================
# Load Model
# ============================================================================

print("\n" + "=" * 60)
print("Loading Trained Model")
print("=" * 60)

checkpoint = torch.load(checkpoint_path, map_location='cpu')
config = checkpoint['config']

# Create model
model = VAE(latent_dim=config['latent_dim']).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"‚úÖ Model loaded successfully!")
print(f"   Epoch: {checkpoint['epoch']}")
print(f"   Latent dimension: {config['latent_dim']}")
if 'val_loss' in checkpoint:
    print(f"   Validation loss: {checkpoint['val_loss']:.4f}")

# ============================================================================
# Load Test Dataset
# ============================================================================

print("\n" + "=" * 60)
print("Loading Test Dataset")
print("=" * 60)

celeba_root = './CelebA'
partition_file = os.path.join(celeba_root, 'Eval', 'list_eval_partition.txt')

transform = transforms.Compose([
    transforms.Resize(config['img_size']),
    transforms.CenterCrop(config['img_size']),
    transforms.ToTensor(),
])

test_dataset = CelebADataset(
    root_dir=celeba_root,
    partition_file=partition_file,
    split='test',
    transform=transform,
    subset_size=2000  # Use 2000 test images
)

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0
)

print(f"‚úÖ Test dataset loaded: {len(test_dataset)} images")

# ============================================================================
# Visualization 1: Reconstructions
# ============================================================================

print("\n" + "=" * 60)
print("Generating Reconstructions")
print("=" * 60)

def visualize_reconstructions(model, dataset, num_samples=10):
    """Visualize original and reconstructed images"""
    model.eval()
    
    # Get random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 2, 4))
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            img = dataset[idx].unsqueeze(0).to(device)
            recon, _, _ = model(img)
            
            # Original
            img_np = img[0].cpu().permute(1, 2, 0).numpy()
            axes[0, i].imshow(img_np)
            axes[0, i].axis('off')
            if i == 0:
                axes[0, i].set_ylabel('Original', fontsize=12, fontweight='bold')
            
            # Reconstruction
            recon_np = recon[0].cpu().permute(1, 2, 0).numpy()
            axes[1, i].imshow(recon_np)
            axes[1, i].axis('off')
            if i == 0:
                axes[1, i].set_ylabel('Reconstructed', fontsize=12, fontweight='bold')
    
    plt.suptitle('VAE Reconstructions on Test Set', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return fig

fig1 = visualize_reconstructions(model, test_dataset, num_samples=10)
print("‚úÖ Reconstructions displayed above")

# ============================================================================
# Visualization 2: Random Samples from Latent Space
# ============================================================================

print("\n" + "=" * 60)
print("Generating Random Samples from Latent Space")
print("=" * 60)

def generate_random_samples(model, num_samples=16):
    """Generate random samples from the latent space"""
    model.eval()
    
    with torch.no_grad():
        # Sample from standard normal
        z = torch.randn(num_samples, model.latent_dim).to(device)
        samples = model.decode(z).cpu()
    
    # Plot
    grid_size = int(np.sqrt(num_samples))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(grid_size * 2.5, grid_size * 2.5))
    axes = axes.flatten()
    
    for i in range(num_samples):
        sample_np = samples[i].permute(1, 2, 0).numpy()
        axes[i].imshow(sample_np)
        axes[i].axis('off')
    
    plt.suptitle('Random Samples from Latent Space (z ~ N(0,I))', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return fig

fig2 = generate_random_samples(model, num_samples=16)
print("‚úÖ Random samples displayed above")

# ============================================================================
# Visualization 3: Latent Space Interpolation
# ============================================================================

print("\n" + "=" * 60)
print("Generating Latent Space Interpolations")
print("=" * 60)

def interpolate_latent(model, dataset, num_pairs=3, num_steps=10):
    """Interpolate between pairs of images in latent space"""
    model.eval()
    
    fig, axes = plt.subplots(num_pairs, num_steps, figsize=(num_steps * 2, num_pairs * 2))
    
    with torch.no_grad():
        for pair_idx in range(num_pairs):
            # Get two random images
            idx1, idx2 = np.random.choice(len(dataset), 2, replace=False)
            img1 = dataset[idx1].unsqueeze(0).to(device)
            img2 = dataset[idx2].unsqueeze(0).to(device)
            
            # Encode to latent space
            mu1, _ = model.encode(img1)
            mu2, _ = model.encode(img2)
            
            # Interpolate
            alphas = np.linspace(0, 1, num_steps)
            
            for step_idx, alpha in enumerate(alphas):
                # Linear interpolation in latent space
                z_interp = (1 - alpha) * mu1 + alpha * mu2
                
                # Decode
                img_interp = model.decode(z_interp)
                
                # Display
                img_np = img_interp[0].cpu().permute(1, 2, 0).numpy()
                
                if num_pairs == 1:
                    axes[step_idx].imshow(img_np)
                    axes[step_idx].axis('off')
                    if step_idx == 0:
                        axes[step_idx].set_title('Start', fontsize=10)
                    elif step_idx == num_steps - 1:
                        axes[step_idx].set_title('End', fontsize=10)
                else:
                    axes[pair_idx, step_idx].imshow(img_np)
                    axes[pair_idx, step_idx].axis('off')
                    
                    if step_idx == 0:
                        axes[pair_idx, step_idx].set_ylabel(f'Pair {pair_idx+1}', fontsize=10, fontweight='bold')
                    if pair_idx == 0:
                        if step_idx == 0:
                            axes[pair_idx, step_idx].set_title('Start', fontsize=10)
                        elif step_idx == num_steps - 1:
                            axes[pair_idx, step_idx].set_title('End', fontsize=10)
    
    plt.suptitle('Latent Space Interpolations (Smooth Transitions)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return fig

fig3 = interpolate_latent(model, test_dataset, num_pairs=3, num_steps=10)
print("‚úÖ Interpolations displayed above")

# ============================================================================
# Visualization 4: Latent Dimension Traversals
# ============================================================================

print("\n" + "=" * 60)
print("Generating Latent Dimension Traversals")
print("=" * 60)

def traverse_latent_dims(model, dataset, num_dims=10, num_steps=7, traversal_range=3.0):
    model.eval()
    
    with torch.no_grad():
        # Get a random image and encode it
        idx = np.random.choice(len(dataset))
        img = dataset[idx].unsqueeze(0).to(device)
        mu, _ = model.encode(img)
        
        # Select random dimensions to traverse
        dims_to_traverse = np.random.choice(model.latent_dim, num_dims, replace=False)
        
        fig, axes = plt.subplots(num_dims, num_steps, figsize=(num_steps * 1.8, num_dims * 1.8))
        
        traversal_values = np.linspace(-traversal_range, traversal_range, num_steps)
        
        for dim_idx, dim in enumerate(dims_to_traverse):
            for step_idx, value in enumerate(traversal_values):
                # Copy the latent vector and modify one dimension
                z_modified = mu.clone()
                z_modified[0, dim] = value
                
                # Decode
                img_generated = model.decode(z_modified)
                img_np = img_generated[0].cpu().permute(1, 2, 0).numpy()
                
                # Display
                axes[dim_idx, step_idx].imshow(img_np)
                axes[dim_idx, step_idx].axis('off')
                
                if step_idx == 0:
                    axes[dim_idx, step_idx].set_ylabel(f'Dim {dim}', fontsize=9, fontweight='bold')
                
                if dim_idx == 0:
                    axes[dim_idx, step_idx].set_title(f'{value:.1f}', fontsize=9)
        
        plt.suptitle(f'Latent Dimension Traversals (Range: [{-traversal_range:.1f}, {traversal_range:.1f}])', 
                     fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    return fig

fig4 = traverse_latent_dims(model, test_dataset, num_dims=10, num_steps=7, traversal_range=7.0) #TOBECHANGEDANDCHECKED - abcd
print("‚úÖ Latent traversals displayed above")
print("   Each row shows how varying a single latent dimension affects the generated face")

# ============================================================================
# Compute Reconstruction Quality Metrics
# ============================================================================

print("\n" + "=" * 60)
print("Computing Reconstruction Quality Metrics")
print("=" * 60)

def compute_reconstruction_metrics(model, test_loader, num_batches=10):
    """Compute average reconstruction metrics on test set"""
    model.eval()
    
    mse_total = 0
    mae_total = 0
    n_samples = 0
    
    with torch.no_grad():
        for batch_idx, data in enumerate(test_loader):
            if batch_idx >= num_batches:
                break
            
            data = data.to(device)
            recon, _, _ = model(data)
            
            # MSE (Mean Squared Error)
            mse = torch.mean((recon - data) ** 2, dim=[1, 2, 3])
            mse_total += mse.sum().item()
            
            # MAE (Mean Absolute Error)
            mae = torch.mean(torch.abs(recon - data), dim=[1, 2, 3])
            mae_total += mae.sum().item()
            
            n_samples += data.size(0)
    
    avg_mse = mse_total / n_samples
    avg_mae = mae_total / n_samples
    
    return avg_mse, avg_mae

avg_mse, avg_mae = compute_reconstruction_metrics(model, test_loader, num_batches=10)

print(f"‚úÖ Reconstruction Quality Metrics (on {10} test batches):")
print(f"   Mean Squared Error (MSE): {avg_mse:.6f}")
print(f"   Mean Absolute Error (MAE): {avg_mae:.6f}")

# ============================================================================
# Summary
# ============================================================================

print("\n" + "=" * 60)
print("Testing Complete - Summary")
print("=" * 60)

print(f"\nüìä Model Information:")
print(f"   Latent Dimension: {config['latent_dim']}")
print(f"   Image Size: {config['img_size']}√ó{config['img_size']}")
print(f"   Training Epochs: {checkpoint['epoch']}")
print(f"   Beta (Œ≤): {config['beta']}")

print(f"\nüìà Performance:")
if 'val_loss' in checkpoint:
    print(f"   Best Validation Loss: {checkpoint['val_loss']:.4f}")
print(f"   Test MSE: {avg_mse:.6f}")
print(f"   Test MAE: {avg_mae:.6f}")

print(f"\n‚úÖ Generated Visualizations:")
print(f"   1. ‚úì Original vs. Reconstructed (10 samples)")
print(f"   2. ‚úì Random Samples from Latent Space (16 samples)")
print(f"   3. ‚úì Latent Space Interpolations (3 pairs)")
print(f"   4. ‚úì Latent Dimension Traversals (10 dimensions)")

print(f"\nüíæ Model checkpoint location:")
print(f"   {checkpoint_path}")

print("\n" + "=" * 60)

VAE Model Testing and Visualization

Using output directory: ./outputs/vae_baseline_latent128_20251231_170110
Loading BEST model: ./outputs/vae_baseline_latent128_20251231_170110/checkpoints/best_model.pth

Loading Trained Model


NameError: name 'VAE' is not defined