# Adversarial Attack Analysis on VAE

This notebook loads a pre-trained VAE model and performs comprehensive adversarial attack analysis using functions from the `attack_analysis` module.

## Analysis Components:
1. **Visual Attack Analysis**: 6-row visualization showing original, adversarial, and difference images
2. **Attack Method Comparison**: Effectiveness comparison across FGSM, PGD, and Custom attacks
3. **Efficiency Scaling**: How attack success scales with epsilon parameter
4. **Latent Space Analysis**: How attacks affect digit clustering in latent space

## Attack Methods:
- **FGSM** (Fast Gradient Sign Method)
- **PGD** (Projected Gradient Descent)
- **Custom Iterative Attack**
- **Latent Space Attack**

In [None]:
# Import necessary libraries
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import glob
from datetime import datetime

# Import VAE and attack classes
from adversarial_vae_attack import (
    VAE, 
    AdversarialAttacks,
    get_mnist_loaders, 
    get_device,
    vae_loss
)

# Import attack analysis functions
from attack_analysis import (
    visualize_attack_analysis,
    custom_vae_attack,
    compare_attack_effectiveness,
    analyze_attack_success_rates,
    comprehensive_attack_efficiency,
    plot_attack_efficiency_scaling,
    plot_simple_efficiency_summary,
    sample_digits_by_class,
    analyze_latent_clustering_under_attack,
    plot_latent_clustering_comparison,
    compare_latent_disruption_across_attacks,
    save_analysis_results
)

print("✅ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Setup configuration
device = get_device()

CONFIG = {
    'latent_dim': 2,
    'batch_size_test': 64,
    'data_dir': './data',
    'n_attack_samples': 8,  # Number of samples for visualization
    'attack_epsilons': [0.05, 0.1, 0.15, 0.2],  # Perturbation strengths
    'efficiency_samples': 1000,  # Samples for efficiency analysis
    'clustering_samples': 12    # Samples per digit for clustering analysis
}

print("📋 Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

In [None]:
# Load the most recent trained model
def load_latest_model():
    """Load the most recently saved VAE model and its metadata"""
    model_files = glob.glob("vae_model_*.pth")
    
    if not model_files:
        raise FileNotFoundError("No trained VAE models found. Please run vae_trainer1.ipynb first.")
    
    model_files.sort()
    latest_model_file = model_files[-1]
    timestamp = latest_model_file.split('_')[-1].split('.')[0]
    
    # Load metadata
    metadata_file = f"vae_metadata_{timestamp}.json"
    metadata = None
    if os.path.exists(metadata_file):
        with open(metadata_file, 'r') as f:
            metadata = json.load(f)
    
    # Initialize and load model
    model = VAE(latent_dim=CONFIG['latent_dim'])
    model.load_state_dict(torch.load(latest_model_file, map_location=device))
    model.to(device)
    model.eval()
    
    print(f"📁 Loaded model: {latest_model_file}")
    if metadata:
        print(f"🕒 Training timestamp: {metadata['timestamp']}")
        print(f"📊 Final training loss: {metadata['training_info']['final_loss']:.6f}")
        print(f"🎯 Loss reduction: {metadata['training_info']['loss_reduction_percent']:.2f}%")
    
    return model, metadata, timestamp

# Load model and data
model, metadata, model_timestamp = load_latest_model()

print("\n📊 Loading MNIST test dataset...")
_, test_loader = get_mnist_loaders(
    batch_size_train=128,
    batch_size_test=CONFIG['batch_size_test'],
    data_dir=CONFIG['data_dir']
)

# Get sample data
test_iter = iter(test_loader)
test_images, test_labels = next(test_iter)
test_images = test_images.to(device)
test_labels = test_labels.to(device)

sample_images = test_images[:CONFIG['n_attack_samples']]
sample_labels = test_labels[:CONFIG['n_attack_samples']]

print(f"✅ Dataset loaded: {len(test_loader.dataset)} test samples")
print(f"📝 Sample labels: {sample_labels.cpu().numpy()}")

## 🔥 FGSM Attack Analysis

Fast Gradient Sign Method - simple but effective single-step attack.

In [None]:
# FGSM Attack Analysis
print("🔥 Performing FGSM Attack Analysis...")

attacks = AdversarialAttacks()

# Get original reconstructions
with torch.no_grad():
    original_recons, _, _ = model(sample_images)

# Test different epsilon values
for epsilon in CONFIG['attack_epsilons']:
    print(f"\n🎯 Testing FGSM with epsilon = {epsilon}")
    
    # Perform FGSM attack
    fgsm_images = attacks.fgsm_attack(model, sample_images, sample_images, epsilon)
    
    # Get adversarial reconstructions
    with torch.no_grad():
        fgsm_recons, _, _ = model(fgsm_images)
    
    # Visualize attack analysis
    visualize_attack_analysis(
        sample_images, original_recons, fgsm_images, fgsm_recons,
        "FGSM", epsilon, sample_labels
    )

## 🔥 PGD Attack Analysis

Projected Gradient Descent - iterative attack with multiple refinement steps.

In [None]:
# PGD Attack Analysis
print("🔥 Performing PGD Attack Analysis...")

for epsilon in CONFIG['attack_epsilons']:
    print(f"\n🎯 Testing PGD with epsilon = {epsilon}")
    
    # Perform PGD attack
    pgd_images = attacks.pgd_attack(
        model, sample_images, sample_images, 
        epsilon=epsilon, alpha=epsilon/10, num_iter=20
    )
    
    # Get adversarial reconstructions
    with torch.no_grad():
        pgd_recons, _, _ = model(pgd_images)
    
    # Visualize attack analysis
    visualize_attack_analysis(
        sample_images, original_recons, pgd_images, pgd_recons,
        "PGD", epsilon, sample_labels
    )

## 🔥 Latent Space Attack Analysis

Attack that directly manipulates latent space representations.

In [None]:
# Latent Space Attack Analysis
print("🔥 Performing Latent Space Attack Analysis...")

latent_epsilons = [0.5, 1.0, 2.0, 3.0]

for epsilon in latent_epsilons:
    print(f"\n🎯 Testing Latent Attack with epsilon = {epsilon}")
    
    # Perform latent space attack
    latent_images, orig_latent, perturbed_latent = attacks.latent_space_attack(
        model, sample_images, epsilon
    )
    
    # Visualize attack analysis
    visualize_attack_analysis(
        sample_images, original_recons, latent_images, latent_images,
        "Latent Space", epsilon, sample_labels
    )

## 🔥 Custom Iterative Attack Analysis

Custom attack combining reconstruction and KL divergence targeting.

In [None]:
# Custom Iterative Attack Analysis
print("🔥 Performing Custom Iterative Attack Analysis...")

for epsilon in CONFIG['attack_epsilons']:
    print(f"\n🎯 Testing Custom Attack with epsilon = {epsilon}")
    
    # Perform custom attack
    custom_images = custom_vae_attack(model, sample_images, epsilon, num_iter=15)
    
    # Get adversarial reconstructions
    with torch.no_grad():
        custom_recons, _, _ = model(custom_images)
    
    # Visualize attack analysis
    visualize_attack_analysis(
        sample_images, original_recons, custom_images, custom_recons,
        "Custom Iterative", epsilon, sample_labels
    )

## 📊 Attack Effectiveness Comparison

Compare different attack methods across various metrics.

In [None]:
# Attack Effectiveness Comparison
print("📊 Comparing Attack Effectiveness...")

comparison_results = compare_attack_effectiveness(model, sample_images, epsilon=0.1)

# Display results in table format
print("\n" + "="*80)
print("ATTACK EFFECTIVENESS COMPARISON (ε=0.1)")
print("="*80)
print(f"{'Attack':<15} {'Img MSE':<10} {'Img L∞':<10} {'Recon MSE':<12} {'Latent MSE':<12}")
print("-"*80)

for attack_name, metrics in comparison_results.items():
    print(f"{attack_name:<15} {metrics['img_mse']:<10.6f} {metrics['img_linf']:<10.6f} "
          f"{metrics['recon_mse']:<12.6f} {metrics['latent_mse']:<12.6f}")

print("="*80)
print("Metrics:")
print("- Img MSE: Mean squared error between original and adversarial images")
print("- Img L∞: Maximum absolute perturbation (should be ≤ ε)")
print("- Recon MSE: Difference in VAE reconstructions")
print("- Latent MSE: Difference in latent space representations")

## 📈 Attack Success Rate Analysis

Analyze how attack success varies with perturbation strength and across different digits.

In [None]:
# Attack Success Rate Analysis
print("📈 Analyzing Attack Success Rates...")

success_results, total_tested = analyze_attack_success_rates(
    model, test_loader, CONFIG['attack_epsilons'], device, max_batches=5
)

print(f"\nTested on {total_tested} samples")
print("\n" + "="*60)
print("ATTACK SUCCESS RATES BY EPSILON")
print("="*60)
print(f"{'Epsilon':<10} {'FGSM Success':<15} {'PGD Success':<15}")
print("-"*60)

for eps in CONFIG['attack_epsilons']:
    fgsm_rate = np.mean(success_results[eps]['fgsm']) * 100
    pgd_rate = np.mean(success_results[eps]['pgd']) * 100
    print(f"{eps:<10} {fgsm_rate:<15.1f}% {pgd_rate:<15.1f}%")

# Plot success rates
plt.figure(figsize=(12, 5))

# Overall success rates
plt.subplot(1, 2, 1)
epsilons = CONFIG['attack_epsilons']
fgsm_rates = [np.mean(success_results[eps]['fgsm']) * 100 for eps in epsilons]
pgd_rates = [np.mean(success_results[eps]['pgd']) * 100 for eps in epsilons]

plt.plot(epsilons, fgsm_rates, 'o-', label='FGSM', linewidth=2, markersize=8)
plt.plot(epsilons, pgd_rates, 's-', label='PGD', linewidth=2, markersize=8)
plt.xlabel('Epsilon (Perturbation Strength)')
plt.ylabel('Attack Success Rate (%)')
plt.title('Attack Success Rate vs Epsilon')
plt.legend()
plt.grid(True, alpha=0.3)

# Success rates by digit (for epsilon=0.1)
plt.subplot(1, 2, 2)
eps = 0.1
digits = list(range(10))
fgsm_by_digit = [np.mean(success_results[eps]['by_digit'][d]['fgsm']) * 100 
                 if success_results[eps]['by_digit'][d]['fgsm'] else 0 for d in digits]
pgd_by_digit = [np.mean(success_results[eps]['by_digit'][d]['pgd']) * 100 
                if success_results[eps]['by_digit'][d]['pgd'] else 0 for d in digits]

x = np.arange(len(digits))
width = 0.35

plt.bar(x - width/2, fgsm_by_digit, width, label='FGSM', alpha=0.8)
plt.bar(x + width/2, pgd_by_digit, width, label='PGD', alpha=0.8)
plt.xlabel('Digit Class')
plt.ylabel('Attack Success Rate (%)')
plt.title(f'Attack Success by Digit Class (ε={eps})')
plt.xticks(x, digits)
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 🚀 Comprehensive Attack Efficiency Analysis

Evaluate attack efficiency across the entire test dataset with detailed epsilon scaling.

In [None]:
# Comprehensive Attack Efficiency Analysis
epsilon_range = np.linspace(0.01, 0.3, 20)  # 20 points from 0.01 to 0.3

print("🚀 Starting comprehensive attack efficiency analysis...")
print(f"Testing {len(epsilon_range)} epsilon values: {epsilon_range[0]:.3f} to {epsilon_range[-1]:.3f}")

# Run the analysis
efficiency_results = comprehensive_attack_efficiency(
    model, test_loader, epsilon_range, device, max_samples=CONFIG['efficiency_samples']
)

print(f"\n✅ Analysis complete! Tested on {efficiency_results['sample_count']} samples")

# Create efficiency scaling plots
print("\n📈 Creating efficiency scaling visualizations...")
plot_attack_efficiency_scaling(efficiency_results)

# Create simple summary plot
print("\n🎯 Creating simple efficiency summary plot...")
plot_simple_efficiency_summary(efficiency_results)

## 🎯 Latent Space Clustering Analysis

Analyze how adversarial attacks affect digit clustering in the 2D latent space.

In [None]:
# Latent Space Clustering Analysis
print("🎯 Starting Latent Space Clustering Analysis")
print("=" * 60)

# Sample digits from test dataset
digit_samples = sample_digits_by_class(
    test_loader, device, 
    samples_per_class=CONFIG['clustering_samples'], 
    max_batches=15
)

# Analyze clustering under different attacks
attack_methods = ['fgsm', 'pgd', 'custom']
epsilon_test = 0.15  # Use moderate epsilon for clear visualization
clustering_analyses = {}

for attack_method in attack_methods:
    print(f"\n{'='*50}")
    print(f"ANALYZING {attack_method.upper()} ATTACK")
    print(f"{'='*50}")
    
    # Perform clustering analysis
    clustering_results = analyze_latent_clustering_under_attack(
        model, digit_samples, attack_method=attack_method, epsilon=epsilon_test
    )
    
    # Store results
    clustering_analyses[attack_method] = clustering_results
    
    # Plot the results
    plot_latent_clustering_comparison(clustering_results)

# Comparative analysis
print("\n🔬 Creating comparative analysis...")
compare_latent_disruption_across_attacks(clustering_analyses)

## 💾 Save Analysis Results

Save all analysis results for future reference and comparison.

In [None]:
# Save comprehensive analysis results
print("💾 Saving analysis results...")

# Save different analysis results
analysis_files = []

# Save attack effectiveness comparison
effectiveness_file = save_analysis_results(
    comparison_results, model_timestamp, 'attack_effectiveness'
)
analysis_files.append(effectiveness_file)

# Save success rate analysis
success_data = {
    'config': CONFIG,
    'success_rates': {str(k): v for k, v in success_results.items()},
    'total_samples_tested': total_tested
}
success_file = save_analysis_results(
    success_data, model_timestamp, 'success_rates'
)
analysis_files.append(success_file)

# Save efficiency analysis
efficiency_data = {
    'sample_count': efficiency_results['sample_count'],
    'epsilon_range': efficiency_results['epsilons'].tolist(),
    'attacks': {name: data for name, data in efficiency_results.items() 
                if name not in ['epsilons', 'sample_count']}
}
efficiency_file = save_analysis_results(
    efficiency_data, model_timestamp, 'attack_efficiency'
)
analysis_files.append(efficiency_file)

# Save clustering analysis
clustering_data = {
    'epsilon_tested': epsilon_test,
    'attacks': {}
}

for method, results in clustering_analyses.items():
    clustering_data['attacks'][method] = {
        'attack_info': results['attack_info'],
        'centroid_movements': {},
        'scatter_changes': {}
    }
    
    # Calculate summary statistics
    for digit in range(10):
        if (digit in results['original']['latent'] and 
            digit in results['adversarial']['latent']):
            
            orig_centroid = np.mean(results['original']['latent'][digit], axis=0)
            adv_centroid = np.mean(results['adversarial']['latent'][digit], axis=0)
            movement = np.linalg.norm(adv_centroid - orig_centroid)
            
            orig_points = results['original']['latent'][digit]
            adv_points = results['adversarial']['latent'][digit]
            orig_scatter = np.mean(np.linalg.norm(orig_points - orig_centroid, axis=1))
            adv_scatter = np.mean(np.linalg.norm(adv_points - adv_centroid, axis=1))
            scatter_change = ((adv_scatter - orig_scatter) / orig_scatter) * 100
            
            clustering_data['attacks'][method]['centroid_movements'][str(digit)] = float(movement)
            clustering_data['attacks'][method]['scatter_changes'][str(digit)] = float(scatter_change)

clustering_file = save_analysis_results(
    clustering_data, model_timestamp, 'latent_clustering'
)
analysis_files.append(clustering_file)

# Print comprehensive summary
print("\n" + "="*80)
print("COMPREHENSIVE ADVERSARIAL ATTACK ANALYSIS COMPLETE")
print("="*80)
print(f"\n🔍 Analysis Summary:")
print(f"   • Visual attack analysis with 6-row comparative visualization")
print(f"   • Attack effectiveness comparison across FGSM, PGD, and Custom methods")
print(f"   • Success rate analysis across {len(CONFIG['attack_epsilons'])} epsilon values")
print(f"   • Comprehensive efficiency scaling with {len(epsilon_range)} epsilon points")
print(f"   • Latent space clustering analysis for {len(attack_methods)} attack methods")
print(f"   • Tested on {efficiency_results['sample_count']} samples for efficiency analysis")

print(f"\n💾 Analysis files saved:")
for file in analysis_files:
    print(f"   📁 {file}")

print(f"\n🔗 Model analyzed: {model_timestamp}")
print("\n✅ Complete adversarial vulnerability assessment finished!")
print("🛡️ Ready for defense strategy development!")