# Adversarial Attacks on Variational Autoencoders

This notebook demonstrates how to engineer adversarial attacks against a VAE using MNIST dataset with LeNet-style encoder/decoder and 2D latent space.

## Key Concepts:
- **FGSM (Fast Gradient Sign Method)**: Single-step attack using gradient sign
- **PGD (Projected Gradient Descent)**: Multi-step iterative attack
- **Latent Space Attack**: Attack in the encoded latent representation
- **VAE Vulnerabilities**: How reconstruction and regularization losses affect robustness

In [None]:
# Install required packages
!pip install torch torchvision numpy matplotlib tqdm psutil

# Install GPU monitoring tools (optional - will fallback to nvidia-smi if not available)
try:
    !pip install nvidia-ml-py
    print("✓ nvidia-ml-py installed for efficient GPU monitoring")
except:
    print("⚠ nvidia-ml-py not available, will use nvidia-smi fallback")

# Check if nvidia-smi is available
import subprocess
try:
    result = subprocess.run(['nvidia-smi', '--version'], capture_output=True, text=True)
    if result.returncode == 0:
        print("✓ nvidia-smi available for GPU monitoring")
    else:
        print("⚠ nvidia-smi not available - GPU monitoring will show zeros")
except FileNotFoundError:
    print("⚠ nvidia-smi not found - GPU monitoring will show zeros")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# GPU Monitoring utilities
import subprocess
import time
import threading
from collections import deque
import psutil

# For real-time plotting
from IPython.display import clear_output
import matplotlib.animation as animation

class GPUMonitor:
    """Real-time GPU monitoring during training"""
    
    def __init__(self, max_points=100):
        self.max_points = max_points
        self.gpu_utilization = deque(maxlen=max_points)
        self.gpu_memory = deque(maxlen=max_points)
        self.gpu_temperature = deque(maxlen=max_points)
        self.timestamps = deque(maxlen=max_points)
        self.monitoring = False
        self.monitor_thread = None
        
    def get_gpu_stats(self):
        """Get GPU statistics using nvidia-ml-py or nvidia-smi"""
        try:
            # Try using nvidia-ml-py first (more efficient)
            import pynvml
            pynvml.nvmlInit()
            handle = pynvml.nvmlDeviceGetHandleByIndex(0)
            
            # Get utilization
            util = pynvml.nvmlDeviceGetUtilizationRates(handle)
            gpu_util = util.gpu
            
            # Get memory info
            mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            memory_used = (mem_info.used / mem_info.total) * 100
            
            # Get temperature
            temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
            
            return gpu_util, memory_used, temp
            
        except ImportError:
            # Fallback to nvidia-smi command
            try:
                result = subprocess.run([
                    'nvidia-smi', '--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu',
                    '--format=csv,noheader,nounits'
                ], capture_output=True, text=True, timeout=5)
                
                if result.returncode == 0:
                    values = result.stdout.strip().split(',')
                    gpu_util = float(values[0])
                    memory_used = (float(values[1]) / float(values[2])) * 100
                    temp = float(values[3])
                    return gpu_util, memory_used, temp
                    
            except (subprocess.TimeoutExpired, FileNotFoundError, ValueError):
                pass
                
        # Return zeros if no GPU monitoring available
        return 0, 0, 0
    
    def monitor_loop(self, interval=1.0):
        """Background monitoring loop"""
        while self.monitoring:
            gpu_util, memory_used, temp = self.get_gpu_stats()
            current_time = time.time()
            
            self.gpu_utilization.append(gpu_util)
            self.gpu_memory.append(memory_used)
            self.gpu_temperature.append(temp)
            self.timestamps.append(current_time)
            
            time.sleep(interval)
    
    def start_monitoring(self, interval=1.0):
        """Start GPU monitoring in background thread"""
        if not self.monitoring:
            self.monitoring = True
            self.monitor_thread = threading.Thread(target=self.monitor_loop, args=(interval,))
            self.monitor_thread.daemon = True
            self.monitor_thread.start()
            print("GPU monitoring started...")
    
    def stop_monitoring(self):
        """Stop GPU monitoring"""
        self.monitoring = False
        if self.monitor_thread:
            self.monitor_thread.join(timeout=2)
        print("GPU monitoring stopped.")
    
    def plot_stats(self, figsize=(15, 10)):
        """Plot current GPU statistics"""
        if len(self.timestamps) == 0:
            print("No monitoring data available. Start monitoring first.")
            return
            
        fig, axes = plt.subplots(2, 2, figsize=figsize)
        
        # Convert timestamps to relative time
        if len(self.timestamps) > 0:
            start_time = self.timestamps[0]
            times = [(t - start_time) / 60 for t in self.timestamps]  # Convert to minutes
        else:
            times = []
        
        # GPU Utilization
        axes[0, 0].plot(times, list(self.gpu_utilization), 'b-', linewidth=2)
        axes[0, 0].set_title('GPU Utilization (%)')
        axes[0, 0].set_xlabel('Time (minutes)')
        axes[0, 0].set_ylabel('Utilization (%)')
        axes[0, 0].grid(True, alpha=0.3)
        axes[0, 0].set_ylim(0, 100)
        
        # GPU Memory Usage
        axes[0, 1].plot(times, list(self.gpu_memory), 'r-', linewidth=2)
        axes[0, 1].set_title('GPU Memory Usage (%)')
        axes[0, 1].set_xlabel('Time (minutes)')
        axes[0, 1].set_ylabel('Memory (%)')
        axes[0, 1].grid(True, alpha=0.3)
        axes[0, 1].set_ylim(0, 100)
        
        # GPU Temperature
        axes[1, 0].plot(times, list(self.gpu_temperature), 'g-', linewidth=2)
        axes[1, 0].set_title('GPU Temperature (°C)')
        axes[1, 0].set_xlabel('Time (minutes)')
        axes[1, 0].set_ylabel('Temperature (°C)')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Summary statistics
        if len(self.gpu_utilization) > 0:
            avg_util = sum(self.gpu_utilization) / len(self.gpu_utilization)
            max_util = max(self.gpu_utilization)
            avg_mem = sum(self.gpu_memory) / len(self.gpu_memory)
            max_mem = max(self.gpu_memory)
            avg_temp = sum(self.gpu_temperature) / len(self.gpu_temperature)
            max_temp = max(self.gpu_temperature)
            
            summary_text = f"""GPU Statistics Summary:
            
Average Utilization: {avg_util:.1f}%
Peak Utilization: {max_util:.1f}%

Average Memory: {avg_mem:.1f}%
Peak Memory: {max_mem:.1f}%

Average Temperature: {avg_temp:.1f}°C
Peak Temperature: {max_temp:.1f}°C

Monitoring Duration: {times[-1]:.1f} minutes"""
            
            axes[1, 1].text(0.05, 0.95, summary_text, transform=axes[1, 1].transAxes,
                           verticalalignment='top', fontfamily='monospace', fontsize=10)
            axes[1, 1].set_xlim(0, 1)
            axes[1, 1].set_ylim(0, 1)
            axes[1, 1].axis('off')
        
        plt.tight_layout()
        plt.show()

# Initialize GPU monitor
gpu_monitor = GPUMonitor()

# Check if GPU is available and print initial stats
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Get initial GPU stats
    initial_util, initial_mem, initial_temp = gpu_monitor.get_gpu_stats()
    print(f"Initial GPU Utilization: {initial_util}%")
    print(f"Initial GPU Memory Usage: {initial_mem:.1f}%")
    print(f"Initial GPU Temperature: {initial_temp}°C")
else:
    print("No GPU available - monitoring will show zeros")

## 1. Define VAE Architecture with LeNet-style Encoder/Decoder

In [None]:
class LeNetEncoder(nn.Module):
    """LeNet-style encoder for VAE"""
    def __init__(self, latent_dim=2):
        super(LeNetEncoder, self).__init__()
        self.latent_dim = latent_dim
        
        # Convolutional layers (LeNet-style)
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)  # 28x28 -> 28x28
        self.pool1 = nn.MaxPool2d(2, 2)  # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 14x14 -> 10x10
        self.pool2 = nn.MaxPool2d(2, 2)  # 10x10 -> 5x5
        
        # Fully connected layers
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        
        # Output layers for mean and log variance
        self.fc_mu = nn.Linear(84, latent_dim)
        self.fc_logvar = nn.Linear(84, latent_dim)
        
    def forward(self, x):
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        
        # Flatten
        x = x.view(-1, 16 * 5 * 5)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        # Output mean and log variance
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        
        return mu, logvar

In [None]:
class LeNetDecoder(nn.Module):
    """LeNet-style decoder for VAE"""
    def __init__(self, latent_dim=2):
        super(LeNetDecoder, self).__init__()
        self.latent_dim = latent_dim
        
        # Fully connected layers
        self.fc1 = nn.Linear(latent_dim, 84)
        self.fc2 = nn.Linear(84, 120)
        self.fc3 = nn.Linear(120, 16 * 5 * 5)
        
        # Transposed convolutional layers (reverse of encoder)
        self.deconv1 = nn.ConvTranspose2d(16, 6, kernel_size=5, stride=2, padding=2, output_padding=1)  # 5x5 -> 10x10
        self.deconv2 = nn.ConvTranspose2d(6, 1, kernel_size=5, stride=2, padding=2, output_padding=1)   # 10x10 -> 20x20
        # Add padding to get from 20x20 to 28x28
        self.final_conv = nn.ConvTranspose2d(1, 1, kernel_size=9, stride=1, padding=0)  # 20x20 -> 28x28
        
    def forward(self, z):
        # Fully connected layers
        x = F.relu(self.fc1(z))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        
        # Reshape to feature maps
        x = x.view(-1, 16, 5, 5)
        
        # Transposed convolutional layers
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = torch.sigmoid(self.final_conv(x))
        
        return x

In [None]:
class VAE(nn.Module):
    """Variational Autoencoder with LeNet-style architecture"""
    def __init__(self, latent_dim=2):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = LeNetEncoder(latent_dim)
        self.decoder = LeNetDecoder(latent_dim)
        
    def reparameterize(self, mu, logvar):
        """Reparameterization trick"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar
    
    def encode(self, x):
        """Encode input to latent space"""
        mu, logvar = self.encoder(x)
        return self.reparameterize(mu, logvar)
    
    def decode(self, z):
        """Decode from latent space"""
        return self.decoder(z)

def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    """VAE loss function with KL divergence"""
    # Reconstruction loss
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL divergence loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + beta * kl_loss

## 2. Load MNIST Dataset

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Visualize some samples
def show_samples(loader, num_samples=8):
    data_iter = iter(loader)
    images, labels = next(data_iter)
    
    fig, axes = plt.subplots(1, num_samples, figsize=(12, 2))
    for i in range(num_samples):
        axes[i].imshow(images[i].squeeze(), cmap='gray')
        axes[i].set_title(f'Label: {labels[i]}')
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

show_samples(train_loader)

## 3. Train the VAE Model

In [None]:
def train_vae(model, train_loader, epochs=10, lr=1e-3, beta=1.0, monitor_gpu=True):
    """Train the VAE model with optional GPU monitoring"""
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    model.train()
    train_losses = []
    
    # Start GPU monitoring if requested and GPU is available
    if monitor_gpu and torch.cuda.is_available():
        gpu_monitor.start_monitoring(interval=0.5)  # Monitor every 0.5 seconds
    
    print("Training VAE...")
    try:
        for epoch in tqdm(range(epochs)):
            epoch_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.to(device)
                optimizer.zero_grad()
                
                recon_batch, mu, logvar = model(data)
                loss = vae_loss(recon_batch, data, mu, logvar, beta)
                
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                
                # Optional: Print GPU stats every 100 batches
                if monitor_gpu and torch.cuda.is_available() and batch_idx % 100 == 0:
                    gpu_util, gpu_mem, gpu_temp = gpu_monitor.get_gpu_stats()
                    if batch_idx == 0 and epoch == 0:  # Print header once
                        print(f"\nBatch | GPU Util | GPU Mem | GPU Temp | Loss")
                        print("-" * 50)
                    if epoch % 2 == 0:  # Only print for even epochs to reduce clutter
                        print(f"{batch_idx:5d} | {gpu_util:7.1f}% | {gpu_mem:6.1f}% | {gpu_temp:7.1f}°C | {loss.item():.4f}")
            
            avg_loss = epoch_loss / len(train_loader.dataset)
            train_losses.append(avg_loss)
            
            if epoch % 2 == 0:
                print(f'Epoch {epoch}, Average Loss: {avg_loss:.4f}')
                
    finally:
        # Stop monitoring when training is done or interrupted
        if monitor_gpu and torch.cuda.is_available():
            gpu_monitor.stop_monitoring()
    
    return train_losses

# Alternative: Manual GPU monitoring during training
# If you want to monitor GPU in real-time, run this cell

# Initialize and train model
model = VAE(latent_dim=2)

# Simple training without monitoring (original version)
# train_losses = train_vae(model, train_loader, epochs=10, beta=1.0, monitor_gpu=False)

In [None]:
# Initialize and train model with GPU monitoring
model = VAE(latent_dim=2)

# Train with GPU monitoring enabled
train_losses = train_vae(model, train_loader, epochs=10, beta=1.0, monitor_gpu=True)

# Plot training loss and GPU statistics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Training loss
axes[0, 0].plot(train_losses, 'b-', linewidth=2)
axes[0, 0].set_title('VAE Training Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True)

# Show GPU monitoring results if available
if torch.cuda.is_available() and len(gpu_monitor.timestamps) > 0:
    # Convert timestamps to relative time in minutes
    start_time = gpu_monitor.timestamps[0]
    times = [(t - start_time) / 60 for t in gpu_monitor.timestamps]
    
    # GPU Utilization
    axes[0, 1].plot(times, list(gpu_monitor.gpu_utilization), 'r-', linewidth=2)
    axes[0, 1].set_title('GPU Utilization During Training')
    axes[0, 1].set_xlabel('Time (minutes)')
    axes[0, 1].set_ylabel('GPU Utilization (%)')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_ylim(0, 100)
    
    # GPU Memory Usage
    axes[1, 0].plot(times, list(gpu_monitor.gpu_memory), 'g-', linewidth=2)
    axes[1, 0].set_title('GPU Memory Usage During Training')
    axes[1, 0].set_xlabel('Time (minutes)')
    axes[1, 0].set_ylabel('GPU Memory (%)')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_ylim(0, 100)
    
    # GPU Temperature
    axes[1, 1].plot(times, list(gpu_monitor.gpu_temperature), 'orange', linewidth=2)
    axes[1, 1].set_title('GPU Temperature During Training')
    axes[1, 1].set_xlabel('Time (minutes)')
    axes[1, 1].set_ylabel('Temperature (°C)')
    axes[1, 1].grid(True, alpha=0.3)
    
    # Print summary statistics
    if len(gpu_monitor.gpu_utilization) > 0:
        print(f"\n📊 GPU Training Statistics:")
        print(f"Average GPU Utilization: {sum(gpu_monitor.gpu_utilization)/len(gpu_monitor.gpu_utilization):.1f}%")
        print(f"Peak GPU Utilization: {max(gpu_monitor.gpu_utilization):.1f}%")
        print(f"Average GPU Memory: {sum(gpu_monitor.gpu_memory)/len(gpu_monitor.gpu_memory):.1f}%")
        print(f"Peak GPU Memory: {max(gpu_monitor.gpu_memory):.1f}%")
        print(f"Average GPU Temperature: {sum(gpu_monitor.gpu_temperature)/len(gpu_monitor.gpu_temperature):.1f}°C")
        print(f"Peak GPU Temperature: {max(gpu_monitor.gpu_temperature):.1f}°C")
        print(f"Training Duration: {times[-1]:.2f} minutes")
else:
    # Hide unused subplots if no GPU monitoring data
    for i in range(1, 4):
        axes.flat[i].set_visible(False)
    axes[0, 0].set_position([0.1, 0.1, 0.8, 0.8])  # Make loss plot larger
    print("⚠ No GPU monitoring data available")

plt.tight_layout()
plt.show()

In [None]:
# 🚀 Additional GPU Monitoring Utilities

def watch_gpu_realtime(duration_seconds=30, interval=1.0):
    """
    Real-time GPU monitoring for specified duration
    Useful for monitoring GPU during training in another cell
    """
    print(f"🔍 Monitoring GPU for {duration_seconds} seconds...")
    print("Time(s) | GPU Util | GPU Mem | GPU Temp")
    print("-" * 40)
    
    start_time = time.time()
    while time.time() - start_time < duration_seconds:
        current_time = time.time() - start_time
        gpu_util, gpu_mem, gpu_temp = gpu_monitor.get_gpu_stats()
        print(f"{current_time:6.1f} | {gpu_util:7.1f}% | {gpu_mem:6.1f}% | {gpu_temp:7.1f}°C")
        time.sleep(interval)

def get_gpu_memory_usage():
    """Get detailed GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9  # GB
        cached = torch.cuda.memory_reserved() / 1e9      # GB
        total = torch.cuda.get_device_properties(0).total_memory / 1e9  # GB
        
        print(f"💾 GPU Memory Usage:")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Cached:    {cached:.2f} GB") 
        print(f"  Total:     {total:.2f} GB")
        print(f"  Usage:     {(allocated/total)*100:.1f}%")
        
        return allocated, cached, total
    else:
        print("❌ No GPU available")
        return 0, 0, 0

def profile_model_memory(model, input_shape=(1, 1, 28, 28)):
    """Profile memory usage of model forward pass"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # Clear cache
        torch.cuda.reset_peak_memory_stats()
        
        # Memory before model creation
        baseline = torch.cuda.memory_allocated()
        
        # Create dummy input
        dummy_input = torch.randn(input_shape).to(device)
        
        # Forward pass
        with torch.no_grad():
            _ = model(dummy_input)
        
        peak_memory = torch.cuda.max_memory_allocated()
        current_memory = torch.cuda.memory_allocated()
        
        print(f"🧠 Model Memory Profile:")
        print(f"  Baseline:     {baseline/1e6:.1f} MB")
        print(f"  After Forward: {current_memory/1e6:.1f} MB")
        print(f"  Peak Usage:   {peak_memory/1e6:.1f} MB")
        print(f"  Model Size:   {(current_memory - baseline)/1e6:.1f} MB")
        
        return (current_memory - baseline) / 1e6
    else:
        print("❌ No GPU available for profiling")
        return 0

# Show current GPU status
print("🖥️  Current GPU Status:")
get_gpu_memory_usage()

# Usage examples:
print("\n📖 GPU Monitoring Usage:")
print("1. gpu_monitor.start_monitoring()     # Start background monitoring")
print("2. # Run your training here...")
print("3. gpu_monitor.stop_monitoring()      # Stop monitoring") 
print("4. gpu_monitor.plot_stats()           # Plot monitoring results")
print("5. watch_gpu_realtime(30)             # Watch GPU for 30 seconds")
print("6. profile_model_memory(model)        # Profile model memory usage")

## 4. Visualize Latent Space (2D)

In [None]:
def plot_latent_space(model, test_loader, device, num_samples=2000):
    """Plot the 2D latent space representation"""
    model.eval()
    latents = []
    labels = []
    
    with torch.no_grad():
        for data, label in test_loader:
            data = data.to(device)
            mu, _ = model.encoder(data)
            latents.append(mu.cpu().numpy())
            labels.append(label.numpy())
            
            if len(latents) * data.size(0) >= num_samples:
                break
    
    latents = np.concatenate(latents, axis=0)[:num_samples]
    labels = np.concatenate(labels, axis=0)[:num_samples]
    
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar(scatter)
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.title('2D Latent Space Representation (Color = Digit Class)')
    plt.grid(True)
    plt.show()
    
    return latents, labels

latents, labels = plot_latent_space(model, test_loader, device)

## 5. Implement Adversarial Attacks

In [None]:
class AdversarialAttacks:
    """Class containing various adversarial attack methods for VAEs"""
    
    @staticmethod
    def fgsm_attack(model, data, target, epsilon):
        """
        Fast Gradient Sign Method (FGSM) attack
        
        Args:
            model: VAE model
            data: input data
            target: target data (for reconstruction loss)
            epsilon: perturbation magnitude
        """
        # Set model to evaluation mode
        model.eval()
        
        # Enable gradient computation for input
        data.requires_grad = True
        
        # Forward pass
        recon_data, mu, logvar = model(data)
        
        # Calculate loss
        loss = vae_loss(recon_data, target, mu, logvar)
        
        # Zero gradients
        model.zero_grad()
        
        # Calculate gradients
        loss.backward()
        
        # Get gradient sign
        data_grad = data.grad.data
        sign_data_grad = data_grad.sign()
        
        # Create adversarial example
        perturbed_data = data + epsilon * sign_data_grad
        perturbed_data = torch.clamp(perturbed_data, 0, 1)
        
        return perturbed_data
    
    @staticmethod
    def pgd_attack(model, data, target, epsilon, alpha, num_iter):
        """
        Projected Gradient Descent (PGD) attack
        
        Args:
            model: VAE model
            data: input data
            target: target data
            epsilon: maximum perturbation
            alpha: step size
            num_iter: number of iterations
        """
        model.eval()
        
        # Initialize perturbation
        delta = torch.zeros_like(data).uniform_(-epsilon, epsilon)
        delta.requires_grad = True
        
        for i in range(num_iter):
            # Forward pass with perturbation
            perturbed_data = data + delta
            recon_data, mu, logvar = model(perturbed_data)
            
            # Calculate loss
            loss = vae_loss(recon_data, target, mu, logvar)
            
            # Calculate gradients
            loss.backward()
            
            # Update perturbation
            delta.data = delta.data + alpha * delta.grad.data.sign()
            delta.data = torch.clamp(delta.data, -epsilon, epsilon)
            delta.data = torch.clamp(data + delta.data, 0, 1) - data
            
            # Zero gradients
            delta.grad.zero_()
        
        return data + delta
    
    @staticmethod
    def latent_space_attack(model, data, epsilon, target_latent=None):
        """
        Attack in latent space by perturbing encoded representations
        
        Args:
            model: VAE model
            data: input data
            epsilon: perturbation magnitude in latent space
            target_latent: target latent representation (optional)
        """
        model.eval()
        
        # Encode to latent space
        mu, logvar = model.encoder(data)
        z = model.reparameterize(mu, logvar)
        
        if target_latent is not None:
            # Move towards target latent representation
            direction = (target_latent - z).sign()
            perturbed_z = z + epsilon * direction
        else:
            # Random perturbation in latent space
            noise = torch.randn_like(z)
            perturbed_z = z + epsilon * noise
        
        # Decode back to image space
        adversarial_recon = model.decoder(perturbed_z)
        
        return adversarial_recon, z, perturbed_z

## 6. Demonstrate Adversarial Attacks

In [None]:
# Get test samples for attacks
test_iter = iter(test_loader)
test_data, test_labels = next(test_iter)
test_data = test_data[:8].to(device)  # Use first 8 samples

# Initialize attack methods
attacks = AdversarialAttacks()

# Show original images
fig, axes = plt.subplots(1, 8, figsize=(16, 2))
for i in range(8):
    axes[i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[i].set_title(f'Original {i}')
    axes[i].axis('off')
plt.suptitle('Original Test Images')
plt.tight_layout()
plt.show()

In [None]:
# FGSM Attack - Comprehensive Analysis
print("\n=== FGSM Attack - Complete Pipeline ===")
epsilon_fgsm = 0.1
fgsm_adversarial = attacks.fgsm_attack(model, test_data, test_data, epsilon=epsilon_fgsm)

# Get reconstructions
with torch.no_grad():
    original_recon, _, _ = model(test_data)
    adversarial_recon, _, _ = model(fgsm_adversarial)

# Create comprehensive visualization
fig, axes = plt.subplots(6, 8, figsize=(16, 12))
row_labels = [
    'Original Input',
    'Original Reconstruction', 
    'Adversarial Input',
    'Adversarial Reconstruction',
    'Input Difference (x10)',
    'Reconstruction Difference (x10)'
]

for i in range(8):
    # Row 1: Original inputs
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Sample {i}')
    axes[0, i].axis('off')
    
    # Row 2: Original reconstructions
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
    
    # Row 3: Adversarial inputs (epsilon-perturbed)
    axes[2, i].imshow(fgsm_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].axis('off')
    
    # Row 4: Adversarial reconstructions
    axes[3, i].imshow(adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[3, i].axis('off')
    
    # Row 5: Input differences (original vs adversarial, amplified)
    input_diff = (fgsm_adversarial[i] - test_data[i]).detach().cpu().squeeze()
    axes[4, i].imshow(input_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[4, i].axis('off')
    
    # Row 6: Reconstruction differences (original recon vs adversarial recon, amplified)
    recon_diff = (adversarial_recon[i] - original_recon[i]).detach().cpu().squeeze()
    axes[5, i].imshow(recon_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[5, i].axis('off')

# Add row labels
for i, label in enumerate(row_labels):
    axes[i, 0].set_ylabel(label, rotation=90, labelpad=50, fontsize=10, ha='center')

plt.suptitle(f'FGSM Attack Analysis (ε={epsilon_fgsm})', fontsize=14, y=0.98)
plt.tight_layout()
plt.subplots_adjust(left=0.15)
plt.show()

# Calculate and display statistics
input_perturbation = (fgsm_adversarial - test_data).detach().cpu()
recon_perturbation = (adversarial_recon - original_recon).detach().cpu()

print(f"\n📊 FGSM Attack Statistics:")
print(f"Input Perturbation:")
print(f"  L2 norm: {torch.norm(input_perturbation).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(input_perturbation)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(input_perturbation)).item():.6f}")

print(f"\nReconstruction Perturbation:")
print(f"  L2 norm: {torch.norm(recon_perturbation).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(recon_perturbation)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(recon_perturbation)).item():.6f}")

print(f"\nAmplification Factor: {torch.norm(recon_perturbation).item() / torch.norm(input_perturbation).item():.2f}x")

In [None]:
# PGD Attack - Comprehensive Analysis
print("\n=== PGD Attack - Complete Pipeline ===")
epsilon_pgd = 0.1
alpha = 0.01
num_iter = 20

pgd_adversarial = attacks.pgd_attack(model, test_data, test_data, 
                                   epsilon=epsilon_pgd, alpha=alpha, num_iter=num_iter)

# Get reconstructions
with torch.no_grad():
    original_recon, _, _ = model(test_data)
    pgd_adversarial_recon, _, _ = model(pgd_adversarial)

# Create comprehensive visualization
fig, axes = plt.subplots(6, 8, figsize=(16, 12))
row_labels = [
    'Original Input',
    'Original Reconstruction', 
    'PGD Adversarial Input',
    'PGD Adversarial Reconstruction',
    'Input Difference (x10)',
    'Reconstruction Difference (x10)'
]

for i in range(8):
    # Row 1: Original inputs
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Sample {i}')
    axes[0, i].axis('off')
    
    # Row 2: Original reconstructions
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
    
    # Row 3: PGD adversarial inputs
    axes[2, i].imshow(pgd_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].axis('off')
    
    # Row 4: PGD adversarial reconstructions
    axes[3, i].imshow(pgd_adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[3, i].axis('off')
    
    # Row 5: Input differences (original vs PGD adversarial, amplified)
    input_diff = (pgd_adversarial[i] - test_data[i]).detach().cpu().squeeze()
    axes[4, i].imshow(input_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[4, i].axis('off')
    
    # Row 6: Reconstruction differences (original recon vs PGD adversarial recon, amplified)
    recon_diff = (pgd_adversarial_recon[i] - original_recon[i]).detach().cpu().squeeze()
    axes[5, i].imshow(recon_diff * 10, cmap='RdBu', vmin=-1, vmax=1)
    axes[5, i].axis('off')

# Add row labels
for i, label in enumerate(row_labels):
    axes[i, 0].set_ylabel(label, rotation=90, labelpad=50, fontsize=10, ha='center')

plt.suptitle(f'PGD Attack Analysis (ε={epsilon_pgd}, α={alpha}, iter={num_iter})', fontsize=14, y=0.98)
plt.tight_layout()
plt.subplots_adjust(left=0.15)
plt.show()

# Calculate and display statistics
input_perturbation_pgd = (pgd_adversarial - test_data).detach().cpu()
recon_perturbation_pgd = (pgd_adversarial_recon - original_recon).detach().cpu()

print(f"\n📊 PGD Attack Statistics:")
print(f"Input Perturbation:")
print(f"  L2 norm: {torch.norm(input_perturbation_pgd).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(input_perturbation_pgd)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(input_perturbation_pgd)).item():.6f}")

print(f"\nReconstruction Perturbation:")
print(f"  L2 norm: {torch.norm(recon_perturbation_pgd).item():.6f}")
print(f"  L∞ norm: {torch.max(torch.abs(recon_perturbation_pgd)).item():.6f}")
print(f"  Mean absolute: {torch.mean(torch.abs(recon_perturbation_pgd)).item():.6f}")

print(f"\nAmplification Factor: {torch.norm(recon_perturbation_pgd).item() / torch.norm(input_perturbation_pgd).item():.2f}x")

In [None]:
# Compare FGSM vs PGD Attack Effects
print("\n=== FGSM vs PGD Comparison ===")

# Select first 4 samples for detailed comparison
num_samples = 4
fig, axes = plt.subplots(7, num_samples, figsize=(12, 14))

comparison_labels = [
    'Original Input',
    'Original Reconstruction',
    'FGSM Adversarial',
    'FGSM Reconstruction', 
    'PGD Adversarial',
    'PGD Reconstruction',
    'FGSM vs PGD Diff (x5)'
]

for i in range(num_samples):
    # Row 1: Original inputs
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Sample {i}')
    axes[0, i].axis('off')
    
    # Row 2: Original reconstructions
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
    
    # Row 3: FGSM adversarial inputs
    axes[2, i].imshow(fgsm_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].axis('off')
    
    # Row 4: FGSM adversarial reconstructions
    axes[3, i].imshow(adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[3, i].axis('off')
    
    # Row 5: PGD adversarial inputs
    axes[4, i].imshow(pgd_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[4, i].axis('off')
    
    # Row 6: PGD adversarial reconstructions
    axes[5, i].imshow(pgd_adversarial_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[5, i].axis('off')
    
    # Row 7: Difference between FGSM and PGD adversarial inputs
    fgsm_vs_pgd_diff = (fgsm_adversarial[i] - pgd_adversarial[i]).detach().cpu().squeeze()
    axes[6, i].imshow(fgsm_vs_pgd_diff * 5, cmap='RdBu', vmin=-1, vmax=1)
    axes[6, i].axis('off')

# Add row labels
for i, label in enumerate(comparison_labels):
    axes[i, 0].set_ylabel(label, rotation=90, labelpad=40, fontsize=9, ha='center')

plt.suptitle('FGSM vs PGD Attack Comparison', fontsize=14, y=0.98)
plt.tight_layout()
plt.subplots_adjust(left=0.18)
plt.show()

# Quantitative comparison
print(f"\n📈 Attack Method Comparison:")
print(f"{'Metric':<25} {'FGSM':<12} {'PGD':<12} {'Ratio (PGD/FGSM)':<15}")
print("-" * 70)

fgsm_input_l2 = torch.norm(input_perturbation).item()
pgd_input_l2 = torch.norm(input_perturbation_pgd).item()
print(f"{'Input L2 Perturbation':<25} {fgsm_input_l2:<12.6f} {pgd_input_l2:<12.6f} {pgd_input_l2/fgsm_input_l2:<15.2f}")

fgsm_recon_l2 = torch.norm(recon_perturbation).item()
pgd_recon_l2 = torch.norm(recon_perturbation_pgd).item()
print(f"{'Recon L2 Perturbation':<25} {fgsm_recon_l2:<12.6f} {pgd_recon_l2:<12.6f} {pgd_recon_l2/fgsm_recon_l2:<15.2f}")

fgsm_input_linf = torch.max(torch.abs(input_perturbation)).item()
pgd_input_linf = torch.max(torch.abs(input_perturbation_pgd)).item()
print(f"{'Input L∞ Perturbation':<25} {fgsm_input_linf:<12.6f} {pgd_input_linf:<12.6f} {pgd_input_linf/fgsm_input_linf:<15.2f}")

fgsm_recon_linf = torch.max(torch.abs(recon_perturbation)).item()
pgd_recon_linf = torch.max(torch.abs(recon_perturbation_pgd)).item()
print(f"{'Recon L∞ Perturbation':<25} {fgsm_recon_linf:<12.6f} {pgd_recon_linf:<12.6f} {pgd_recon_linf/fgsm_recon_linf:<15.2f}")

# Check if attacks are different
attack_similarity = F.mse_loss(fgsm_adversarial, pgd_adversarial).item()
print(f"\n🔍 Attack Similarity (MSE between FGSM and PGD adversarial inputs): {attack_similarity:.6f}")
if attack_similarity < 1e-6:
    print("⚠️  Warning: FGSM and PGD attacks produced nearly identical results!")
else:
    print("✓ FGSM and PGD attacks produced different adversarial examples.")

In [None]:
# Latent Space Attack
print("\n=== Latent Space Attack ===")
epsilon_latent = 2.0

latent_adversarial, orig_latent, perturbed_latent = attacks.latent_space_attack(
    model, test_data, epsilon=epsilon_latent)

# Get original reconstructions for comparison
with torch.no_grad():
    original_recon, _, _ = model(test_data)

# Visualize latent space attack results
fig, axes = plt.subplots(3, 8, figsize=(16, 6))
for i in range(8):
    # Original
    axes[0, i].imshow(test_data[i].detach().cpu().squeeze(), cmap='gray')
    axes[0, i].set_title(f'Original {i}')
    axes[0, i].axis('off')
    
    # Original reconstruction
    axes[1, i].imshow(original_recon[i].detach().cpu().squeeze(), cmap='gray')
    axes[1, i].set_title(f'Original Recon')
    axes[1, i].axis('off')
    
    # Latent attack result
    axes[2, i].imshow(latent_adversarial[i].detach().cpu().squeeze(), cmap='gray')
    axes[2, i].set_title(f'Latent Attack')
    axes[2, i].axis('off')

plt.suptitle(f'Latent Space Attack Results (ε={epsilon_latent})')
plt.tight_layout()
plt.show()

# Show latent space perturbations
print(f"Latent space perturbation magnitude: {torch.norm(perturbed_latent - orig_latent).item():.6f}")
print(f"Original latent mean: {orig_latent.mean(dim=0).detach().cpu().numpy()}")
print(f"Perturbed latent mean: {perturbed_latent.mean(dim=0).detach().cpu().numpy()}")

## 7. Evaluate Attack Effectiveness

In [None]:
def evaluate_attack_effectiveness(model, original, adversarial, attack_name):
    """Evaluate the effectiveness of adversarial attacks"""
    model.eval()
    
    with torch.no_grad():
        # Reconstruct original
        recon_orig, mu_orig, logvar_orig = model(original)
        
        # Reconstruct adversarial
        recon_adv, mu_adv, logvar_adv = model(adversarial)
        
        # Calculate reconstruction errors
        orig_error = F.mse_loss(recon_orig, original).item()
        adv_error = F.mse_loss(recon_adv, adversarial).item()
        
        # Calculate latent space distances
        latent_distance = F.mse_loss(mu_orig, mu_adv).item()
        
        # Calculate input perturbation
        input_perturbation = F.mse_loss(original, adversarial).item()
        
        print(f"\n=== {attack_name} Effectiveness ===")
        print(f"Original Reconstruction Error: {orig_error:.6f}")
        print(f"Adversarial Reconstruction Error: {adv_error:.6f}")
        print(f"Latent Space Distance: {latent_distance:.6f}")
        print(f"Input Perturbation (MSE): {input_perturbation:.6f}")
        
        return orig_error, adv_error, latent_distance, input_perturbation

# Evaluate all attacks
fgsm_results = evaluate_attack_effectiveness(model, test_data, fgsm_adversarial, "FGSM")
pgd_results = evaluate_attack_effectiveness(model, test_data, pgd_adversarial, "PGD")

# For latent attack, compare original reconstruction vs latent attack result
with torch.no_grad():
    orig_recon, orig_mu, orig_logvar = model(test_data)
    latent_mse = F.mse_loss(orig_recon, latent_adversarial).item()
    print(f"\n=== Latent Space Attack Effectiveness ===")
    print(f"Original vs Latent Attack Reconstruction MSE: {latent_mse:.6f}")

## 8. Robustness Analysis

In [None]:
# Test robustness across different epsilon values
print("\n=== Robustness Analysis ===")
epsilons = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
fgsm_errors = []
pgd_errors = []
perturbation_magnitudes = []

for eps in tqdm(epsilons, desc="Testing epsilon values"):
    # FGSM
    fgsm_adv = attacks.fgsm_attack(model, test_data, test_data, epsilon=eps)
    _, fgsm_error, _, fgsm_pert = evaluate_attack_effectiveness(model, test_data, fgsm_adv, f"FGSM-{eps}")
    fgsm_errors.append(fgsm_error)
    
    # PGD
    pgd_adv = attacks.pgd_attack(model, test_data, test_data, epsilon=eps, alpha=0.01, num_iter=10)
    _, pgd_error, _, pgd_pert = evaluate_attack_effectiveness(model, test_data, pgd_adv, f"PGD-{eps}")
    pgd_errors.append(pgd_error)
    
    perturbation_magnitudes.append(eps)

# Plot robustness curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(epsilons, fgsm_errors, 'o-', label='FGSM', linewidth=2)
plt.plot(epsilons, pgd_errors, 's-', label='PGD', linewidth=2)
plt.xlabel('Epsilon (Perturbation Magnitude)')
plt.ylabel('Reconstruction Error')
plt.title('VAE Robustness vs Perturbation Magnitude')
plt.legend()
plt.grid(True)

# Test different latent space perturbation magnitudes
latent_epsilons = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
latent_errors = []

with torch.no_grad():
    orig_recon, _, _ = model(test_data)

for eps in latent_epsilons:
    latent_adv, _, _ = attacks.latent_space_attack(model, test_data, epsilon=eps)
    error = F.mse_loss(orig_recon, latent_adv).item()
    latent_errors.append(error)

plt.subplot(1, 3, 2)
plt.plot(latent_epsilons, latent_errors, '^-', color='green', linewidth=2)
plt.xlabel('Latent Space Perturbation Magnitude')
plt.ylabel('Reconstruction Difference (MSE)')
plt.title('Latent Space Attack Effectiveness')
plt.grid(True)

# Compare attack methods
plt.subplot(1, 3, 3)
methods = ['Original', 'FGSM\n(ε=0.1)', 'PGD\n(ε=0.1)', 'Latent\n(ε=2.0)']
errors = [fgsm_results[0], fgsm_results[1], pgd_results[1], latent_mse]
colors = ['blue', 'red', 'orange', 'green']

bars = plt.bar(methods, errors, color=colors, alpha=0.7)
plt.ylabel('Reconstruction Error / MSE')
plt.title('Attack Method Comparison')
plt.xticks(rotation=45)

# Add value labels on bars
for bar, error in zip(bars, errors):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.0001, 
             f'{error:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 9. Visualize Latent Space Perturbations

In [None]:
# Visualize how attacks affect the latent space
model.eval()
with torch.no_grad():
    # Get latent representations
    mu_orig, _ = model.encoder(test_data[:4])
    mu_fgsm, _ = model.encoder(fgsm_adversarial[:4])
    mu_pgd, _ = model.encoder(pgd_adversarial[:4])

# Plot latent space movements
plt.figure(figsize=(12, 8))
colors = ['red', 'blue', 'green', 'orange']

for i in range(4):
    # Original point
    plt.scatter(mu_orig[i, 0].cpu(), mu_orig[i, 1].cpu(), 
               color=colors[i], s=100, marker='o', label=f'Original {i}' if i < 4 else "")
    
    # FGSM point
    plt.scatter(mu_fgsm[i, 0].cpu(), mu_fgsm[i, 1].cpu(), 
               color=colors[i], s=100, marker='x', alpha=0.7)
    
    # PGD point
    plt.scatter(mu_pgd[i, 0].cpu(), mu_pgd[i, 1].cpu(), 
               color=colors[i], s=100, marker='^', alpha=0.7)
    
    # Draw arrows showing movement
    plt.arrow(mu_orig[i, 0].cpu(), mu_orig[i, 1].cpu(),
             mu_fgsm[i, 0].cpu() - mu_orig[i, 0].cpu(),
             mu_fgsm[i, 1].cpu() - mu_orig[i, 1].cpu(),
             color=colors[i], alpha=0.5, head_width=0.1, linestyle='--')
    
    plt.arrow(mu_orig[i, 0].cpu(), mu_orig[i, 1].cpu(),
             mu_pgd[i, 0].cpu() - mu_orig[i, 0].cpu(),
             mu_pgd[i, 1].cpu() - mu_orig[i, 1].cpu(),
             color=colors[i], alpha=0.5, head_width=0.1, linestyle='-')

# Create custom legend
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=8, label='Original'),
    Line2D([0], [0], marker='x', color='w', markerfacecolor='black', markersize=8, label='FGSM'),
    Line2D([0], [0], marker='^', color='w', markerfacecolor='black', markersize=8, label='PGD'),
    Line2D([0], [0], color='black', linestyle='--', label='FGSM Movement'),
    Line2D([0], [0], color='black', linestyle='-', label='PGD Movement')
]

plt.legend(handles=legend_elements, loc='upper right')
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('Latent Space Perturbations from Adversarial Attacks')
plt.grid(True, alpha=0.3)
plt.show()

## 10. Defense Mechanisms (Bonus)

Here are some strategies to improve VAE robustness against adversarial attacks:

In [None]:
def adversarial_training_step(model, data, optimizer, epsilon=0.1, alpha=0.01):
    """
    Single step of adversarial training
    """
    model.train()
    
    # Generate adversarial examples
    model.eval()
    adv_data = AdversarialAttacks.fgsm_attack(model, data, data, epsilon)
    model.train()
    
    # Train on both clean and adversarial data
    optimizer.zero_grad()
    
    # Clean loss
    recon_clean, mu_clean, logvar_clean = model(data)
    clean_loss = vae_loss(recon_clean, data, mu_clean, logvar_clean)
    
    # Adversarial loss
    recon_adv, mu_adv, logvar_adv = model(adv_data)
    adv_loss = vae_loss(recon_adv, adv_data, mu_adv, logvar_adv)
    
    # Combined loss
    total_loss = 0.5 * clean_loss + 0.5 * adv_loss
    
    total_loss.backward()
    optimizer.step()
    
    return total_loss.item()

print("Defense Strategies for VAEs:")
print("1. Adversarial Training: Train on both clean and adversarial examples")
print("2. Input Preprocessing: Add noise or apply transformations")
print("3. Regularization: Increase β in β-VAE to enforce stronger regularization")
print("4. Ensemble Methods: Use multiple VAE models and average predictions")
print("5. Certified Defenses: Use techniques like randomized smoothing")

# Example: Train a model with higher β for better regularization
robust_model = VAE(latent_dim=2)
print("\nTraining a more robust VAE with β=5.0...")
robust_losses = train_vae(robust_model, train_loader, epochs=5, beta=5.0)

# Test robustness of the new model
print("\nTesting robustness of β-VAE:")
robust_fgsm = attacks.fgsm_attack(robust_model, test_data, test_data, epsilon=0.1)
evaluate_attack_effectiveness(robust_model, test_data, robust_fgsm, "Robust β-VAE FGSM")

## Summary

This notebook demonstrated several key concepts in adversarial attacks on VAEs:

### Attack Methods:
1. **FGSM (Fast Gradient Sign Method)**: Single-step attack using gradient sign
2. **PGD (Projected Gradient Descent)**: Multi-step iterative attack
3. **Latent Space Attack**: Perturbations in the encoded latent representation

### Key Findings:
- VAEs are vulnerable to adversarial perturbations in both input and latent spaces
- Small input perturbations can cause significant changes in latent representations
- The 2D latent space makes visualization of attack effects possible
- Different attack methods have varying effectiveness

### Defense Strategies:
- Adversarial training with mixed clean/adversarial data
- Stronger regularization (higher β in β-VAE)
- Input preprocessing and ensemble methods
- Certified defense techniques

This framework can be extended to other autoencoder architectures and datasets to study adversarial robustness in generative models.