# Adversarial Attacks on Variational Autoencoders

This notebook demonstrates how to engineer adversarial attacks against a VAE using MNIST dataset with LeNet-style encoder/decoder and 2D latent space.

## Key Concepts:
- **FGSM (Fast Gradient Sign Method)**: Single-step attack using gradient sign
- **PGD (Projected Gradient Descent)**: Multi-step iterative attack
- **Latent Space Attack**: Attack in the encoded latent representation
- **VAE Vulnerabilities**: How reconstruction and regularization losses affect robustness

In [None]:
# Install required packages
!pip install torch torchvision numpy matplotlib tqdm psutil

# Install GPU monitoring tools (optional - will fallback to nvidia-smi if not available)
try:
    !pip install nvidia-ml-py
    print("✓ nvidia-ml-py installed for efficient GPU monitoring")
except:
    print("⚠ nvidia-ml-py not available, will use nvidia-smi fallback")

# Check if nvidia-smi is available
import subprocess
try:
    result = subprocess.run(['nvidia-smi', '--version'], capture_output=True, text=True)
    if result.returncode == 0:
        print("✓ nvidia-smi available for GPU monitoring")
    else:
        print("⚠ nvidia-smi not available - GPU monitoring will show zeros")
except FileNotFoundError:
    print("⚠ nvidia-smi not found - GPU monitoring will show zeros")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check if GPU is available and print basic info
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("No GPU available - using CPU")

## 1. Define VAE Architecture with LeNet-style Encoder/Decoder

In [None]:
class LeNetEncoder(nn.Module):
    """LeNet-style encoder for VAE"""
    def __init__(self, latent_dim=2):
        super(LeNetEncoder, self).__init__()
        self.latent_dim = latent_dim
        
        # Convolutional layers (LeNet-style)
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)  # 28x28 -> 28x28
        self.pool1 = nn.MaxPool2d(2, 2)  # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 14x14 -> 10x10
        self.pool2 = nn.MaxPool2d(2, 2)  # 10x10 -> 5x5
        
        # Fully connected layers
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        
        # Output layers for mean and log variance
        self.fc_mu = nn.Linear(84, latent_dim)
        self.fc_logvar = nn.Linear(84, latent_dim)
        
    def forward(self, x):
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        
        # Flatten
        x = x.view(-1, 16 * 5 * 5)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        # Output mean and log variance
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        
        return mu, logvar

In [None]:
class LeNetDecoder(nn.Module):
    """LeNet-style decoder for VAE"""
    def __init__(self, latent_dim=2):
        super(LeNetDecoder, self).__init__()
        self.latent_dim = latent_dim
        
        # Fully connected layers
        self.fc1 = nn.Linear(latent_dim, 84)
        self.fc2 = nn.Linear(84, 120)
        self.fc3 = nn.Linear(120, 16 * 5 * 5)
        
        # Transposed convolutional layers (reverse of encoder)
        self.deconv1 = nn.ConvTranspose2d(16, 6, kernel_size=5, stride=2, padding=2, output_padding=1)  # 5x5 -> 10x10
        self.deconv2 = nn.ConvTranspose2d(6, 1, kernel_size=5, stride=2, padding=2, output_padding=1)   # 10x10 -> 20x20
        # Add padding to get from 20x20 to 28x28
        self.final_conv = nn.ConvTranspose2d(1, 1, kernel_size=9, stride=1, padding=0)  # 20x20 -> 28x28
        
    def forward(self, z):
        # Fully connected layers
        x = F.relu(self.fc1(z))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        
        # Reshape to feature maps
        x = x.view(-1, 16, 5, 5)
        
        # Transposed convolutional layers
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = torch.sigmoid(self.final_conv(x))
        
        return x

In [None]:
class VAE(nn.Module):
    """Variational Autoencoder with LeNet-style architecture"""
    def __init__(self, latent_dim=2):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = LeNetEncoder(latent_dim)
        self.decoder = LeNetDecoder(latent_dim)
        
    def reparameterize(self, mu, logvar):
        """Reparameterization trick"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar
    
    def encode(self, x):
        """Encode input to latent space"""
        mu, logvar = self.encoder(x)
        return self.reparameterize(mu, logvar)
    
    def decode(self, z):
        """Decode from latent space"""
        return self.decoder(z)

def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    """
    FIXED VAE loss function with proper scaling
    
    The original had major scaling issues:
    - reconstruction loss was summed over batch AND pixels
    - KL loss was summed over batch AND latent dims
    - This made losses scale with batch size, causing instability
    """
    batch_size = x.size(0)
    
    # Reconstruction loss - properly normalized
    # BCE should be averaged over batch and pixels, not summed
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='mean')
    
    # KL divergence loss - properly normalized
    # KL should be averaged over batch, summed over latent dims
    kl_loss = -0.5 * torch.mean(torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1))
    
    return recon_loss + beta * kl_loss

# 🚨 CRITICAL ISSUE EXPLANATION:
print("🚨 MAJOR VAE LOSS SCALING ISSUE IDENTIFIED!")
print("="*60)
print("❌ ORIGINAL PROBLEM:")
print("   • Reconstruction loss: reduction='sum' → scales with batch_size * pixels")
print("   • KL loss: torch.sum(...) → scales with batch_size * latent_dims")
print("   • With batch_size=2048: losses become HUGE!")
print("   • Model learns to minimize by making reconstructions blurry/constant")
print("")
print("✅ FIXED VERSION:")
print("   • Reconstruction loss: reduction='mean' → normalized per sample")
print("   • KL loss: torch.mean(torch.sum(..., dim=1)) → proper normalization") 
print("   • Losses stay stable regardless of batch size")
print("   • Model can learn proper reconstructions")
print("")
print("📊 IMPACT DEMONSTRATION:")

# Show the scaling difference
dummy_batch_size = [128, 512, 1024, 2048]
dummy_recon = torch.ones(1, 1, 28, 28)  # Perfect reconstruction
dummy_x = torch.ones(1, 1, 28, 28)
dummy_mu = torch.zeros(1, 2)
dummy_logvar = torch.zeros(1, 2)

print(f"{'Batch Size':<12} {'Old Loss':<15} {'New Loss':<15} {'Ratio':<10}")
print("-" * 55)

for bs in dummy_batch_size:
    # Simulate larger batch
    recon_batch = dummy_recon.repeat(bs, 1, 1, 1)
    x_batch = dummy_x.repeat(bs, 1, 1, 1)
    mu_batch = dummy_mu.repeat(bs, 1)
    logvar_batch = dummy_logvar.repeat(bs, 1)
    
    # Old loss (broken)
    old_recon = F.binary_cross_entropy(recon_batch, x_batch, reduction='sum')
    old_kl = -0.5 * torch.sum(1 + logvar_batch - mu_batch.pow(2) - logvar_batch.exp())
    old_total = old_recon + old_kl
    
    # New loss (fixed) 
    new_recon = F.binary_cross_entropy(recon_batch, x_batch, reduction='mean')
    new_kl = -0.5 * torch.mean(torch.sum(1 + logvar_batch - mu_batch.pow(2) - logvar_batch.exp(), dim=1))
    new_total = new_recon + new_kl
    
    ratio = old_total / new_total
    print(f"{bs:<12} {old_total.item():<15.3f} {new_total.item():<15.3f} {ratio.item():<10.1f}x")

print("\n💡 This explains why your reconstructions were horrible!")
print("   The old loss scaled quadratically with batch size, making training unstable.")

## 2. Load MNIST Dataset

In [None]:
# Load MNIST dataset with optimizations for A100
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Increase batch size significantly for A100 - can handle much larger batches
# A100 has 40-80GB memory, so we can use much larger batches
batch_size_train = 2048  # Much larger for A100
batch_size_test = 1024

# Add num_workers for faster data loading and pin_memory for GPU transfer
num_workers = 4  # Adjust based on your CPU cores

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size_train, 
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,  # Faster GPU transfer
    persistent_workers=True  # Keep workers alive between epochs
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size_test, 
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Training batch size: {batch_size_train} (optimized for A100)")
print(f"Test batch size: {batch_size_test}")
print(f"Batches per epoch: {len(train_loader)}")

# Visualize some samples
def show_samples(loader, num_samples=8):
    data_iter = iter(loader)
    images, labels = next(data_iter)
    
    fig, axes = plt.subplots(1, num_samples, figsize=(12, 2))
    for i in range(num_samples):
        axes[i].imshow(images[i].squeeze(), cmap='gray')
        axes[i].set_title(f'Label: {labels[i]}')
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

show_samples(train_loader)

## 3. Train the VAE Model

In [None]:
def train_vae(model, train_loader, epochs=10, lr=1e-3, beta=1.0, use_amp=True):
    """
    Train the VAE model with A100 optimizations
    
    Args:
        use_amp: Use Automatic Mixed Precision for A100 efficiency
    """
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Initialize mixed precision training for A100
    scaler = torch.cuda.amp.GradScaler() if use_amp and torch.cuda.is_available() else None
    
    model.train()
    train_losses = []
    
    # Compile model for A100 (PyTorch 2.0+)
    try:
        model = torch.compile(model)
        print("✅ Model compiled for A100 optimization")
    except:
        print("⚠️  Model compilation not available (PyTorch < 2.0)")
    
    print(f"Training VAE with {'mixed precision' if use_amp else 'full precision'}...")
    print(f"Batch size: {train_loader.batch_size} (A100 optimized)")
    
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0
        batch_count = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device, non_blocking=True)  # Async GPU transfer
            optimizer.zero_grad()
            
            if use_amp and scaler is not None:
                # Mixed precision forward pass
                with torch.cuda.amp.autocast():
                    recon_batch, mu, logvar = model(data)
                    loss = vae_loss(recon_batch, data, mu, logvar, beta)
                
                # Mixed precision backward pass
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard precision training
                recon_batch, mu, logvar = model(data)
                loss = vae_loss(recon_batch, data, mu, logvar, beta)
                loss.backward()
                optimizer.step()
            
            epoch_loss += loss.item()
            batch_count += 1
        
        avg_loss = epoch_loss / len(train_loader.dataset)
        train_losses.append(avg_loss)
        
        if epoch % 2 == 0:
            print(f'Epoch {epoch}, Average Loss: {avg_loss:.4f}, Batches: {batch_count}')
    
    return train_losses

# Initialize model with larger capacity for A100
# We can afford a bigger model on A100
model = VAE(latent_dim=2)

# Check model size
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"📊 Model Statistics:")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1e6:.2f} MB (float32)")

# Check GPU memory before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"A100 Total Memory: {total_memory:.1f} GB")
    print(f"Current usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
# 🔍 VAE Training Diagnostics - Let's identify the reconstruction problem

def diagnose_vae_issues(model, train_loader, test_loader):
    """Comprehensive VAE diagnostics to identify training issues"""
    
    print("🔍 Running VAE Diagnostics...")
    print("="*60)
    
    # 1. Check model architecture
    total_params = sum(p.numel() for p in model.parameters())
    encoder_params = sum(p.numel() for p in model.encoder.parameters())
    decoder_params = sum(p.numel() for p in model.decoder.parameters())
    
    print(f"📊 Model Architecture:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Encoder parameters: {encoder_params:,}")
    print(f"  Decoder parameters: {decoder_params:,}")
    print(f"  Latent dimensions: {model.latent_dim}")
    
    # 2. Test forward pass
    model.eval()
    with torch.no_grad():
        test_batch = next(iter(test_loader))[0][:8].to(device)
        
        # Check input range
        print(f"\n📈 Data Diagnostics:")
        print(f"  Input shape: {test_batch.shape}")
        print(f"  Input range: [{test_batch.min():.4f}, {test_batch.max():.4f}]")
        print(f"  Input mean: {test_batch.mean():.4f}")
        print(f"  Input std: {test_batch.std():.4f}")
        
        # Forward pass
        recon, mu, logvar = model(test_batch)
        
        # Check outputs
        print(f"\n🧠 Model Outputs:")
        print(f"  Reconstruction shape: {recon.shape}")
        print(f"  Reconstruction range: [{recon.min():.4f}, {recon.max():.4f}]")
        print(f"  Reconstruction mean: {recon.mean():.4f}")
        print(f"  Latent mu range: [{mu.min():.4f}, {mu.max():.4f}]")
        print(f"  Latent mu mean: {mu.mean():.4f}")
        print(f"  Latent logvar range: [{logvar.min():.4f}, {logvar.max():.4f}]")
        print(f"  Latent logvar mean: {logvar.mean():.4f}")
        
        # Check for common issues
        issues_found = []
        
        # Issue 1: Reconstruction saturation
        if recon.min() < 0.01 or recon.max() > 0.99:
            issues_found.append("❌ Reconstruction saturation (values near 0 or 1)")
        
        # Issue 2: Latent collapse
        if mu.std() < 0.1:
            issues_found.append("❌ Potential latent collapse (mu std too low)")
        
        # Issue 3: KL divergence issues
        if logvar.mean() < -10:
            issues_found.append("❌ KL collapse (logvar too negative)")
        elif logvar.mean() > 5:
            issues_found.append("❌ KL explosion (logvar too positive)")
        
        # Issue 4: Reconstruction loss
        recon_loss = F.binary_cross_entropy(recon, test_batch, reduction='mean')
        if recon_loss > 0.5:
            issues_found.append(f"❌ High reconstruction loss ({recon_loss:.4f})")
        
        # Issue 5: Dead neurons
        relu_activations = []
        def hook_fn(module, input, output):
            if isinstance(module, nn.ReLU):
                relu_activations.append((output == 0).float().mean().item())
        
        hooks = []
        for name, module in model.named_modules():
            if isinstance(module, nn.ReLU):
                hooks.append(module.register_forward_hook(hook_fn))
        
        _ = model(test_batch)
        
        for hook in hooks:
            hook.remove()
        
        if relu_activations and max(relu_activations) > 0.9:
            issues_found.append("❌ Dead ReLU neurons detected")
        
        print(f"\n🚨 Issues Detected:")
        if issues_found:
            for issue in issues_found:
                print(f"  {issue}")
        else:
            print("  ✅ No obvious issues detected")
        
        # 3. Visual inspection
        print(f"\n🖼️  Visual Diagnostics:")
        fig, axes = plt.subplots(2, 8, figsize=(16, 4))
        
        for i in range(8):
            # Original
            axes[0, i].imshow(test_batch[i].cpu().squeeze(), cmap='gray')
            axes[0, i].set_title(f'Original {i}')
            axes[0, i].axis('off')
            
            # Reconstruction
            axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
            axes[1, i].set_title(f'Recon {i}')
            axes[1, i].axis('off')
        
        plt.suptitle('Original vs Reconstruction Comparison')
        plt.tight_layout()
        plt.show()
        
        # Calculate reconstruction quality
        mse_loss = F.mse_loss(recon, test_batch, reduction='mean')
        print(f"  MSE Loss: {mse_loss:.6f}")
        print(f"  BCE Loss: {recon_loss:.6f}")
        
        return issues_found, recon_loss.item(), mse_loss.item()

# Run diagnostics
print("🔧 Diagnosing VAE training issues...")
issues, bce_loss, mse_loss = diagnose_vae_issues(model, train_loader, test_loader)

In [None]:
# 🛠️ Fixed VAE Implementation - Addressing Common Issues

def improved_vae_loss(recon_x, x, mu, logvar, beta=1.0, recon_loss_type='mse'):
    """
    Improved VAE loss function with better numerical stability
    
    Args:
        recon_loss_type: 'mse' or 'bce' - MSE often works better than BCE
    """
    # Reconstruction loss - try MSE instead of BCE
    if recon_loss_type == 'mse':
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    else:
        # Add epsilon for numerical stability in BCE
        epsilon = 1e-8
        recon_x_stable = torch.clamp(recon_x, epsilon, 1 - epsilon)
        recon_loss = F.binary_cross_entropy(recon_x_stable, x, reduction='sum')
    
    # KL divergence loss with numerical stability
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + beta * kl_loss, recon_loss, kl_loss

def train_vae_improved(model, train_loader, epochs=10, lr=1e-3, beta=1.0, 
                       use_amp=False, recon_loss_type='mse', warmup_epochs=5):
    """
    Improved VAE training with better practices
    
    Args:
        use_amp: Disabled by default as it can cause issues with VAE training
        recon_loss_type: 'mse' often works better than 'bce' for MNIST
        warmup_epochs: Gradually increase beta for KL annealing
    """
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
    
    model.train()
    train_losses = []
    recon_losses = []
    kl_losses = []
    
    print(f"Training improved VAE...")
    print(f"Reconstruction loss: {recon_loss_type}")
    print(f"KL warmup epochs: {warmup_epochs}")
    print(f"Mixed precision: {'Enabled' if use_amp else 'Disabled (recommended for VAE)'}")
    
    for epoch in tqdm(range(epochs)):
        epoch_loss = 0
        epoch_recon_loss = 0
        epoch_kl_loss = 0
        batch_count = 0
        
        # KL annealing - gradually increase beta
        if epoch < warmup_epochs:
            current_beta = beta * (epoch / warmup_epochs)
        else:
            current_beta = beta
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device, non_blocking=True)
            optimizer.zero_grad()
            
            if use_amp:
                with torch.cuda.amp.autocast():
                    recon_batch, mu, logvar = model(data)
                    loss, recon_loss, kl_loss = improved_vae_loss(
                        recon_batch, data, mu, logvar, current_beta, recon_loss_type)
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard precision (recommended for VAE)
                recon_batch, mu, logvar = model(data)
                loss, recon_loss, kl_loss = improved_vae_loss(
                    recon_batch, data, mu, logvar, current_beta, recon_loss_type)
                
                loss.backward()
                
                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
            
            epoch_loss += loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_kl_loss += kl_loss.item()
            batch_count += 1
        
        avg_loss = epoch_loss / len(train_loader.dataset)
        avg_recon_loss = epoch_recon_loss / len(train_loader.dataset)
        avg_kl_loss = epoch_kl_loss / len(train_loader.dataset)
        
        train_losses.append(avg_loss)
        recon_losses.append(avg_recon_loss)
        kl_losses.append(avg_kl_loss)
        
        # Learning rate scheduling
        scheduler.step(avg_loss)
        
        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f'Epoch {epoch:3d}, Total Loss: {avg_loss:.4f}, '
                  f'Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, '
                  f'Beta: {current_beta:.3f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
    return train_losses, recon_losses, kl_losses

# Create a fresh model instance
print("🔄 Creating fresh VAE model...")
model_improved = VAE(latent_dim=2)

# Check the model
total_params = sum(p.numel() for p in model_improved.parameters())
print(f"Model parameters: {total_params:,}")

# Run quick diagnostic on untrained model
print(f"\n📋 Testing untrained model...")
with torch.no_grad():
    test_batch = next(iter(test_loader))[0][:4].to(device)
    recon, mu, logvar = model_improved(test_batch)
    print(f"Untrained reconstruction range: [{recon.min():.4f}, {recon.max():.4f}]")
    print(f"Untrained latent mu range: [{mu.min():.4f}, {mu.max():.4f}]")

In [None]:
# 🚀 Train VAE with Improved Settings

print("🔧 Training VAE with improved settings...")
print("Key improvements:")
print("• Using MSE loss instead of BCE (often better for MNIST)")
print("• KL annealing (gradual beta increase)")
print("• Gradient clipping")
print("• Learning rate scheduling")
print("• Disabled mixed precision (can cause VAE issues)")
print("• Reduced batch size for better gradients")

# Use smaller batch size for better VAE training
smaller_train_loader = DataLoader(
    train_dataset, 
    batch_size=512,  # Smaller batch size for VAE
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

smaller_test_loader = DataLoader(
    test_dataset, 
    batch_size=256,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# Train with improved settings
start_time = time.time()

train_losses_improved, recon_losses, kl_losses = train_vae_improved(
    model_improved, 
    smaller_train_loader, 
    epochs=20,  # Start with fewer epochs to check
    lr=1e-3,
    beta=1.0,
    use_amp=False,  # Disable mixed precision for VAE stability
    recon_loss_type='mse',  # MSE often works better than BCE
    warmup_epochs=5
)

training_time = time.time() - start_time

# Plot comprehensive training metrics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Total loss
axes[0, 0].plot(train_losses_improved, 'b-', linewidth=2, marker='o')
axes[0, 0].set_title('Total VAE Loss (Improved)')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True, alpha=0.3)

# Reconstruction loss
axes[0, 1].plot(recon_losses, 'r-', linewidth=2, marker='s')
axes[0, 1].set_title('Reconstruction Loss (MSE)')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Reconstruction Loss')
axes[0, 1].grid(True, alpha=0.3)

# KL loss
axes[1, 0].plot(kl_losses, 'g-', linewidth=2, marker='^')
axes[1, 0].set_title('KL Divergence Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('KL Loss')
axes[1, 0].grid(True, alpha=0.3)

# Loss components ratio
recon_ratio = [r / (r + k) for r, k in zip(recon_losses, kl_losses)]
kl_ratio = [k / (r + k) for r, k in zip(recon_losses, kl_losses)]

axes[1, 1].plot(recon_ratio, 'r-', linewidth=2, label='Reconstruction %')
axes[1, 1].plot(kl_ratio, 'g-', linewidth=2, label='KL %')
axes[1, 1].set_title('Loss Component Ratios')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Proportion')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print training summary
print(f"\n✅ Improved Training Complete!")
print(f"Training time: {training_time:.2f} seconds")
print(f"Final total loss: {train_losses_improved[-1]:.6f}")
print(f"Final reconstruction loss: {recon_losses[-1]:.6f}")
print(f"Final KL loss: {kl_losses[-1]:.6f}")

# Test the improved model
print(f"\n🧪 Testing improved model reconstructions...")
model_improved.eval()
with torch.no_grad():
    test_batch = next(iter(smaller_test_loader))[0][:8].to(device)
    recon_improved, mu_improved, logvar_improved = model_improved(test_batch)
    
    # Calculate reconstruction quality
    mse_improved = F.mse_loss(recon_improved, test_batch, reduction='mean')
    bce_improved = F.binary_cross_entropy(torch.clamp(recon_improved, 1e-8, 1-1e-8), 
                                         test_batch, reduction='mean')
    
    print(f"Improved model MSE: {mse_improved:.6f}")
    print(f"Improved model BCE: {bce_improved:.6f}")
    print(f"Latent mu std: {mu_improved.std():.4f}")
    print(f"Reconstruction range: [{recon_improved.min():.4f}, {recon_improved.max():.4f}]")
    
    # Visual comparison
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    
    for i in range(8):
        # Original
        axes[0, i].imshow(test_batch[i].cpu().squeeze(), cmap='gray')
        axes[0, i].set_title(f'Original {i}')
        axes[0, i].axis('off')
        
        # Improved reconstruction
        axes[1, i].imshow(recon_improved[i].cpu().squeeze(), cmap='gray')
        axes[1, i].set_title(f'Improved {i}')
        axes[1, i].axis('off')
    
    plt.suptitle('Improved VAE Reconstructions', fontsize=14)
    plt.tight_layout()
    plt.show()

# Compare with original model if it exists
if 'model' in locals():
    print(f"\n📊 Comparison with original model:")
    model.eval()
    with torch.no_grad():
        recon_original, _, _ = model(test_batch)
        mse_original = F.mse_loss(recon_original, test_batch, reduction='mean')
        print(f"Original model MSE: {mse_original:.6f}")
        print(f"Improved model MSE: {mse_improved:.6f}")
        print(f"Improvement: {((mse_original - mse_improved) / mse_original * 100):.1f}%")

In [None]:
# Train the VAE model with A100 optimizations
import time

print("🚀 Starting A100-optimized training...")
start_time = time.time()

# Use mixed precision training for A100 efficiency
train_losses = train_vae(model, train_loader, epochs=10, beta=1.0, use_amp=True)

end_time = time.time()
training_time = end_time - start_time

# Plot training loss
plt.figure(figsize=(12, 8))

# Main loss plot
plt.subplot(2, 2, 1)
plt.plot(train_losses, 'b-', linewidth=2, marker='o')
plt.title('VAE Training Loss (A100 Optimized)', fontsize=14)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

# Training speed analysis
plt.subplot(2, 2, 2)
epochs = list(range(len(train_losses)))
samples_processed = [(epoch + 1) * len(train_loader.dataset) for epoch in epochs]
plt.plot(epochs, samples_processed, 'g-', linewidth=2, marker='s')
plt.title('Samples Processed')
plt.xlabel('Epoch')
plt.ylabel('Total Samples')
plt.grid(True, alpha=0.3)

# GPU utilization estimate (based on batch size and model size)
plt.subplot(2, 2, 3)
batch_sizes = [128, 512, 1024, 2048, 4096]  # Different batch sizes
estimated_util = [min(100, (bs / 2048) * 80) for bs in batch_sizes]  # Rough estimate
current_idx = batch_sizes.index(train_loader.batch_size) if train_loader.batch_size in batch_sizes else -1

plt.bar(range(len(batch_sizes)), estimated_util, alpha=0.7, color='orange')
if current_idx >= 0:
    plt.bar(current_idx, estimated_util[current_idx], color='red', alpha=0.8, label='Current')
plt.title('Estimated GPU Utilization by Batch Size')
plt.xlabel('Batch Size')
plt.ylabel('GPU Utilization (%)')
plt.xticks(range(len(batch_sizes)), batch_sizes)
plt.legend()
plt.grid(True, alpha=0.3)

# Performance metrics
plt.subplot(2, 2, 4)
metrics = ['Samples/sec', 'Batches/sec', 'Time/epoch (s)']
total_samples = len(train_loader.dataset) * len(train_losses)
samples_per_sec = total_samples / training_time
batches_per_sec = len(train_loader) * len(train_losses) / training_time
time_per_epoch = training_time / len(train_losses)

values = [samples_per_sec, batches_per_sec, time_per_epoch]
colors = ['blue', 'green', 'red']

bars = plt.bar(metrics, values, color=colors, alpha=0.7)
plt.title('Training Performance Metrics')
plt.ylabel('Rate')
for bar, value in zip(bars, values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{value:.1f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Print comprehensive training summary
final_loss = train_losses[-1]
initial_loss = train_losses[0]
improvement = initial_loss - final_loss

print(f"\n🎯 A100 Training Performance Summary:")
print(f"{'='*50}")
print(f"Training Time: {training_time:.2f} seconds ({training_time/60:.2f} minutes)")
print(f"Samples/second: {samples_per_sec:.1f}")
print(f"Batches/second: {batches_per_sec:.1f}")
print(f"Time per epoch: {time_per_epoch:.2f} seconds")
print(f"Batch size used: {train_loader.batch_size}")
print(f"Total batches per epoch: {len(train_loader)}")

print(f"\n📊 Training Quality:")
print(f"Initial Loss: {initial_loss:.6f}")
print(f"Final Loss: {final_loss:.6f}")
print(f"Total Improvement: {improvement:.6f}")
print(f"Improvement %: {(improvement/initial_loss)*100:.2f}%")

# GPU memory usage
if torch.cuda.is_available():
    memory_used = torch.cuda.memory_allocated() / 1e9
    memory_cached = torch.cuda.memory_reserved() / 1e9
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    
    print(f"\n💾 A100 Memory Usage:")
    print(f"Memory allocated: {memory_used:.2f} GB")
    print(f"Memory cached: {memory_cached:.2f} GB")
    print(f"Total memory: {total_memory:.1f} GB")
    print(f"Memory utilization: {(memory_used/total_memory)*100:.1f}%")

# Recommendations for better GPU utilization
print(f"\n💡 A100 Optimization Recommendations:")
if train_loader.batch_size < 1024:
    print(f"   • ⚠️  Batch size ({train_loader.batch_size}) is small for A100 - try 2048-8192")
else:
    print(f"   • ✅ Batch size ({train_loader.batch_size}) is good for A100")

if training_time / len(train_losses) < 5:
    print(f"   • ⚠️  Training very fast ({time_per_epoch:.1f}s/epoch) - consider larger model or more data")
else:
    print(f"   • ✅ Training time per epoch is reasonable")

print(f"   • Consider increasing model size (more layers/channels)")
print(f"   • Try larger datasets (CIFAR-10, ImageNet)")
print(f"   • Use gradient accumulation for even larger effective batch sizes")
print(f"   • Enable model compilation (torch.compile) if PyTorch 2.0+")

## 3.5 VAE Quality Assessment

Before proceeding with adversarial attacks, let's thoroughly evaluate the trained VAE to ensure it has learned meaningful representations.

In [None]:
def evaluate_vae_reconstructions(model, test_loader, device, num_samples=10):
    """
    Evaluate VAE reconstruction quality on random test samples
    Shows original vs reconstruction side by side with reconstruction errors
    """
    model.eval()
    
    # Get random test samples
    test_iter = iter(test_loader)
    test_data, test_labels = next(test_iter)
    
    # Select random samples
    indices = torch.randperm(test_data.size(0))[:num_samples]
    sample_data = test_data[indices].to(device)
    sample_labels = test_labels[indices]
    
    with torch.no_grad():
        # Get reconstructions
        recon_data, mu, logvar = model(sample_data)
        
        # Calculate reconstruction errors
        recon_errors = F.mse_loss(recon_data, sample_data, reduction='none')
        recon_errors = recon_errors.view(recon_errors.size(0), -1).mean(dim=1)
    
    # Create visualization
    fig, axes = plt.subplots(3, num_samples, figsize=(num_samples*2, 6))
    
    for i in range(num_samples):
        # Original image
        axes[0, i].imshow(sample_data[i].detach().cpu().squeeze(), cmap='gray')
        axes[0, i].set_title(f'Original\nDigit: {sample_labels[i].item()}')
        axes[0, i].axis('off')
        
        # Reconstructed image
        axes[1, i].imshow(recon_data[i].detach().cpu().squeeze(), cmap='gray')
        axes[1, i].set_title(f'Reconstruction\nMSE: {recon_errors[i].item():.4f}')
        axes[1, i].axis('off')
        
        # Difference (amplified)
        diff = (sample_data[i] - recon_data[i]).detach().cpu().squeeze()
        axes[2, i].imshow(diff * 5, cmap='RdBu', vmin=-1, vmax=1)
        axes[2, i].set_title(f'Difference (×5)')
        axes[2, i].axis('off')
    
    # Add row labels
    axes[0, 0].set_ylabel('Original', rotation=90, labelpad=40, fontsize=12, ha='center')
    axes[1, 0].set_ylabel('Reconstruction', rotation=90, labelpad=40, fontsize=12, ha='center')
    axes[2, 0].set_ylabel('Difference', rotation=90, labelpad=40, fontsize=12, ha='center')
    
    plt.suptitle(f'VAE Reconstruction Quality Assessment ({num_samples} Random Samples)', fontsize=14)
    plt.tight_layout()
    plt.subplots_adjust(left=0.1)
    plt.show()
    
    # Print statistics
    avg_error = recon_errors.mean().item()
    std_error = recon_errors.std().item()
    min_error = recon_errors.min().item()
    max_error = recon_errors.max().item()
    
    print(f"\n📊 Reconstruction Statistics:")
    print(f"Average MSE: {avg_error:.6f}")
    print(f"Std Dev MSE: {std_error:.6f}")
    print(f"Min MSE: {min_error:.6f}")
    print(f"Max MSE: {max_error:.6f}")
    
    # Analyze latent space statistics
    with torch.no_grad():
        mu_mean = mu.mean(dim=0)
        mu_std = mu.std(dim=0)
        latent_norm = torch.norm(mu, dim=1).mean()
    
    print(f"\n🧠 Latent Space Statistics:")
    print(f"Latent dimensions: {model.latent_dim}")
    print(f"Mean latent values: [{mu_mean[0].item():.4f}, {mu_mean[1].item():.4f}]")
    print(f"Std latent values: [{mu_std[0].item():.4f}, {mu_std[1].item():.4f}]")
    print(f"Average latent norm: {latent_norm.item():.4f}")
    
    return recon_errors, mu

In [None]:
def plot_latent_space_grid(model, device, grid_size=15, latent_range=3):
    """
    Generate and visualize a grid of samples from the latent space
    This shows what the VAE has learned to generate across the latent space
    """
    model.eval()
    
    # Create a grid of points in latent space
    x = np.linspace(-latent_range, latent_range, grid_size)
    y = np.linspace(-latent_range, latent_range, grid_size)
    
    # Create meshgrid
    xx, yy = np.meshgrid(x, y)
    
    # Flatten the grid
    grid_points = np.column_stack([xx.ravel(), yy.ravel()])
    
    # Convert to tensor
    latent_samples = torch.FloatTensor(grid_points).to(device)
    
    # Generate images from latent samples
    with torch.no_grad():
        generated_images = model.decoder(latent_samples)
    
    # Create the visualization
    fig, ax = plt.subplots(figsize=(12, 12))
    
    # Create a large image by concatenating generated samples
    img_size = 28  # MNIST image size
    full_image = np.zeros((grid_size * img_size, grid_size * img_size))
    
    for i in range(grid_size):
        for j in range(grid_size):
            idx = i * grid_size + j
            img = generated_images[idx].cpu().squeeze().numpy()
            
            # Place the image in the correct position
            start_row = i * img_size
            end_row = start_row + img_size
            start_col = j * img_size
            end_col = start_col + img_size
            
            full_image[start_row:end_row, start_col:end_col] = img
    
    # Display the full image
    ax.imshow(full_image, cmap='gray')
    ax.set_title(f'Latent Space Manifold Visualization ({grid_size}×{grid_size} grid)\n'
                f'Range: [{-latent_range}, {latent_range}] in both dimensions', fontsize=14)
    
    # Add coordinate labels
    tick_positions = np.arange(0, grid_size * img_size, img_size) + img_size // 2
    tick_labels = [f'{val:.1f}' for val in x]
    
    ax.set_xticks(tick_positions[::3])  # Show every 3rd tick to avoid crowding
    ax.set_yticks(tick_positions[::3])
    ax.set_xticklabels(tick_labels[::3])
    ax.set_yticklabels(tick_labels[::3])
    
    ax.set_xlabel('Latent Dimension 1', fontsize=12)
    ax.set_ylabel('Latent Dimension 2', fontsize=12)
    
    # Add grid lines for better visualization
    for i in range(1, grid_size):
        ax.axhline(y=i * img_size, color='red', alpha=0.3, linewidth=0.5)
        ax.axvline(x=i * img_size, color='red', alpha=0.3, linewidth=0.5)
    
    plt.tight_layout()
    plt.show()
    
    # Also create a smaller focused view around the center
    center_range = 2
    center_grid = 10
    
    # Create centered grid
    x_center = np.linspace(-center_range, center_range, center_grid)
    y_center = np.linspace(-center_range, center_range, center_grid)
    xx_center, yy_center = np.meshgrid(x_center, y_center)
    grid_points_center = np.column_stack([xx_center.ravel(), yy_center.ravel()])
    latent_samples_center = torch.FloatTensor(grid_points_center).to(device)
    
    with torch.no_grad():
        generated_images_center = model.decoder(latent_samples_center)
    
    # Create focused visualization
    fig, axes = plt.subplots(center_grid, center_grid, figsize=(10, 10))
    
    for i in range(center_grid):
        for j in range(center_grid):
            idx = i * center_grid + j
            img = generated_images_center[idx].cpu().squeeze().numpy()
            
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
            
            # Add coordinate labels on border
            if i == 0:  # Top row
                axes[i, j].set_title(f'{x_center[j]:.1f}', fontsize=8)
            if j == 0:  # Left column
                axes[i, j].set_ylabel(f'{y_center[i]:.1f}', fontsize=8)
    
    plt.suptitle(f'Focused Latent Space View (Center Region ±{center_range})', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    print(f"🎨 Generated {grid_size*grid_size} images from latent space grid")
    print(f"📍 Latent range: [{-latent_range}, {latent_range}]")
    print(f"🔍 Focused view range: [{-center_range}, {center_range}]")

In [None]:
# Evaluate VAE Reconstruction Quality
print("🔍 Evaluating VAE Reconstruction Quality...")
recon_errors, latent_samples = evaluate_vae_reconstructions(model, test_loader, device, num_samples=10)

# Check if the model has learned reasonable reconstructions
avg_error = recon_errors.mean().item()
if avg_error > 0.1:
    print(f"⚠️  WARNING: High reconstruction error ({avg_error:.4f})")
    print("   Consider training for more epochs or adjusting hyperparameters")
elif avg_error > 0.05:
    print(f"⚠️  MODERATE: Reconstruction error is moderate ({avg_error:.4f})")
    print("   VAE quality is acceptable but could be improved")
else:
    print(f"✅ GOOD: Low reconstruction error ({avg_error:.4f})")
    print("   VAE has learned good reconstructions")

In [None]:
# Generate Latent Space Grid Visualization
print("🎨 Generating latent space manifold visualization...")
plot_latent_space_grid(model, device, grid_size=15, latent_range=3)

# Additional analysis: Check latent space coverage
with torch.no_grad():
    # Sample from test set to see latent distribution
    test_iter = iter(test_loader)
    test_batch, test_batch_labels = next(test_iter)
    test_batch = test_batch[:100].to(device)  # Use 100 samples
    
    mu_batch, _ = model.encoder(test_batch)
    
    # Calculate latent space statistics
    latent_mean = mu_batch.mean(dim=0)
    latent_std = mu_batch.std(dim=0)
    latent_range_actual = [
        (mu_batch[:, 0].min().item(), mu_batch[:, 0].max().item()),
        (mu_batch[:, 1].min().item(), mu_batch[:, 1].max().item())
    ]
    
    print(f"\n📊 Latent Space Coverage Analysis:")
    print(f"Dimension 1 - Mean: {latent_mean[0].item():.3f}, Std: {latent_std[0].item():.3f}")
    print(f"Dimension 1 - Range: [{latent_range_actual[0][0]:.3f}, {latent_range_actual[0][1]:.3f}]")
    print(f"Dimension 2 - Mean: {latent_mean[1].item():.3f}, Std: {latent_std[1].item():.3f}")
    print(f"Dimension 2 - Range: [{latent_range_actual[1][0]:.3f}, {latent_range_actual[1][1]:.3f}]")
    
    # Check if latent space is being utilized effectively
    max_range = max(latent_range_actual[0][1] - latent_range_actual[0][0],
                   latent_range_actual[1][1] - latent_range_actual[1][0])
    
    if max_range < 2:
        print("⚠️  Latent space utilization is limited - consider reducing regularization (β)")
    elif max_range > 6:
        print("⚠️  Latent space is very spread out - consider increasing regularization (β)")
    else:
        print("✅ Latent space utilization looks good")

In [None]:
# Training Sufficiency Assessment
print("\n🎯 VAE Training Sufficiency Assessment:")

# Calculate final training loss
final_loss = train_losses[-1] if train_losses else float('inf')
print(f"Final training loss: {final_loss:.6f}")

# Assess training quality based on multiple metrics
assessment_score = 0
recommendations = []

# 1. Reconstruction error assessment
avg_recon_error = recon_errors.mean().item()
if avg_recon_error < 0.05:
    assessment_score += 3
    print("✅ Reconstruction quality: EXCELLENT")
elif avg_recon_error < 0.1:
    assessment_score += 2
    print("✅ Reconstruction quality: GOOD")
elif avg_recon_error < 0.15:
    assessment_score += 1
    print("⚠️  Reconstruction quality: MODERATE")
    recommendations.append("Consider training for more epochs")
else:
    print("❌ Reconstruction quality: POOR")
    recommendations.append("Increase training epochs significantly")

# 2. Training loss convergence
if len(train_losses) >= 3:
    loss_improvement = train_losses[0] - train_losses[-1]
    loss_stability = abs(train_losses[-1] - train_losses[-2])
    
    if loss_improvement > 0.01 and loss_stability < 0.001:
        assessment_score += 3
        print("✅ Training convergence: EXCELLENT")
    elif loss_improvement > 0.005:
        assessment_score += 2
        print("✅ Training convergence: GOOD")
    elif loss_improvement > 0.001:
        assessment_score += 1
        print("⚠️  Training convergence: MODERATE")
        recommendations.append("Train for more epochs for better convergence")
    else:
        print("❌ Training convergence: POOR")
        recommendations.append("Training may need more epochs or different hyperparameters")

# 3. Latent space utilization
if max_range > 1.5 and max_range < 5:
    assessment_score += 2
    print("✅ Latent space utilization: GOOD")
elif max_range > 0.8:
    assessment_score += 1
    print("⚠️  Latent space utilization: MODERATE")
else:
    print("❌ Latent space utilization: POOR")
    recommendations.append("Adjust β parameter for better latent space usage")

# Overall assessment
total_score = assessment_score
max_score = 8

print(f"\n📊 Overall VAE Quality Score: {total_score}/{max_score}")

if total_score >= 7:
    print("🎉 EXCELLENT: VAE is well-trained and ready for adversarial attacks!")
elif total_score >= 5:
    print("👍 GOOD: VAE quality is acceptable for adversarial attack analysis")
elif total_score >= 3:
    print("⚠️  MODERATE: VAE could benefit from additional training")
else:
    print("❌ POOR: Consider retraining the VAE with different parameters")

if recommendations:
    print(f"\n💡 Recommendations:")
    for i, rec in enumerate(recommendations, 1):
        print(f"   {i}. {rec}")

print(f"\n{'='*60}")
print("📈 If you want to improve the VAE, consider:")
print("   • Increasing epochs to 20-50")
print("   • Adjusting β (try 0.5 for less regularization)")
print("   • Using a different learning rate (try 1e-4)")
print("   • Adding batch normalization or dropout")
print(f"{'='*60}")

## 4. Visualize Latent Space (2D)

In [None]:
def plot_latent_space(model, test_loader, device, num_samples=2000):
    """Plot the 2D latent space representation"""
    model.eval()
    latents = []
    labels = []
    
    with torch.no_grad():
        for data, label in test_loader:
            data = data.to(device)
            mu, _ = model.encoder(data)
            latents.append(mu.cpu().numpy())
            labels.append(label.numpy())
            
            if len(latents) * data.size(0) >= num_samples:
                break
    
    latents = np.concatenate(latents, axis=0)[:num_samples]
    labels = np.concatenate(labels, axis=0)[:num_samples]
    
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar(scatter)
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.title('2D Latent Space Representation (Color = Digit Class)')
    plt.grid(True)
    plt.show()
    
    return latents, labels

latents, labels = plot_latent_space(model, test_loader, device)

## 5. Implement Adversarial Attacks

In [None]:
class AdversarialAttacks:
    """Class containing various adversarial attack methods for VAEs"""
    
    @staticmethod
    def fgsm_attack(model, data, target, epsilon):
        """
        Fast Gradient Sign Method (FGSM) attack
        
        Args:
            model: VAE model
            data: input data
            target: target data (for reconstruction loss)
            epsilon: perturbation magnitude
        """
        # Set model to evaluation mode
        model.eval()
        
        # Enable gradient computation for input
        data.requires_grad = True
        
        # Forward pass
        recon_data, mu, logvar = model(data)
        
        # Calculate loss
        loss = vae_loss(recon_data, target, mu, logvar)
        
        # Zero gradients
        model.zero_grad()
        
        # Calculate gradients
        loss.backward()
        
        # Get gradient sign
        data_grad = data.grad.data
        sign_data_grad = data_grad.sign()
        
        # Create adversarial example
        perturbed_data = data + epsilon * sign_data_grad
        perturbed_data = torch.clamp(perturbed_data, 0, 1)
        
        return perturbed_data
    
    @staticmethod
    def pgd_attack(model, data, target, epsilon, alpha, num_iter):
        """
        Projected Gradient Descent (PGD) attack
        
        Args:
            model: VAE model
            data: input data
            target: target data
            epsilon: maximum perturbation
            alpha: step size
            num_iter: number of iterations
        """
        model.eval()
        
        # Initialize perturbation
        delta = torch.zeros_like(data).uniform_(-epsilon, epsilon)
        delta.requires_grad = True
        
        for i in range(num_iter):
            # Forward pass with perturbation
            perturbed_data = data + delta
            recon_data, mu, logvar = model(perturbed_data)
            
            # Calculate loss
            loss = vae_loss(recon_data, target, mu, logvar)
            
            # Calculate gradients
            loss.backward()
            
            # Update perturbation
            delta.data = delta.data + alpha * delta.grad.data.sign()
            delta.data = torch.clamp(delta.data, -epsilon, epsilon)
            delta.data = torch.clamp(data + delta.data, 0, 1) - data
            
            # Zero gradients
            delta.grad.zero_()
        
        return data + delta
    
    @staticmethod
    def latent_space_attack(model, data, epsilon, target_latent=None):
        """
        Attack in latent space by perturbing encoded representations
        
        Args:
            model: VAE model
            data: input data
            epsilon: perturbation magnitude in latent space
            target_latent: target latent representation (optional)
        """
        model.eval()
        
        # Encode to latent space
        mu, logvar = model.encoder(data)
        z = model.reparameterize(mu, logvar)
        
        if target_latent is not None:
            # Move towards target latent representation
            direction = (target_latent - z).sign()
            perturbed_z = z + epsilon * direction
        else:
            # Random perturbation in latent space
            noise = torch.randn_like(z)
            perturbed_z = z + epsilon * noise
        
        # Decode back to image space
        adversarial_recon = model.decoder(perturbed_z)
        
        return adversarial_recon, z, perturbed_z

## 6. Demonstrate Adversarial Attacks

In [None]:
# Get test samples for attacks
test_iter = iter(test_loader)
test_data, test_labels = next(test_iter)
test_data = test_data[:8].to(device)  # Use first 8 samples

# Initialize attack methods
attacks = AdversarialAttacks()

# Show original images
fig, axes = plt.subplots(1, 8, figsize=(16, 2))
for i in range(8):
    axes[i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[i].set_title(f'Original {i}')
    axes[i].axis('off')
plt.suptitle('Original Test Images')
plt.tight_layout()
plt.show()

In [1]:
# FGSM Attack - Comprehensive Analysis
print("\n=== FGSM Attack - Complete Pipeline ===")
epsilon_fgsm = 0.1
fgsm_adversarial = attacks.fgsm_attack(model, test_data, test_data, epsilon=epsilon_fgsm)

# Get reconstructions
with torch.no_grad():
    original_recon, _, _ = model(test_data)
    adversarial_recon, _, _ = model(fgsm_adversarial)

# Create comprehensive visualization
fig, axes = plt.subplots(6, 8, figsize=(16, 12))
row_labels = [
    'Original Input',
    'Original Reconstruction', 
    'Adversarial Input',
    'Adversarial Reconstruction',
    'Input Difference (x10)',
    'Reconstruction Difference (x10)'
]

for i in range(8):
    # Row 1: Original inputs
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Sample {i}')
    axes[0, i].axis('off')
    
    # Row 2: Original reconstructions
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
    
    # Row 3: Adversarial inputs (epsilon-perturbed)
    axes[2, i].imshow(fgsm_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].axis('off')
    
    # Row 4: Adversarial reconstructions
    axes[3, i].imshow(adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[3, i].axis('off')
    
    # Row 5: Input differences (original vs adversarial, amplified)
    input_diff = (fgsm_adversarial[i] - test_data[i]).detach().cpu().squeeze()
    axes[4, i].imshow(input_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[4, i].axis('off')
    
    # Row 6: Reconstruction differences (original recon vs adversarial recon, amplified)
    recon_diff = (adversarial_recon[i] - original_recon[i]).detach().cpu().squeeze()
    axes[5, i].imshow(recon_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[5, i].axis('off')

# Add row labels
for i, label in enumerate(row_labels):
    axes[i, 0].set_ylabel(label, rotation=90, labelpad=50, fontsize=10, ha='center')

plt.suptitle(f'FGSM Attack Analysis (ε={epsilon_fgsm})', fontsize=14, y=0.98)
plt.tight_layout()
plt.subplots_adjust(left=0.15)
plt.show()

# Calculate and display statistics
input_perturbation = (fgsm_adversarial - test_data).detach().cpu()
recon_perturbation = (adversarial_recon - original_recon).detach().cpu()

print(f"\n📊 FGSM Attack Statistics:")
print(f"Input Perturbation:")
print(f"  L2 norm: {torch.norm(input_perturbation).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(input_perturbation)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(input_perturbation)).item():.6f}")

print(f"\nReconstruction Perturbation:")
print(f"  L2 norm: {torch.norm(recon_perturbation).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(recon_perturbation)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(recon_perturbation)).item():.6f}")

print(f"\nAmplification Factor: {torch.norm(recon_perturbation).item() / torch.norm(input_perturbation).item():.2f}x")


=== FGSM Attack - Complete Pipeline ===


NameError: name 'attacks' is not defined

In [None]:
# PGD Attack - Comprehensive Analysis
print("\n=== PGD Attack - Complete Pipeline ===")
epsilon_pgd = 0.1
alpha = 0.01
num_iter = 20

pgd_adversarial = attacks.pgd_attack(model, test_data, test_data, 
                                   epsilon=epsilon_pgd, alpha=alpha, num_iter=num_iter)

# Get reconstructions
with torch.no_grad():
    original_recon, _, _ = model(test_data)
    pgd_adversarial_recon, _, _ = model(pgd_adversarial)

# Create comprehensive visualization
fig, axes = plt.subplots(6, 8, figsize=(16, 12))
row_labels = [
    'Original Input',
    'Original Reconstruction', 
    'PGD Adversarial Input',
    'PGD Adversarial Reconstruction',
    'Input Difference (x10)',
    'Reconstruction Difference (x10)'
]

for i in range(8):
    # Row 1: Original inputs
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Sample {i}')
    axes[0, i].axis('off')
    
    # Row 2: Original reconstructions
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
    
    # Row 3: PGD adversarial inputs
    axes[2, i].imshow(pgd_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].axis('off')
    
    # Row 4: PGD adversarial reconstructions
    axes[3, i].imshow(pgd_adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[3, i].axis('off')
    
    # Row 5: Input differences (original vs PGD adversarial, amplified)
    input_diff = (pgd_adversarial[i] - test_data[i]).detach().cpu().squeeze()
    axes[4, i].imshow(input_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[4, i].axis('off')
    
    # Row 6: Reconstruction differences (original recon vs PGD adversarial recon, amplified)
    recon_diff = (pgd_adversarial_recon[i] - original_recon[i]).detach().cpu().squeeze()
    axes[5, i].imshow(recon_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[5, i].axis('off')

# Add row labels
for i, label in enumerate(row_labels):
    axes[i, 0].set_ylabel(label, rotation=90, labelpad=50, fontsize=10, ha='center')

plt.suptitle(f'PGD Attack Analysis (ε={epsilon_pgd}, α={alpha}, iter={num_iter})', fontsize=14, y=0.98)
plt.tight_layout()
plt.subplots_adjust(left=0.15)
plt.show()

# Calculate and display statistics
input_perturbation_pgd = (pgd_adversarial - test_data).detach().cpu()
recon_perturbation_pgd = (pgd_adversarial_recon - original_recon).detach().cpu()

print(f"\n📊 PGD Attack Statistics:")
print(f"Input Perturbation:")
print(f"  L2 norm: {torch.norm(input_perturbation_pgd).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(input_perturbation_pgd)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(input_perturbation_pgd)).item():.6f}")

print(f"\nReconstruction Perturbation:")
print(f"  L2 norm: {torch.norm(recon_perturbation_pgd).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(recon_perturbation_pgd)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(recon_perturbation_pgd)).item():.6f}")

print(f"\nAmplification Factor: {torch.norm(recon_perturbation_pgd).item() / torch.norm(input_perturbation_pgd).item():.2f}x")

In [None]:
# Compare FGSM vs PGD Attack Effects
print("\n=== FGSM vs PGD Comparison ===")

# Select first 4 samples for detailed comparison
num_samples = 4
fig, axes = plt.subplots(7, num_samples, figsize=(12, 14))

comparison_labels = [
    'Original Input',
    'Original Reconstruction',
    'FGSM Adversarial',
    'FGSM Reconstruction', 
    'PGD Adversarial',
    'PGD Reconstruction',
    'FGSM vs PGD Diff (x5)'
]

for i in range(num_samples):
    # Row 1: Original inputs
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Sample {i}')
    axes[0, i].axis('off')
    
    # Row 2: Original reconstructions
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
    
    # Row 3: FGSM adversarial inputs
    axes[2, i].imshow(fgsm_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].axis('off')
    
    # Row 4: FGSM adversarial reconstructions
    axes[3, i].imshow(adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[3, i].axis('off')
    
    # Row 5: PGD adversarial inputs
    axes[4, i].imshow(pgd_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[4, i].axis('off')
    
    # Row 6: PGD adversarial reconstructions
    axes[5, i].imshow(pgd_adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[5, i].axis('off')
    
    # Row 7: Difference between FGSM and PGD adversarial inputs
    fgsm_vs_pgd_diff = (fgsm_adversarial[i] - pgd_adversarial[i]).detach().cpu().squeeze()
    axes[6, i].imshow(fgsm_vs_pgd_diff * 5, cmap='RdBu', vmin=-1, vmax=1)
    axes[6, i].axis('off')

# Add row labels
for i, label in enumerate(comparison_labels):
    axes[i, 0].set_ylabel(label, rotation=90, labelpad=40, fontsize=9, ha='center')

plt.suptitle('FGSM vs PGD Attack Comparison', fontsize=14, y=0.98)
plt.tight_layout()
plt.subplots_adjust(left=0.18)
plt.show()

# Quantitative comparison
print(f"\n📈 Attack Method Comparison:")
print(f"{'Metric':<25} {'FGSM':<12} {'PGD':<12} {'Ratio (PGD/FGSM)':<15}")
print("-" * 70)

fgsm_input_l2 = torch.norm(input_perturbation).item()
pgd_input_l2 = torch.norm(input_perturbation_pgd).item()
print(f"{'Input L2 Perturbation':<25} {fgsm_input_l2:<12.6f} {pgd_input_l2:<12.6f} {pgd_input_l2/fgsm_input_l2:<15.2f}")

fgsm_recon_l2 = torch.norm(recon_perturbation).item()
pgd_recon_l2 = torch.norm(recon_perturbation_pgd).item()
print(f"{'Recon L2 Perturbation':<25} {fgsm_recon_l2:<12.6f} {pgd_recon_l2:<12.6f} {pgd_recon_l2/fgsm_recon_l2:<15.2f}")

fgsm_input_linf = torch.max(torch.abs(input_perturbation)).item()
pgd_input_linf = torch.max(torch.abs(input_perturbation_pgd)).item()
print(f"{'Input L∞ Perturbation':<25} {fgsm_input_linf:<12.6f} {pgd_input_linf:<12.6f} {pgd_input_linf/fgsm_input_linf:<15.2f}")

fgsm_recon_linf = torch.max(torch.abs(recon_perturbation)).item()
pgd_recon_linf = torch.max(torch.abs(recon_perturbation_pgd)).item()
print(f"{'Recon L∞ Perturbation':<25} {fgsm_recon_linf:<12.6f} {pgd_recon_linf:<12.6f} {pgd_recon_linf/fgsm_recon_linf:<15.2f}")

# Check if attacks are different
attack_similarity = F.mse_loss(fgsm_adversarial, pgd_adversarial).item()
print(f"\n🔍 Attack Similarity (MSE between FGSM and PGD adversarial inputs): {attack_similarity:.6f}")
if attack_similarity < 1e-6:
    print("⚠️  Warning: FGSM and PGD attacks produced nearly identical results!")
else:
    print("✓ FGSM and PGD attacks produced different adversarial examples.")

In [None]:
# Latent Space Attack
print("\n=== Latent Space Attack ===")
epsilon_latent = 2.0

latent_adversarial, orig_latent, perturbed_latent = attacks.latent_space_attack(
    model, test_data, epsilon=epsilon_latent)

# Get original reconstructions for comparison
with torch.no_grad():
    original_recon, _, _ = model(test_data)

# Visualize latent space attack results
fig, axes = plt.subplots(3, 8, figsize=(16, 6))
for i in range(8):
    # Original
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Original {i}')
    axes[0, i].axis('off')
    
    # Original reconstruction
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].set_title(f'Original Recon')
    axes[1, i].axis('off')
    
    # Latent attack result
    axes[2, i].imshow(latent_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].set_title(f'Latent Attack')
    axes[2, i].axis('off')

plt.suptitle(f'Latent Space Attack Results (ε={epsilon_latent})')
plt.tight_layout()
plt.show()

# Show latent space perturbations
print(f"Latent space perturbation magnitude: {torch.norm(perturbed_latent - orig_latent).item():.6f}")
print(f"Original latent mean: {orig_latent.mean(dim=0).detach().cpu().numpy()}")
print(f"Perturbed latent mean: {perturbed_latent.mean(dim=0).detach().cpu().numpy()}")

## 7. Evaluate Attack Effectiveness

In [None]:
def evaluate_attack_effectiveness(model, original, adversarial, attack_name):
    """Evaluate the effectiveness of adversarial attacks"""
    model.eval()
    
    with torch.no_grad():
        # Reconstruct original
        recon_orig, mu_orig, logvar_orig = model(original)
        
        # Reconstruct adversarial
        recon_adv, mu_adv, logvar_adv = model(adversarial)
        
        # Calculate reconstruction errors
        orig_error = F.mse_loss(recon_orig, original).item()
        adv_error = F.mse_loss(recon_adv, adversarial).item()
        
        # Calculate latent space distances
        latent_distance = F.mse_loss(mu_orig, mu_adv).item()
        
        # Calculate input perturbation
        input_perturbation = F.mse_loss(original, adversarial).item()
        
        print(f"\n=== {attack_name} Effectiveness ===")
        print(f"Original Reconstruction Error: {orig_error:.6f}")
        print(f"Adversarial Reconstruction Error: {adv_error:.6f}")
        print(f"Latent Space Distance: {latent_distance:.6f}")
        print(f"Input Perturbation (MSE): {input_perturbation:.6f}")
        
        return orig_error, adv_error, latent_distance, input_perturbation

# Evaluate all attacks
fgsm_results = evaluate_attack_effectiveness(model, test_data, fgsm_adversarial, "FGSM")
pgd_results = evaluate_attack_effectiveness(model, test_data, pgd_adversarial, "PGD")

# For latent attack, compare original reconstruction vs latent attack result
with torch.no_grad():
    orig_recon, orig_mu, orig_logvar = model(test_data)
    latent_mse = F.mse_loss(orig_recon, latent_adversarial).item()
    print(f"\n=== Latent Space Attack Effectiveness ===")
    print(f"Original vs Latent Attack Reconstruction MSE: {latent_mse:.6f}")

## 8. Robustness Analysis

In [None]:
# Test robustness across different epsilon values
print("\n=== Robustness Analysis ===")
epsilons = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
fgsm_errors = []
pgd_errors = []
perturbation_magnitudes = []

for eps in tqdm(epsilons, desc="Testing epsilon values"):
    # FGSM
    fgsm_adv = attacks.fgsm_attack(model, test_data, test_data, epsilon=eps)
    _, fgsm_error, _, fgsm_pert = evaluate_attack_effectiveness(model, test_data, fgsm_adv, f"FGSM-{eps}")
    fgsm_errors.append(fgsm_error)
    
    # PGD
    pgd_adv = attacks.pgd_attack(model, test_data, test_data, epsilon=eps, alpha=0.01, num_iter=10)
    _, pgd_error, _, pgd_pert = evaluate_attack_effectiveness(model, test_data, pgd_adv, f"PGD-{eps}")
    pgd_errors.append(pgd_error)
    
    perturbation_magnitudes.append(eps)

# Plot robustness curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(epsilons, fgsm_errors, 'o-', label='FGSM', linewidth=2)
plt.plot(epsilons, pgd_errors, 's-', label='PGD', linewidth=2)
plt.xlabel('Epsilon (Perturbation Magnitude)')
plt.ylabel('Reconstruction Error')
plt.title('VAE Robustness vs Perturbation Magnitude')
plt.legend()
plt.grid(True)

# Test different latent space perturbation magnitudes
latent_epsilons = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
latent_errors = []

with torch.no_grad():
    orig_recon, _, _ = model(test_data)

for eps in latent_epsilons:
    latent_adv, _, _ = attacks.latent_space_attack(model, test_data, epsilon=eps)
    error = F.mse_loss(orig_recon, latent_adv).item()
    latent_errors.append(error)

plt.subplot(1, 3, 2)
plt.plot(latent_epsilons, latent_errors, '^-', color='green', linewidth=2)
plt.xlabel('Latent Space Perturbation Magnitude')
plt.ylabel('Reconstruction Difference (MSE)')
plt.title('Latent Space Attack Effectiveness')
plt.grid(True)

# Compare attack methods
plt.subplot(1, 3, 3)
methods = ['Original', 'FGSM\n(ε=0.1)', 'PGD\n(ε=0.1)', 'Latent\n(ε=2.0)']
errors = [fgsm_results[0], fgsm_results[1], pgd_results[1], latent_mse]
colors = ['blue', 'red', 'orange', 'green']

bars = plt.bar(methods, errors, color=colors, alpha=0.7)
plt.ylabel('Reconstruction Error / MSE')
plt.title('Attack Method Comparison')
plt.xticks(rotation=45)

# Add value labels on bars
for bar, error in zip(bars, errors):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.0001, 
             f'{error:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 9. Visualize Latent Space Perturbations

In [None]:
# Visualize how attacks affect the latent space
model.eval()
with torch.no_grad():
    # Get latent representations
    mu_orig, _ = model.encoder(test_data[:4])
    mu_fgsm, _ = model.encoder(fgsm_adversarial[:4])
    mu_pgd, _ = model.encoder(pgd_adversarial[:4])

# Plot latent space movements
plt.figure(figsize=(12, 8))
colors = ['red', 'blue', 'green', 'orange']

for i in range(4):
    # Original point
    plt.scatter(mu_orig[i, 0].cpu(), mu_orig[i, 1].cpu(), 
               color=colors[i], s=100, marker='o', label=f'Original {i}' if i < 4 else "")
    
    # FGSM point
    plt.scatter(mu_fgsm[i, 0].cpu(), mu_fgsm[i, 1].cpu(), 
               color=colors[i], s=100, marker='x', alpha=0.7)
    
    # PGD point
    plt.scatter(mu_pgd[i, 0].cpu(), mu_pgd[i, 1].cpu(), 
               color=colors[i], s=100, marker='^', alpha=0.7)
    
    # Draw arrows showing movement
    plt.arrow(mu_orig[i, 0].cpu(), mu_orig[i, 1].cpu(),
             mu_fgsm[i, 0].cpu() - mu_orig[i, 0].cpu(),
             mu_fgsm[i, 1].cpu() - mu_orig[i, 1].cpu(),
             color=colors[i], alpha=0.5, head_width=0.1, linestyle='--')
    
    plt.arrow(mu_orig[i, 0].cpu(), mu_orig[i, 1].cpu(),
             mu_pgd[i, 0].cpu() - mu_orig[i, 0].cpu(),
             mu_pgd[i, 1].cpu() - mu_orig[i, 1].cpu(),
             color=colors[i], alpha=0.5, head_width=0.1, linestyle='-')

# Create custom legend
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=8, label='Original'),
    Line2D([0], [0], marker='x', color='w', markerfacecolor='black', markersize=8, label='FGSM'),
    Line2D([0], [0], marker='^', color='w', markerfacecolor='black', markersize=8, label='PGD'),
    Line2D([0], [0], color='black', linestyle='--', label='FGSM Movement'),
    Line2D([0], [0], color='black', linestyle='-', label='PGD Movement')
]

plt.legend(handles=legend_elements, loc='upper right')
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('Latent Space Perturbations from Adversarial Attacks')
plt.grid(True, alpha=0.3)
plt.show()

## 10. Defense Mechanisms (Bonus)

Here are some strategies to improve VAE robustness against adversarial attacks:

In [None]:
def adversarial_training_step(model, data, optimizer, epsilon=0.1, alpha=0.01):
    """
    Single step of adversarial training
    """
    model.train()
    
    # Generate adversarial examples
    model.eval()
    adv_data = AdversarialAttacks.fgsm_attack(model, data, data, epsilon)
    model.train()
    
    # Train on both clean and adversarial data
    optimizer.zero_grad()
    
    # Clean loss
    recon_clean, mu_clean, logvar_clean = model(data)
    clean_loss = vae_loss(recon_clean, data, mu_clean, logvar_clean)
    
    # Adversarial loss
    recon_adv, mu_adv, logvar_adv = model(adv_data)
    adv_loss = vae_loss(recon_adv, adv_data, mu_adv, logvar_adv)
    
    # Combined loss
    total_loss = 0.5 * clean_loss + 0.5 * adv_loss
    
    total_loss.backward()
    optimizer.step()
    
    return total_loss.item()

print("Defense Strategies for VAEs:")
print("1. Adversarial Training: Train on both clean and adversarial examples")
print("2. Input Preprocessing: Add noise or apply transformations")
print("3. Regularization: Increase β in β-VAE to enforce stronger regularization")
print("4. Ensemble Methods: Use multiple VAE models and average predictions")
print("5. Certified Defenses: Use techniques like randomized smoothing")

# Example: Train a model with higher β for better regularization
robust_model = VAE(latent_dim=2)
print("\nTraining a more robust VAE with β=5.0...")
robust_losses = train_vae(robust_model, train_loader, epochs=5, beta=5.0)

# Test robustness of the new model
print("\nTesting robustness of β-VAE:")
robust_fgsm = attacks.fgsm_attack(robust_model, test_data, test_data, epsilon=0.1)
evaluate_attack_effectiveness(robust_model, test_data, robust_fgsm, "Robust β-VAE FGSM")

## Summary

This notebook demonstrated several key concepts in adversarial attacks on VAEs:

### Attack Methods:
1. **FGSM (Fast Gradient Sign Method)**: Single-step attack using gradient sign
2. **PGD (Projected Gradient Descent)**: Multi-step iterative attack
3. **Latent Space Attack**: Perturbations in the encoded latent representation

### Key Findings:
- VAEs are vulnerable to adversarial perturbations in both input and latent spaces
- Small input perturbations can cause significant changes in latent representations
- The 2D latent space makes visualization of attack effects possible
- Different attack methods have varying effectiveness

### Defense Strategies:
- Adversarial training with mixed clean/adversarial data
- Stronger regularization (higher β in β-VAE)
- Input preprocessing and ensemble methods
- Certified defense techniques

This framework can be extended to other autoencoder architectures and datasets to study adversarial robustness in generative models.