In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
from pytorch_msssim import ssim
from sklearn.manifold import TSNE
import os
from datetime import datetime
import time
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================
# 1. Model Definitions (VAE and DAE)
# =============================================

class VAE_Encoder(nn.Module):
    def __init__(self, latent_dim=10):
        super(VAE_Encoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 400),
            nn.ReLU()
        )
        self.fc_mean = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)

    def forward(self, x):
        h = self.encoder(x)
        return self.fc_mean(h), self.fc_logvar(h)

class VAE_Decoder(nn.Module):
    def __init__(self, latent_dim=10):
        super(VAE_Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 28*28),
            nn.Sigmoid(),
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, z):
        return self.decoder(z)

class DAE_Encoder(nn.Module):
    def __init__(self, latent_dim=32):
        super(DAE_Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

    def forward(self, x):
        return self.encoder(x)

class DAE_Decoder(nn.Module):
    def __init__(self, latent_dim=32):
        super(DAE_Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid(),
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, z):
        return self.decoder(z)

# =============================================
# 2. Helper Functions
# =============================================

def dice_loss(preds, targets, epsilon=1e-6):
    preds = preds.view(preds.size(0), -1)
    targets = targets.view(targets.size(0), -1)
    intersection = (preds * targets).sum(dim=1)
    union = preds.sum(dim=1) + targets.sum(dim=1)
    dice = (2. * intersection + epsilon) / (union + epsilon)
    return 1 - dice.mean()

def add_gaussian_noise(images, noise_factor=0.5):
    noisy = images + torch.randn_like(images) * noise_factor
    return torch.clamp(noisy, 0., 1.)

def add_salt_pepper_noise(images, prob=0.1):
    mask = torch.rand_like(images) < prob
    salt = torch.rand_like(images) > 0.5
    noisy = images.clone()
    noisy[mask] = salt[mask].float()
    return noisy

def calculate_psnr(mse):
    max_pixel = 1.0
    if mse == 0:
        return 100  # Perfect reconstruction
    return 20 * np.log10(max_pixel / np.sqrt(mse))

def evaluate_compression(latent_dim):
    original_size = 28 * 28 * 8  # MNIST pixels * 8 bits
    compressed_size = latent_dim * 32  # 32-bit floats
    compression_ratio = original_size / compressed_size
    bpp = compressed_size / (28 * 28)
    return {'compression_ratio': compression_ratio, 'bits_per_pixel': bpp}

# =============================================
# 3. Training Functions
# =============================================

def train_vae(epochs=20, batch_size=128, latent_dim=10):
    os.makedirs('results', exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    excel_filename = f'results/vae_metrics_{timestamp}.xlsx'
    writer = pd.ExcelWriter(excel_filename, engine='openpyxl')

    transform = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data = datasets.MNIST(root="./data", train=False, transform=transform)
    test_loader = DataLoader(test_data, batch_size=128, shuffle=True)

    encoder = VAE_Encoder(latent_dim).to(device)
    decoder = VAE_Decoder(latent_dim).to(device)
    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

    metrics_history = {
        'total_loss': [], 'recon_loss': [], 'kl_loss': [], 'dice_loss': [],
        'train_mse': [], 'train_psnr': [],  # Added training PSNR
        'test_mse': [], 'test_psnr': [], 'test_ssim': [], 'test_dice': []
    }
    epoch_data = []

    def reparameterize(mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    train_start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        epoch_recon, epoch_kl, epoch_total, epoch_dice = 0, 0, 0, 0
        epoch_train_mse, epoch_train_psnr = 0, 0  # Track training PSNR
        
        encoder.train()
        decoder.train()
        
        for images, _ in train_loader:
            images = images.to(device)
            mu, logvar = encoder(images)
            z = reparameterize(mu, logvar)
            recon = decoder(z)
            
            # Calculate losses
            recon_loss = nn.MSELoss(reduction='sum')(recon, images)
            kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            dice = dice_loss(recon, images)
            total_loss = recon_loss + kl_div + dice
            
            # Calculate training PSNR
            mse = nn.MSELoss(reduction='mean')(recon, images).item()
            psnr = calculate_psnr(mse)
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            # Accumulate metrics
            epoch_recon += recon_loss.item()
            epoch_kl += kl_div.item()
            epoch_dice += dice.item()
            epoch_total += total_loss.item()
            epoch_train_mse += mse * images.size(0)
            epoch_train_psnr += psnr * images.size(0)

        num_samples = len(train_loader.dataset)
        metrics_history['recon_loss'].append(epoch_recon/num_samples)
        metrics_history['kl_loss'].append(epoch_kl/num_samples)
        metrics_history['dice_loss'].append(epoch_dice/num_samples)
        metrics_history['total_loss'].append(epoch_total/num_samples)
        metrics_history['train_mse'].append(epoch_train_mse/num_samples)
        metrics_history['train_psnr'].append(epoch_train_psnr/num_samples)

        # Evaluation phase
        test_metrics = evaluate_reconstruction(encoder, decoder, test_loader, reparameterize)
        for k, v in test_metrics.items():
            metrics_history[f'test_{k}'].append(v)

        epoch_time = time.time() - epoch_start_time
        epoch_info = {
            'epoch': epoch+1,
            'total_loss': metrics_history['total_loss'][-1],
            'recon_loss': metrics_history['recon_loss'][-1],
            'kl_loss': metrics_history['kl_loss'][-1],
            'dice_loss': metrics_history['dice_loss'][-1],
            'train_mse': metrics_history['train_mse'][-1],
            'train_psnr': metrics_history['train_psnr'][-1],
            'test_mse': metrics_history['test_mse'][-1],
            'test_psnr': metrics_history['test_psnr'][-1],
            'test_ssim': metrics_history['test_ssim'][-1],
            'test_dice': metrics_history['test_dice'][-1],
            'epoch_time_sec': epoch_time,
            'latent_dim': latent_dim
        }
        epoch_data.append(epoch_info)
        
        print(f"Epoch [{epoch+1}/{epochs}]  Time: {epoch_time:.2f}s  "
              f"Loss: {epoch_info['total_loss']:.2f} (Recon: {epoch_info['recon_loss']:.2f}, "
              f"KL: {epoch_info['kl_loss']:.2f}, Dice: {epoch_info['dice_loss']:.4f})\n"
              f"Train PSNR: {epoch_info['train_psnr']:.2f}dB  Test PSNR: {epoch_info['test_psnr']:.2f}dB")

    total_train_time = time.time() - train_start_time

    df_epochs = pd.DataFrame(epoch_data)
    df_epochs.to_excel(writer, sheet_name=f'LatentDim_{latent_dim}_epochs', index=False)
    
    final_metrics = {
        'latent_dim': latent_dim,
        'final_recon_loss': metrics_history['recon_loss'][-1],
        'final_kl_loss': metrics_history['kl_loss'][-1],
        'final_dice_loss': metrics_history['dice_loss'][-1],
        'final_train_mse': metrics_history['train_mse'][-1],
        'final_train_psnr': metrics_history['train_psnr'][-1],
        'final_test_mse': metrics_history['test_mse'][-1],
        'final_test_psnr': metrics_history['test_psnr'][-1],
        'final_test_ssim': metrics_history['test_ssim'][-1],
        'final_test_dice': metrics_history['test_dice'][-1],
        'total_train_time_sec': total_train_time
    }
    pd.DataFrame([final_metrics]).to_excel(writer, sheet_name=f'LatentDim_{latent_dim}_final', index=False)
    
    writer.close()
    plot_training_progress(metrics_history, latent_dim, model_type='vae')
    visualize_latent_space(encoder, test_loader, reparameterize, latent_dim, model_type='vae')
    
    torch.save({
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'latent_dim': latent_dim
    }, f'results/vae_latent{latent_dim}_model.pth')
    
    return encoder, decoder, metrics_history, final_metrics

def train_dae(epochs=50, batch_size=128, latent_dim=32, noise_type='gaussian'):
    os.makedirs('results', exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    excel_filename = f'results/dae_metrics_{noise_type}_{timestamp}.xlsx'
    writer = pd.ExcelWriter(excel_filename, engine='openpyxl')

    transform = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data = datasets.MNIST(root="./data", train=False, transform=transform)
    test_loader = DataLoader(test_data, batch_size=128, shuffle=True)

    encoder = DAE_Encoder(latent_dim).to(device)
    decoder = DAE_Decoder(latent_dim).to(device)
    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)
    criterion = nn.MSELoss()

    metrics_history = {
        'train_loss': [], 'train_psnr': [],  # Added training PSNR
        'test_mse': [], 'test_psnr': [], 'test_ssim': []
    }
    epoch_data = []

    train_start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        epoch_loss, epoch_psnr = 0, 0  # Track training PSNR
        
        encoder.train()
        decoder.train()
        
        for clean_images, _ in train_loader:
            clean_images = clean_images.to(device)
            
            if noise_type == 'gaussian':
                noisy_images = add_gaussian_noise(clean_images)
            elif noise_type == 'salt_pepper':
                noisy_images = add_salt_pepper_noise(clean_images)
            
            z = encoder(noisy_images)
            reconstructions = decoder(z)
            loss = criterion(reconstructions, clean_images)
            
            # Calculate training PSNR
            mse = loss.item()
            psnr = calculate_psnr(mse)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_psnr += psnr * clean_images.size(0)

        metrics_history['train_loss'].append(epoch_loss/len(train_loader))
        metrics_history['train_psnr'].append(epoch_psnr/len(train_loader.dataset))
        
        test_metrics = evaluate_denoising(encoder, decoder, test_loader, noise_type)
        metrics_history['test_mse'].append(test_metrics['mse'])
        metrics_history['test_psnr'].append(test_metrics['psnr'])
        metrics_history['test_ssim'].append(test_metrics['ssim'])

        epoch_time = time.time() - epoch_start_time
        epoch_info = {
            'epoch': epoch+1,
            'train_loss': metrics_history['train_loss'][-1],
            'train_psnr': metrics_history['train_psnr'][-1],
            'test_mse': metrics_history['test_mse'][-1],
            'test_psnr': metrics_history['test_psnr'][-1],
            'test_ssim': metrics_history['test_ssim'][-1],
            'epoch_time_sec': epoch_time,
            'latent_dim': latent_dim,
            'noise_type': noise_type
        }
        epoch_data.append(epoch_info)
        
        print(f"Epoch [{epoch+1}/{epochs}]  Time: {epoch_time:.2f}s  "
              f"Loss: {epoch_info['train_loss']:.4f}  "
              f"Train PSNR: {epoch_info['train_psnr']:.2f}dB  "
              f"Test PSNR: {epoch_info['test_psnr']:.2f}dB")

    total_train_time = time.time() - train_start_time

    df_epochs = pd.DataFrame(epoch_data)
    df_epochs.to_excel(writer, sheet_name='training_metrics', index=False)
    
    final_metrics = {
        'latent_dim': latent_dim,
        'noise_type': noise_type,
        'final_train_loss': metrics_history['train_loss'][-1],
        'final_train_psnr': metrics_history['train_psnr'][-1],
        'final_test_mse': metrics_history['test_mse'][-1],
        'final_test_psnr': metrics_history['test_psnr'][-1],
        'final_test_ssim': metrics_history['test_ssim'][-1],
        'compression_ratio': evaluate_compression(latent_dim)['compression_ratio'],
        'bits_per_pixel': evaluate_compression(latent_dim)['bits_per_pixel'],
        'total_train_time_sec': total_train_time
    }
    pd.DataFrame([final_metrics]).to_excel(writer, sheet_name='final_metrics', index=False)
    
    writer.close()
    plot_training_progress(metrics_history, latent_dim, model_type='dae')
    visualize_latent_space(encoder, test_loader, lambda x: x, latent_dim, model_type='dae')
    
    torch.save({
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'latent_dim': latent_dim,
        'noise_type': noise_type
    }, f'results/dae_{noise_type}_latent{latent_dim}_model.pth')
    
    return encoder, decoder, metrics_history, final_metrics

# =============================================
# 4. Evaluation and Visualization Functions
# =============================================

def evaluate_reconstruction(encoder, decoder, test_loader, reparameterize_fn):
    encoder.eval()
    decoder.eval()
    
    total_mse, total_psnr, total_ssim, total_dice = 0.0, 0.0, 0.0, 0.0
    total_samples = 0
    
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            mu, logvar = encoder(images)
            z = reparameterize_fn(mu, logvar)
            reconstructions = decoder(z)
            
            mse = nn.MSELoss()(reconstructions, images)
            psnr = calculate_psnr(mse.item())
            ssim_val = ssim(reconstructions, images, data_range=1.0, size_average=False)
            dice = dice_loss(reconstructions, images)
            
            total_mse += mse.item() * images.size(0)
            total_psnr += psnr * images.size(0)
            total_ssim += ssim_val.sum().item()
            total_dice += dice.item() * images.size(0)
            total_samples += images.size(0)
    
    return {
        'mse': total_mse / total_samples,
        'psnr': total_psnr / total_samples,
        'ssim': total_ssim / total_samples,
        'dice': total_dice / total_samples
    }

def evaluate_denoising(encoder, decoder, test_loader, noise_type):
    encoder.eval()
    decoder.eval()
    
    total_mse, total_psnr, total_ssim = 0.0, 0.0, 0.0
    total_samples = 0
    
    with torch.no_grad():
        for clean_images, _ in test_loader:
            clean_images = clean_images.to(device)
            
            if noise_type == 'gaussian':
                noisy_images = add_gaussian_noise(clean_images)
            elif noise_type == 'salt_pepper':
                noisy_images = add_salt_pepper_noise(clean_images)
            
            z = encoder(noisy_images)
            reconstructions = decoder(z)
            
            mse = nn.MSELoss()(reconstructions, clean_images)
            psnr = calculate_psnr(mse.item())
            ssim_val = ssim(reconstructions, clean_images, data_range=1.0, size_average=False)
            
            total_mse += mse.item() * clean_images.size(0)
            total_psnr += psnr * clean_images.size(0)
            total_ssim += ssim_val.sum().item()
            total_samples += clean_images.size(0)
    
    return {
        'mse': total_mse / total_samples,
        'psnr': total_psnr / total_samples,
        'ssim': total_ssim / total_samples
    }

def visualize_latent_space(encoder, test_loader, reparameterize_fn, latent_dim, model_type='vae'):
    encoder.eval()
    latents = []
    labels = []
    
    with torch.no_grad():
        for images, label in test_loader:
            images = images.to(device)
            if model_type == 'vae':
                mu, logvar = encoder(images)
                z = reparameterize_fn(mu, logvar)
            else:  # DAE
                z = encoder(images)
            latents.append(z.cpu())
            labels.append(label)
    
    latents = torch.cat(latents).numpy()
    labels = torch.cat(labels).numpy()
    
    plt.figure(figsize=(15, 5))
    
    # t-SNE visualization
    if latent_dim >= 2:
        tsne = TSNE(n_components=2)
        latents_2d = tsne.fit_transform(latents)
        plt.subplot(1, 3, 1)
        plt.scatter(latents_2d[:, 0], latents_2d[:, 1], c=labels, cmap='tab10', alpha=0.6)
        plt.colorbar(ticks=range(10))
        plt.title('t-SNE Projection')
    
    # Dimension distributions
    plt.subplot(1, 3, 2)
    for dim in range(min(3, latent_dim)):
        plt.hist(latents[:, dim], bins=50, alpha=0.5, label=f'Dim {dim+1}')
    plt.title('Latent Dimension Distributions')
    plt.legend()
    
    # Statistics
    plt.subplot(1, 3, 3)
    dim_means = np.mean(latents, axis=0)
    dim_vars = np.var(latents, axis=0)
    plt.bar(range(latent_dim), dim_means, alpha=0.5, label='Mean')
    plt.bar(range(latent_dim), dim_vars, alpha=0.5, label='Variance')
    plt.title('Latent Dimension Statistics')
    plt.legend()
    
    plt.suptitle(f'{model_type.upper()} Latent Space (Dim={latent_dim})')
    plt.tight_layout()
    plt.savefig(f'results/{model_type}_latent_space_latent{latent_dim}.png', dpi=300)
    plt.close()

def plot_training_progress(metrics, latent_dim, model_type='vae'):
    plt.figure(figsize=(15, 5))
    
    if model_type == 'vae':
        plt.subplot(1, 4, 1)
        plt.plot(metrics['total_loss'], label='Total Loss')
        plt.plot(metrics['recon_loss'], label='Recon Loss')
        plt.plot(metrics['kl_loss'], label='KL Loss')
        plt.plot(metrics['dice_loss'], label='Dice Loss')
        plt.title('Training Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        
        plt.subplot(1, 4, 2)
        plt.plot(metrics['train_psnr'], label='Train PSNR')
        plt.plot(metrics['test_psnr'], label='Test PSNR')
        plt.title('PSNR (dB)')
        plt.xlabel('Epoch')
        plt.legend()
        plt.grid(True)
        
        plt.subplot(1, 4, 3)
        plt.plot(metrics['test_mse'], label='Test MSE')
        plt.title('Test MSE')
        plt.xlabel('Epoch')
        plt.grid(True)
        
        plt.subplot(1, 4, 4)
        plt.plot(metrics['test_ssim'], label='Test SSIM')
        plt.title('Test SSIM')
        plt.xlabel('Epoch')
        plt.grid(True)
    else:  # DAE
        plt.subplot(1, 4, 1)
        plt.plot(metrics['train_loss'], label='Training Loss')
        plt.title('Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        
        plt.subplot(1, 4, 2)
        plt.plot(metrics['train_psnr'], label='Train PSNR')
        plt.plot(metrics['test_psnr'], label='Test PSNR')
        plt.title('PSNR (dB)')
        plt.xlabel('Epoch')
        plt.legend()
        plt.grid(True)
        
        plt.subplot(1, 4, 3)
        plt.plot(metrics['test_mse'], label='Test MSE')
        plt.title('Test MSE')
        plt.xlabel('Epoch')
        plt.grid(True)
        
        plt.subplot(1, 4, 4)
        plt.plot(metrics['test_ssim'], label='Test SSIM')
        plt.title('Test SSIM')
        plt.xlabel('Epoch')
        plt.grid(True)
    
    plt.suptitle(f'{model_type.upper()} Training Progress (Latent Dim={latent_dim})')
    plt.tight_layout()
    plt.savefig(f'results/{model_type}_training_progress_latent{latent_dim}.png', dpi=300)
    plt.close()

def visualize_reconstructions(encoder, decoder, test_loader, reparameterize_fn, num_samples=5, model_type='vae'):
    encoder.eval()
    decoder.eval()
    
    images, _ = next(iter(test_loader))
    images = images[:num_samples].to(device)
    
    with torch.no_grad():
        if model_type == 'vae':
            mu, logvar = encoder(images)
            z = reparameterize_fn(mu, logvar)
        else:  # DAE
            z = encoder(images)
        reconstructions = decoder(z)
    
    plt.figure(figsize=(10, 4))
    for i in range(num_samples):
        plt.subplot(2, num_samples, i+1)
        plt.imshow(images[i].cpu().squeeze(), cmap='gray')
        plt.title("Original" if i == 0 else "")
        plt.axis('off')
        plt.subplot(2, num_samples, num_samples+i+1)
        plt.imshow(reconstructions[i].cpu().squeeze(), cmap='gray')
        plt.title("Reconstructed" if i == 0 else "")
        plt.axis('off')
    
    plt.suptitle(f'{model_type.upper()} Reconstructions')
    plt.tight_layout()
    plt.savefig(f'results/{model_type}_reconstructions_latent{latent_dim}.png', dpi=300)
    plt.close()

# =============================================
# 5. Main Execution
# =============================================

if __name__ == "__main__":
    # VAE Configuration
    vae_latent_dims = [5, 10, 20]
    vae_metrics = {}
    
    # DAE Configuration
    dae_latent_dims = [32]
    noise_types = ['gaussian', 'salt_pepper']
    dae_metrics = {}
    
    # Train VAEs
    for latent_dim in vae_latent_dims:
        print(f"\n{'='*50}")
        print(f"Training VAE with Latent Dimension: {latent_dim}")
        print(f"{'='*50}")
        
        encoder, decoder, _, final_metrics = train_vae(latent_dim=latent_dim)
        vae_metrics[f'LatentDim={latent_dim}'] = final_metrics
        
        test_data = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor())
        test_loader = DataLoader(test_data, batch_size=128, shuffle=True)
        visualize_reconstructions(encoder, decoder, test_loader, 
                                lambda mu, logvar: mu + torch.exp(0.5*logvar) * torch.randn_like(logvar),
                                model_type='vae')
    
    # Train DAEs
    for noise_type in noise_types:
        for latent_dim in dae_latent_dims:
            print(f"\n{'='*50}")
            print(f"Training DAE with {noise_type} noise, Latent Dim: {latent_dim}")
            print(f"{'='*50}")
            
            encoder, decoder, _, final_metrics = train_dae(latent_dim=latent_dim, noise_type=noise_type)
            dae_metrics[f'{noise_type}_latent{latent_dim}'] = final_metrics
            
            test_data = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor())
            test_loader = DataLoader(test_data, batch_size=128, shuffle=True)
            visualize_reconstructions(encoder, decoder, test_loader, 
                                    lambda x: x,  # Identity function for DAE
                                    model_type='dae')
    
    # Save comparison results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    with pd.ExcelWriter(f'results/model_comparison_{timestamp}.xlsx', engine='openpyxl') as writer:
        pd.DataFrame.from_dict(vae_metrics, orient='index').to_excel(writer, sheet_name='VAE_Results')
        pd.DataFrame.from_dict(dae_metrics, orient='index').to_excel(writer, sheet_name='DAE_Results')
    
    print("\nTraining complete! All results saved in the 'results' directory")

Using device: cuda

Training VAE with Latent Dimension: 5
Epoch [1/20]  Time: 9.17s  Loss: 44.29 (Recon: 38.93, KL: 5.35, Dice: 0.0042)
Train PSNR: 13.24dB  Test PSNR: 14.23dB
Epoch [2/20]  Time: 8.65s  Loss: 35.72 (Recon: 28.90, KL: 6.81, Dice: 0.0035)
Train PSNR: 14.34dB  Test PSNR: 14.55dB
Epoch [3/20]  Time: 9.70s  Loss: 34.49 (Recon: 27.32, KL: 7.16, Dice: 0.0034)
Train PSNR: 14.58dB  Test PSNR: 14.69dB
Epoch [4/20]  Time: 9.32s  Loss: 33.80 (Recon: 26.44, KL: 7.35, Dice: 0.0033)
Train PSNR: 14.72dB  Test PSNR: 14.77dB
Epoch [5/20]  Time: 9.22s  Loss: 33.27 (Recon: 25.80, KL: 7.47, Dice: 0.0032)
Train PSNR: 14.83dB  Test PSNR: 14.89dB
Epoch [6/20]  Time: 9.31s  Loss: 32.91 (Recon: 25.30, KL: 7.61, Dice: 0.0032)
Train PSNR: 14.92dB  Test PSNR: 15.02dB
Epoch [7/20]  Time: 9.47s  Loss: 32.60 (Recon: 24.92, KL: 7.68, Dice: 0.0032)
Train PSNR: 14.98dB  Test PSNR: 15.01dB
Epoch [8/20]  Time: 9.19s  Loss: 32.33 (Recon: 24.58, KL: 7.75, Dice: 0.0031)
Train PSNR: 15.04dB  Test PSNR: 15.13d