In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
from pytorch_msssim import ssim
import time
import numpy as np
from sklearn.manifold import TSNE

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================
# 1. Model Definitions (Same as before)
# =============================================

class Encoder(nn.Module):
    def __init__(self, latent_dim=10):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 400),
            nn.ReLU(),
            nn.Linear(400, latent_dim)
        )

    def forward(self, x):
        return self.model(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim=10):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 28*28),
            nn.Sigmoid(),
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, latent_dim=10):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 100),
            nn.ReLU(),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.model(z)

# =============================================
# 2. Training Function with Time Measurement
# =============================================

def train_aae(epochs=20, batch_size=128, latent_dim=10):
    # Load data
    transform = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data = datasets.MNIST(root="./data", train=False, transform=transform)
    test_loader = DataLoader(test_data, batch_size=128, shuffle=True)

    # Initialize models
    encoder = Encoder(latent_dim).to(device)
    decoder = Decoder(latent_dim).to(device)
    discriminator = Discriminator(latent_dim).to(device)

    # Loss functions and optimizers
    reconstruction_loss = nn.MSELoss()
    adversarial_loss = nn.BCELoss()
    opt_enc_dec = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
    opt_disc = optim.Adam(discriminator.parameters(), lr=1e-3)
    opt_enc_adv = optim.Adam(encoder.parameters(), lr=1e-3)

    # Data storage
    metrics_history = {
        'train_recon': [],
        'train_disc': [],
        'train_adv': [],
        'test_mse': [],
        'test_psnr': [],
        'test_ssim': []
    }

    # Start training timer
    train_start_time = time.time()

    for epoch in range(epochs):
        epoch_start_time = time.time()
        epoch_recon, epoch_disc, epoch_adv = 0, 0, 0

        # Training phase
        encoder.train()
        decoder.train()
        discriminator.train()

        for images, _ in train_loader:
            images = images.to(device)

            # Reconstruction phase
            z = encoder(images)
            recon = decoder(z)
            loss_recon = reconstruction_loss(recon, images)

            opt_enc_dec.zero_grad()
            loss_recon.backward()
            opt_enc_dec.step()

            # Discriminator phase
            real_z = torch.randn(z.size(), device=device)
            fake_z = encoder(images).detach()

            pred_real = discriminator(real_z)
            pred_fake = discriminator(fake_z)

            loss_disc = adversarial_loss(pred_real, torch.ones_like(pred_real)) + \
                        adversarial_loss(pred_fake, torch.zeros_like(pred_fake))

            opt_disc.zero_grad()
            loss_disc.backward()
            opt_disc.step()

            # Adversarial phase
            z = encoder(images)
            pred = discriminator(z)
            loss_enc_adv = adversarial_loss(pred, torch.ones_like(pred))

            opt_enc_adv.zero_grad()
            loss_enc_adv.backward()
            opt_enc_adv.step()

            # Accumulate losses
            epoch_recon += loss_recon.item()
            epoch_disc += loss_disc.item()
            epoch_adv += loss_enc_adv.item()

        # Store training losses
        metrics_history['train_recon'].append(epoch_recon/len(train_loader))
        metrics_history['train_disc'].append(epoch_disc/len(train_loader))
        metrics_history['train_adv'].append(epoch_adv/len(train_loader))

        # Evaluation phase
        test_metrics = evaluate_reconstruction(encoder, decoder, test_loader)

        metrics_history['test_mse'].append(test_metrics['mse'])
        metrics_history['test_psnr'].append(test_metrics['psnr'])
        metrics_history['test_ssim'].append(test_metrics['ssim'])

        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time

        # Display epoch info
        print(f"Epoch [{epoch+1}/{epochs}]  Time: {epoch_time:.2f}s  "
              f"Recon: {metrics_history['train_recon'][-1]:.4f}  Disc: {metrics_history['train_disc'][-1]:.4f}  "
              f"Test MSE: {metrics_history['test_mse'][-1]:.4f}  PSNR: {metrics_history['test_psnr'][-1]:.2f}dB  SSIM: {metrics_history['test_ssim'][-1]:.4f}")

    # Calculate total training time
    total_train_time = time.time() - train_start_time

    # Display final metrics
    final_metrics = {
        'latent_dim': latent_dim,
        'final_mse': metrics_history['test_mse'][-1],
        'final_psnr': metrics_history['test_psnr'][-1],
        'final_ssim': metrics_history['test_ssim'][-1],
        'total_train_time_sec': total_train_time,
        'avg_epoch_time_sec': total_train_time/epochs
    }

    print("\nFinal Metrics:")
    print(final_metrics)

    # Visualizations
    visualize_reconstructions(encoder, decoder, test_loader)
    visualize_latent_space(encoder, test_loader)

    return encoder, decoder, metrics_history, final_metrics

# =============================================
# 3. Evaluation Functions (Same as before)
# =============================================

def evaluate_reconstruction(encoder, decoder, test_loader):
    encoder.eval()
    decoder.eval()

    total_mse = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            z = encoder(images)
            reconstructions = decoder(z)

            # MSE
            mse = nn.MSELoss()(reconstructions, images)
            total_mse += mse.item() * images.size(0)

            # PSNR
            max_pixel = 1.0
            psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
            total_psnr += psnr.item() * images.size(0)

            # SSIM
            ssim_val = ssim(reconstructions, images, data_range=1.0, size_average=False)
            total_ssim += ssim_val.sum().item()

            total_samples += images.size(0)

    return {
        'mse': total_mse / total_samples,
        'psnr': total_psnr / total_samples,
        'ssim': total_ssim / total_samples
    }

# =============================================
# 4. Visualization Functions
# =============================================

def visualize_reconstructions(encoder, decoder, test_loader):
    """Visualize original and reconstructed images."""
    encoder.eval()
    decoder.eval()

    with torch.no_grad():
        images, _ = next(iter(test_loader))
        images = images.to(device)

        # Get reconstructions
        z = encoder(images)
        reconstructions = decoder(z)

        # Plot original and reconstructed images
        fig, axes = plt.subplots(2, 10, figsize=(15, 4))
        for i in range(10):
            axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
            axes[0, i].axis('off')
            axes[1, i].imshow(reconstructions[i].cpu().squeeze(), cmap='gray')
            axes[1, i].axis('off')

        plt.show()

def visualize_latent_space(encoder, test_loader):
    """Visualize the latent space using t-SNE."""
    encoder.eval()

    latent_vectors = []
    labels = []

    with torch.no_grad():
        for images, targets in test_loader:
            images = images.to(device)
            z = encoder(images)
            latent_vectors.append(z.cpu().numpy())
            labels.append(targets.cpu().numpy())

    latent_vectors = np.concatenate(latent_vectors, axis=0)
    labels = np.concatenate(labels, axis=0)

    # t-SNE visualization
    tsne = TSNE(n_components=2, random_state=42)
    latent_tsne = tsne.fit_transform(latent_vectors)

    plt.figure(figsize=(8, 6))
    plt.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar()
    plt.title("t-SNE visualization of Latent Space")
    plt.show()

# =============================================
# 5. Main Execution with Time Tracking
# =============================================

if __name__ == "__main__":
    latent_dims_to_test = [5, 10, 20]
    all_metrics = {}

    for latent_dim in latent_dims_to_test:
        print(f"\n{'='*50}")
        print(f"Training with Latent Dimension: {latent_dim}")
        print(f"{'='*50}")

        # Train and evaluate
        encoder, decoder, _, final_metrics = train_aae(latent_dim=latent_dim)
        all_metrics[f'LatentDim={latent_dim}'] = final_metrics

    print("\nTraining and evaluation complete!")
    print(f"Final metrics: {all_metrics}")


Using device: cuda

Training with Latent Dimension: 5


100%|██████████| 9.91M/9.91M [00:00<00:00, 20.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 498kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.47MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.66MB/s]


Epoch [1/20]  Time: 11.00s  Recon: 0.0512  Disc: 1.3533  Test MSE: 0.0391  PSNR: 14.08dB  SSIM: 0.4480
Epoch [2/20]  Time: 9.25s  Recon: 0.0384  Disc: 1.2429  Test MSE: 0.0380  PSNR: 14.20dB  SSIM: 0.4719
Epoch [3/20]  Time: 9.91s  Recon: 0.0370  Disc: 1.3800  Test MSE: 0.0355  PSNR: 14.51dB  SSIM: 0.5035
Epoch [4/20]  Time: 9.89s  Recon: 0.0352  Disc: 1.3691  Test MSE: 0.0349  PSNR: 14.58dB  SSIM: 0.5267
Epoch [5/20]  Time: 10.46s  Recon: 0.0340  Disc: 1.3869  Test MSE: 0.0329  PSNR: 14.84dB  SSIM: 0.5530
Epoch [6/20]  Time: 9.96s  Recon: 0.0330  Disc: 1.3654  Test MSE: 0.0344  PSNR: 14.64dB  SSIM: 0.5502
Epoch [7/20]  Time: 9.67s  Recon: 0.0358  Disc: 1.3774  Test MSE: 0.0337  PSNR: 14.73dB  SSIM: 0.5499


In [2]:
!pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch_msssim
Successfully installed pytorch_msssim-1.0.0
