In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
from pytorch_msssim import ssim
import os
from datetime import datetime
import time
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================
# 1. DAE Model Definitions
# =============================================

class Encoder(nn.Module):
    def __init__(self, latent_dim=32):  # Increased latent dim for better denoising
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim=32):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid(),
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, z):
        return self.decoder(z)

# =============================================
# 2. Noise Functions
# =============================================

def add_gaussian_noise(images, noise_factor=0.5):
    """Add Gaussian noise to images"""
    noisy = images + torch.randn_like(images) * noise_factor
    return torch.clamp(noisy, 0., 1.)

def add_salt_pepper_noise(images, prob=0.1):
    """Add salt and pepper noise to images"""
    mask = torch.rand_like(images) < prob
    salt = torch.rand_like(images) > 0.5
    noisy = images.clone()
    noisy[mask] = salt[mask].float()
    return noisy

# =============================================
# 3. DAE Training Function
# =============================================

def train_dae(epochs=50, batch_size=128, latent_dim=32, noise_type='gaussian'):
    # Create results directory
    os.makedirs('results', exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    excel_filename = f'results/dae_metrics_{noise_type}_{timestamp}.xlsx'
    writer = pd.ExcelWriter(excel_filename, engine='openpyxl')

    # Load data
    transform = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data = datasets.MNIST(root="./data", train=False, transform=transform)
    test_loader = DataLoader(test_data, batch_size=128, shuffle=True)

    # Initialize models
    encoder = Encoder(latent_dim).to(device)
    decoder = Decoder(latent_dim).to(device)

    # Optimizer and loss
    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
    criterion = nn.MSELoss()

    # Data storage
    epoch_data = []
    metrics_history = {
        'train_loss': [],
        'test_mse': [],
        'test_psnr': [],
        'test_ssim': []
    }

    # Start training timer
    train_start_time = time.time()

    for epoch in range(epochs):
        epoch_start_time = time.time()
        epoch_loss = 0

        # Training phase
        encoder.train()
        decoder.train()

        for clean_images, _ in train_loader:
            clean_images = clean_images.to(device)

            # Add noise based on selected type
            if noise_type == 'gaussian':
                noisy_images = add_gaussian_noise(clean_images)
            elif noise_type == 'salt_pepper':
                noisy_images = add_salt_pepper_noise(clean_images)
            else:
                raise ValueError(f"Unknown noise type: {noise_type}")

            # Forward pass
            z = encoder(noisy_images)
            reconstructions = decoder(z)

            # Loss calculation (compare to clean images)
            loss = criterion(reconstructions, clean_images)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accumulate loss
            epoch_loss += loss.item()

        # Store training loss (averaged per sample)
        metrics_history['train_loss'].append(epoch_loss/len(train_loader))

        # Evaluation phase
        test_start_time = time.time()
        test_metrics = evaluate_denoising(encoder, decoder, test_loader, noise_type)
        test_time = time.time() - test_start_time

        metrics_history['test_mse'].append(test_metrics['mse'])
        metrics_history['test_psnr'].append(test_metrics['psnr'])
        metrics_history['test_ssim'].append(test_metrics['ssim'])

        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time

        # Store epoch info
        epoch_info = {
            'epoch': epoch+1,
            'train_loss': metrics_history['train_loss'][-1],
            'test_mse': metrics_history['test_mse'][-1],
            'test_psnr': metrics_history['test_psnr'][-1],
            'test_ssim': metrics_history['test_ssim'][-1],
            'epoch_time_sec': epoch_time,
            'test_time_sec': test_time,
            'latent_dim': latent_dim,
            'noise_type': noise_type
        }
        epoch_data.append(epoch_info)

        print(f"Epoch [{epoch+1}/{epochs}]  Time: {epoch_time:.2f}s  "
              f"Loss: {epoch_info['train_loss']:.4f}  "
              f"Test MSE: {epoch_info['test_mse']:.4f}  "
              f"PSNR: {epoch_info['test_psnr']:.2f}dB")

    # Calculate total training time
    total_train_time = time.time() - train_start_time

    # Save to Excel
    df_epochs = pd.DataFrame(epoch_data)
    df_epochs.to_excel(writer, sheet_name='training_metrics', index=False)

    # Save final metrics
    final_metrics = {
        'latent_dim': latent_dim,
        'noise_type': noise_type,
        'final_mse': metrics_history['test_mse'][-1],
        'final_psnr': metrics_history['test_psnr'][-1],
        'final_ssim': metrics_history['test_ssim'][-1],
        'final_train_loss': metrics_history['train_loss'][-1],
        'compression_ratio': evaluate_compression(latent_dim)['compression_ratio'],
        'bits_per_pixel': evaluate_compression(latent_dim)['bits_per_pixel'],
        'total_train_time_sec': total_train_time,
        'avg_epoch_time_sec': total_train_time/epochs,
        'avg_test_time_sec': sum([x['test_time_sec'] for x in epoch_data])/epochs
    }
    pd.DataFrame([final_metrics]).to_excel(writer, sheet_name='final_metrics', index=False)

    writer.close()
    plot_training_progress(metrics_history, latent_dim, noise_type)

    # Save model
    torch.save({
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'latent_dim': latent_dim,
        'noise_type': noise_type
    }, f'results/dae_{noise_type}_latent{latent_dim}_model.pth')

    return encoder, decoder, metrics_history, final_metrics

# =============================================
# 4. DAE Evaluation Functions
# =============================================

def evaluate_denoising(encoder, decoder, test_loader, noise_type):
    encoder.eval()
    decoder.eval()

    total_mse = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    total_samples = 0

    with torch.no_grad():
        for clean_images, _ in test_loader:
            clean_images = clean_images.to(device)

            # Add noise
            if noise_type == 'gaussian':
                noisy_images = add_gaussian_noise(clean_images)
            elif noise_type == 'salt_pepper':
                noisy_images = add_salt_pepper_noise(clean_images)

            # Forward pass
            z = encoder(noisy_images)
            reconstructions = decoder(z)

            # Metrics (compare to clean images)
            mse = nn.MSELoss()(reconstructions, clean_images)
            total_mse += mse.item() * clean_images.size(0)

            max_pixel = 1.0
            psnr = 20 * torch.log10(max_pixel / torch.sqrt(mse))
            total_psnr += psnr.item() * clean_images.size(0)

            ssim_val = ssim(reconstructions, clean_images, data_range=1.0, size_average=False)
            total_ssim += ssim_val.sum().item()

            total_samples += clean_images.size(0)

    return {
        'mse': total_mse / total_samples,
        'psnr': total_psnr / total_samples,
        'ssim': total_ssim / total_samples
    }

def visualize_denoising(encoder, decoder, test_loader, noise_type, num_samples=5):
    encoder.eval()
    decoder.eval()

    clean_images, _ = next(iter(test_loader))
    clean_images = clean_images[:num_samples].to(device)

    # Add noise
    if noise_type == 'gaussian':
        noisy_images = add_gaussian_noise(clean_images)
    elif noise_type == 'salt_pepper':
        noisy_images = add_salt_pepper_noise(clean_images)

    with torch.no_grad():
        z = encoder(noisy_images)
        reconstructions = decoder(z)

    plt.figure(figsize=(12, 6))
    for i in range(num_samples):
        # Original
        plt.subplot(3, num_samples, i+1)
        plt.imshow(clean_images[i].cpu().squeeze(), cmap='gray')
        plt.title("Original" if i == 0 else "")
        plt.axis('off')

        # Noisy
        plt.subplot(3, num_samples, num_samples+i+1)
        plt.imshow(noisy_images[i].cpu().squeeze(), cmap='gray')
        plt.title("Noisy" if i == 0 else "")
        plt.axis('off')

        # Reconstructed
        plt.subplot(3, num_samples, 2*num_samples+i+1)
        plt.imshow(reconstructions[i].cpu().squeeze(), cmap='gray')
        plt.title("Reconstructed" if i == 0 else "")
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(f'results/denoising_{noise_type}_examples.png', dpi=300)
    plt.close()

def plot_training_progress(metrics, latent_dim, noise_type):
    plt.figure(figsize=(15, 5))

    # Loss plot
    plt.subplot(1, 3, 1)
    plt.plot(metrics['train_loss'], label='Training Loss')
    plt.title(f'Training Loss ({noise_type} noise)\nLatent Dim: {latent_dim}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    # Test metrics
    plt.subplot(1, 3, 2)
    plt.plot(metrics['test_mse'], label='Test MSE')
    plt.title('Test MSE')
    plt.xlabel('Epoch')
    plt.grid(True)

    plt.subplot(1, 3, 3)
    plt.plot(metrics['test_psnr'], label='Test PSNR')
    plt.title('Test PSNR (dB)')
    plt.xlabel('Epoch')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'results/dae_training_progress_{noise_type}.png', dpi=300)
    plt.close()

def evaluate_compression(latent_dim):
    original_size = 28 * 28 * 8  # MNIST pixels * 8 bits
    compressed_size = latent_dim * 32  # 32-bit floats
    compression_ratio = original_size / compressed_size
    bpp = compressed_size / (28 * 28)

    return {
        'compression_ratio': compression_ratio,
        'bits_per_pixel': bpp
    }

# =============================================
# 5. Main Execution
# =============================================

if __name__ == "__main__":
    # Configuration
    latent_dims_to_test = [32]  # Typically DAEs use larger latent spaces
    noise_types = ['gaussian', 'salt_pepper']  # Test both noise types
    epochs = 50  # DAEs often need more epochs

    all_metrics = {}

    for noise_type in noise_types:
        for latent_dim in latent_dims_to_test:
            print(f"\n{'='*50}")
            print(f"Training DAE with {noise_type} noise, Latent Dim: {latent_dim}")
            print(f"{'='*50}")

            # Train and evaluate
            encoder, decoder, _, final_metrics = train_dae(
                epochs=epochs,
                latent_dim=latent_dim,
                noise_type=noise_type
            )
            all_metrics[f'{noise_type}_latent{latent_dim}'] = final_metrics

            # Visualizations
            test_data = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor())
            test_loader = DataLoader(test_data, batch_size=128, shuffle=True)
            visualize_denoising(encoder, decoder, test_loader, noise_type)

    # Save comparison
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    df_comparison = pd.DataFrame.from_dict(all_metrics, orient='index')
    df_comparison.to_excel(f'results/dae_comparison_{timestamp}.xlsx', index=True)

    print("\nDAE Training and evaluation complete!")
    print("Results saved in the 'results' directory")

Using device: cuda

Training DAE with gaussian noise, Latent Dim: 32


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 427kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.96MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.9MB/s]


Epoch [1/50]  Time: 9.58s  Loss: 0.0547  Test MSE: 0.0353  PSNR: 14.53dB
Epoch [2/50]  Time: 8.67s  Loss: 0.0306  Test MSE: 0.0269  PSNR: 15.70dB
Epoch [3/50]  Time: 8.73s  Loss: 0.0247  Test MSE: 0.0226  PSNR: 16.46dB
Epoch [4/50]  Time: 8.81s  Loss: 0.0215  Test MSE: 0.0202  PSNR: 16.96dB
Epoch [5/50]  Time: 8.98s  Loss: 0.0197  Test MSE: 0.0189  PSNR: 17.24dB
Epoch [6/50]  Time: 8.71s  Loss: 0.0185  Test MSE: 0.0175  PSNR: 17.56dB
Epoch [7/50]  Time: 8.69s  Loss: 0.0176  Test MSE: 0.0168  PSNR: 17.75dB
Epoch [8/50]  Time: 8.43s  Loss: 0.0168  Test MSE: 0.0165  PSNR: 17.83dB
Epoch [9/50]  Time: 8.46s  Loss: 0.0163  Test MSE: 0.0159  PSNR: 17.99dB
Epoch [10/50]  Time: 8.67s  Loss: 0.0157  Test MSE: 0.0154  PSNR: 18.14dB
Epoch [11/50]  Time: 8.74s  Loss: 0.0154  Test MSE: 0.0150  PSNR: 18.24dB
Epoch [12/50]  Time: 8.03s  Loss: 0.0150  Test MSE: 0.0149  PSNR: 18.27dB
Epoch [13/50]  Time: 8.77s  Loss: 0.0147  Test MSE: 0.0144  PSNR: 18.41dB
Epoch [14/50]  Time: 8.67s  Loss: 0.0145  Test 

In [2]:
!pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch_msssim
Successfully installed pytorch_msssim-1.0.0
