# Adversarial Attacks Analysis

## EE4745 Neural Networks Final Project

This notebook provides comprehensive analysis of adversarial attacks against sports image classification models.

### Objectives:
- Implement and demonstrate FGSM and PGD adversarial attacks
- Analyze attack effectiveness across different models
- Visualize adversarial examples and perturbations
- Study transferability of adversarial examples
- Evaluate model robustness and interpretability

---

## 1. Setup and Configuration

Import libraries and set up the environment for adversarial attack analysis.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import time
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

from sklearn.metrics import accuracy_score, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Add src directory to path
sys.path.append('../src')

# Import custom modules
from dataset.sports_dataset import SportsDataset, get_dataloaders
from models.simple_cnn import create_simple_cnn
from models.resnet_small import create_resnet_small
from training.utils import set_seed, get_device, load_checkpoint
from attacks.fgsm import FGSM
from attacks.pgd import PGD
from attacks.transferability import TransferabilityAnalyzer
from attacks.interpretability import AdversarialInterpretabilityAnalyzer
from attacks import utils as attack_utils
from interpretability.saliency import SaliencyMap
from interpretability.gradcam import GradCAM, get_target_layer

# Set style and configuration
plt.style.use('default')
sns.set_palette('Set1')
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 10

# Set seed for reproducibility
set_seed(42)

print("Adversarial Attacks Analysis Setup")
print("=" * 40)
print(f"PyTorch version: {torch.__version__}")
device = get_device()
print(f"Device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
print("Setup complete!")

## 2. Load Data and Pre-trained Models

Load the dataset and pre-trained models for attack analysis.

In [None]:
# Dataset configuration
DATA_DIR = '../data'
IMAGE_SIZE = 32
BATCH_SIZE = 16  # Smaller batch size for attack analysis
NUM_WORKERS = 2

# Load datasets
print("Loading datasets for attack analysis...")
train_loader, val_loader, num_classes = get_dataloaders(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    num_workers=NUM_WORKERS
)

class_names = SportsDataset.CLASSES
print(f"Dataset loaded: {num_classes} classes")
print(f"Classes: {class_names}")
print(f"Validation samples: {len(val_loader.dataset)}")

# Create a small subset for detailed analysis
subset_size = 100  # Small subset for detailed analysis
subset_indices = torch.randperm(len(val_loader.dataset))[:subset_size]
subset_dataset = torch.utils.data.Subset(val_loader.dataset, subset_indices)
subset_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Created subset with {subset_size} samples for detailed analysis")

### Load Pre-trained Models

Load the trained models from previous experiments or create new ones.

In [None]:
def load_or_create_model(model_type, checkpoint_path=None):
    """Load a pre-trained model or create a new one"""
    
    if model_type == 'SimpleCNN':
        model = create_simple_cnn(num_classes=num_classes, input_size=IMAGE_SIZE)
    elif model_type == 'ResNetSmall':
        model = create_resnet_small(num_classes=num_classes, input_size=IMAGE_SIZE)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    model.to(device)
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint: {checkpoint_path}")
        try:
            load_checkpoint(checkpoint_path, model)
            print("‚úÖ Checkpoint loaded successfully")
        except Exception as e:
            print(f"‚ö†Ô∏è  Failed to load checkpoint: {e}")
            print("Using randomly initialized model")
    else:
        print(f"‚ö†Ô∏è  No checkpoint found, using randomly initialized {model_type}")
    
    model.eval()
    return model

# Define possible checkpoint paths
checkpoint_paths = {
    'SimpleCNN': '../checkpoints/SimpleCNN-best.pt',
    'ResNetSmall': '../checkpoints/ResNetSmall-best.pt'
}

# Load models
print("\nLoading models...")
models = {}

for model_type in ['SimpleCNN', 'ResNetSmall']:
    print(f"\nLoading {model_type}...")
    models[model_type] = load_or_create_model(model_type, checkpoint_paths.get(model_type))
    
    # Quick evaluation on clean data
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in subset_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = models[model_type](images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    clean_accuracy = 100.0 * correct / total
    print(f"  Clean accuracy on subset: {clean_accuracy:.2f}%")

print(f"\nModels loaded: {list(models.keys())}")

## 3. FGSM Attack Implementation and Analysis

Demonstrate the Fast Gradient Sign Method (FGSM) attack.

In [None]:
def evaluate_attack(model, attack_method, dataloader, attack_name, epsilon_values):
    """Evaluate attack effectiveness across different epsilon values"""
    
    results = []
    
    for epsilon in epsilon_values:
        print(f"\nEvaluating {attack_name} with epsilon={epsilon:.4f}...")
        
        correct_clean = 0
        correct_adv = 0
        total = 0
        
        perturbation_norms = []
        
        model.eval()
        
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader, desc=f"Œµ={epsilon:.4f}")):
            images, labels = images.to(device), labels.to(device)
            
            # Clean prediction
            with torch.no_grad():
                clean_outputs = model(images)
                clean_pred = clean_outputs.argmax(dim=1)
                correct_clean += (clean_pred == labels).sum().item()
            
            # Generate adversarial examples
            if epsilon > 0:
                adv_images = attack_method.attack(images, labels, epsilon=epsilon)
                
                # Calculate perturbation norm
                perturbation = (adv_images - images).view(images.size(0), -1)
                l2_norm = torch.norm(perturbation, p=2, dim=1).mean().item()
                linf_norm = torch.norm(perturbation, p=float('inf'), dim=1).mean().item()
                perturbation_norms.append({'l2': l2_norm, 'linf': linf_norm})
                
                # Adversarial prediction
                with torch.no_grad():
                    adv_outputs = model(adv_images)
                    adv_pred = adv_outputs.argmax(dim=1)
                    correct_adv += (adv_pred == labels).sum().item()
            else:
                correct_adv = correct_clean
                perturbation_norms.append({'l2': 0.0, 'linf': 0.0})
            
            total += labels.size(0)
        
        clean_acc = 100.0 * correct_clean / total
        adv_acc = 100.0 * correct_adv / total
        success_rate = 100.0 * (correct_clean - correct_adv) / correct_clean if correct_clean > 0 else 0
        
        avg_l2_norm = np.mean([p['l2'] for p in perturbation_norms])
        avg_linf_norm = np.mean([p['linf'] for p in perturbation_norms])
        
        result = {
            'epsilon': epsilon,
            'clean_accuracy': clean_acc,
            'adversarial_accuracy': adv_acc,
            'attack_success_rate': success_rate,
            'avg_l2_perturbation': avg_l2_norm,
            'avg_linf_perturbation': avg_linf_norm
        }
        results.append(result)
        
        print(f"  Clean Acc: {clean_acc:.2f}%")
        print(f"  Adversarial Acc: {adv_acc:.2f}%")
        print(f"  Attack Success Rate: {success_rate:.2f}%")
        print(f"  Avg L2 Perturbation: {avg_l2_norm:.4f}")
        print(f"  Avg L‚àû Perturbation: {avg_linf_norm:.4f}")
    
    return results

# FGSM Attack Analysis
print("\nFGSM ATTACK ANALYSIS")
print("=" * 40)

# Epsilon values to test
epsilon_values = [0.0, 0.01, 0.03, 0.05, 0.1, 0.2, 0.3]

fgsm_results = {}

for model_name, model in models.items():
    print(f"\nAnalyzing FGSM attacks on {model_name}...")
    
    # Create FGSM attack
    fgsm_attack = FGSM(model, device=device)
    
    # Evaluate attack
    results = evaluate_attack(
        model, fgsm_attack, subset_loader, 
        f"FGSM-{model_name}", epsilon_values
    )
    
    fgsm_results[model_name] = results

# Create DataFrame for analysis
fgsm_df_data = []
for model_name, results in fgsm_results.items():
    for result in results:
        result['model'] = model_name
        result['attack'] = 'FGSM'
        fgsm_df_data.append(result)

fgsm_df = pd.DataFrame(fgsm_df_data)
print("\nFGSM attack evaluation complete!")

### FGSM Attack Visualization

In [None]:
def visualize_attack_results(df, attack_name):
    """Visualize attack effectiveness across models"""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'{attack_name} Attack Analysis', fontsize=16, fontweight='bold')
    
    models = df['model'].unique()
    colors = ['blue', 'red', 'green', 'orange']
    
    # Plot 1: Adversarial Accuracy vs Epsilon
    for i, model in enumerate(models):
        model_data = df[df['model'] == model]
        axes[0, 0].plot(model_data['epsilon'], model_data['adversarial_accuracy'], 
                       'o-', color=colors[i], label=model, linewidth=2, markersize=6)
    
    axes[0, 0].set_title('Adversarial Accuracy vs Epsilon', fontweight='bold')
    axes[0, 0].set_xlabel('Epsilon')
    axes[0, 0].set_ylabel('Adversarial Accuracy (%)')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Attack Success Rate vs Epsilon
    for i, model in enumerate(models):
        model_data = df[df['model'] == model]
        axes[0, 1].plot(model_data['epsilon'], model_data['attack_success_rate'], 
                       's-', color=colors[i], label=model, linewidth=2, markersize=6)
    
    axes[0, 1].set_title('Attack Success Rate vs Epsilon', fontweight='bold')
    axes[0, 1].set_xlabel('Epsilon')
    axes[0, 1].set_ylabel('Attack Success Rate (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: L2 Perturbation vs Epsilon
    for i, model in enumerate(models):
        model_data = df[df['model'] == model]
        axes[1, 0].plot(model_data['epsilon'], model_data['avg_l2_perturbation'], 
                       '^-', color=colors[i], label=model, linewidth=2, markersize=6)
    
    axes[1, 0].set_title('L2 Perturbation Norm vs Epsilon', fontweight='bold')
    axes[1, 0].set_xlabel('Epsilon')
    axes[1, 0].set_ylabel('Average L2 Perturbation')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Model Robustness Comparison
    epsilon_threshold = 0.1  # Compare robustness at eps=0.1
    robustness_data = df[df['epsilon'] == epsilon_threshold]
    
    if not robustness_data.empty:
        models_rob = robustness_data['model']
        adv_accs = robustness_data['adversarial_accuracy']
        
        bars = axes[1, 1].bar(range(len(models_rob)), adv_accs, 
                             color=colors[:len(models_rob)], alpha=0.7)
        axes[1, 1].set_title(f'Model Robustness (Œµ={epsilon_threshold})', fontweight='bold')
        axes[1, 1].set_xlabel('Models')
        axes[1, 1].set_ylabel('Adversarial Accuracy (%)')
        axes[1, 1].set_xticks(range(len(models_rob)))
        axes[1, 1].set_xticklabels(models_rob, rotation=45)
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add value labels on bars
        for i, bar in enumerate(bars):
            height = bar.get_height()
            axes[1, 1].text(bar.get_x() + bar.get_width()/2., height + 1,
                           f'{height:.1f}%', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

def visualize_adversarial_examples(model, attack_method, dataloader, class_names, 
                                  epsilon_values=[0.05, 0.1, 0.2], num_samples=3):
    """Visualize adversarial examples"""
    
    model.eval()
    
    # Get a batch of data
    images, labels = next(iter(dataloader))
    images, labels = images.to(device), labels.to(device)
    
    # Select samples to visualize
    sample_indices = torch.randperm(images.size(0))[:num_samples]
    sample_images = images[sample_indices]
    sample_labels = labels[sample_indices]
    
    # Create figure
    num_eps = len(epsilon_values)
    fig, axes = plt.subplots(num_samples, num_eps + 1, figsize=(4 * (num_eps + 1), 4 * num_samples))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    # Denormalization for display
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(device)
    
    for i in range(num_samples):
        # Original image
        original = sample_images[i:i+1]
        true_label = sample_labels[i].item()
        
        with torch.no_grad():
            orig_output = model(original)
            orig_pred = orig_output.argmax(dim=1).item()
            orig_conf = F.softmax(orig_output, dim=1).max().item()
        
        # Display original
        img_display = (original.squeeze() * std + mean).clamp(0, 1)
        img_display = img_display.permute(1, 2, 0).cpu().numpy()
        
        axes[i, 0].imshow(img_display)
        axes[i, 0].set_title(f'Original\nTrue: {class_names[true_label]}\n' + 
                            f'Pred: {class_names[orig_pred]}\nConf: {orig_conf:.3f}', 
                            fontsize=10)
        axes[i, 0].axis('off')
        
        # Adversarial examples for different epsilons
        for j, epsilon in enumerate(epsilon_values):
            adv_image = attack_method.attack(original, sample_labels[i:i+1], epsilon=epsilon)
            
            with torch.no_grad():
                adv_output = model(adv_image)
                adv_pred = adv_output.argmax(dim=1).item()
                adv_conf = F.softmax(adv_output, dim=1).max().item()
            
            # Calculate perturbation
            perturbation = adv_image - original
            l2_norm = torch.norm(perturbation.view(-1), p=2).item()
            linf_norm = torch.norm(perturbation.view(-1), p=float('inf')).item()
            
            # Display adversarial image
            adv_display = (adv_image.squeeze() * std + mean).clamp(0, 1)
            adv_display = adv_display.permute(1, 2, 0).cpu().numpy()
            
            axes[i, j+1].imshow(adv_display)
            
            success = '‚úì' if adv_pred != true_label else '‚úó'
            axes[i, j+1].set_title(f'Œµ={epsilon}\nPred: {class_names[adv_pred]} {success}\n' + 
                                  f'Conf: {adv_conf:.3f}\nL2: {l2_norm:.3f}', 
                                  fontsize=10)
            axes[i, j+1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize FGSM results
visualize_attack_results(fgsm_df, "FGSM")

# Visualize adversarial examples for SimpleCNN
if 'SimpleCNN' in models:
    print("\nVisualizing FGSM adversarial examples on SimpleCNN...")
    fgsm_attack = FGSM(models['SimpleCNN'], device=device)
    visualize_adversarial_examples(
        models['SimpleCNN'], fgsm_attack, subset_loader, class_names
    )

## 4. PGD Attack Implementation and Analysis

Demonstrate the Projected Gradient Descent (PGD) attack.

In [None]:
print("PGD ATTACK ANALYSIS")
print("=" * 40)

# PGD configuration
pgd_epsilon_values = [0.0, 0.03, 0.05, 0.1, 0.2]
pgd_steps = 10
pgd_step_size = 0.01

pgd_results = {}

for model_name, model in models.items():
    print(f"\nAnalyzing PGD attacks on {model_name}...")
    
    # Create PGD attack
    pgd_attack = PGD(model, device=device, steps=pgd_steps, step_size=pgd_step_size)
    
    # Evaluate attack
    results = evaluate_attack(
        model, pgd_attack, subset_loader, 
        f"PGD-{model_name}", pgd_epsilon_values
    )
    
    pgd_results[model_name] = results

# Create DataFrame for analysis
pgd_df_data = []
for model_name, results in pgd_results.items():
    for result in results:
        result['model'] = model_name
        result['attack'] = 'PGD'
        pgd_df_data.append(result)

pgd_df = pd.DataFrame(pgd_df_data)

# Visualize PGD results
visualize_attack_results(pgd_df, "PGD")

# Visualize PGD adversarial examples
if 'SimpleCNN' in models:
    print("\nVisualizing PGD adversarial examples on SimpleCNN...")
    pgd_attack = PGD(models['SimpleCNN'], device=device, steps=pgd_steps, step_size=pgd_step_size)
    visualize_adversarial_examples(
        models['SimpleCNN'], pgd_attack, subset_loader, class_names,
        epsilon_values=[0.05, 0.1, 0.2]
    )

print("\nPGD attack evaluation complete!")

## 5. Attack Comparison and Analysis

Compare FGSM and PGD attacks side by side.

In [None]:
def compare_attacks(fgsm_df, pgd_df):
    """Compare different attack methods"""
    
    # Combine dataframes
    combined_df = pd.concat([fgsm_df, pgd_df], ignore_index=True)
    
    # Plot comparison
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('FGSM vs PGD Attack Comparison', fontsize=16, fontweight='bold')
    
    models = combined_df['model'].unique()
    attacks = combined_df['attack'].unique()
    
    # Common epsilon values for fair comparison
    common_epsilons = [0.03, 0.05, 0.1, 0.2]
    
    # Plot 1: Attack Success Rate Comparison
    for model in models:
        for attack in attacks:
            data = combined_df[(combined_df['model'] == model) & 
                             (combined_df['attack'] == attack) & 
                             (combined_df['epsilon'].isin(common_epsilons))]
            
            if not data.empty:
                linestyle = '-' if attack == 'FGSM' else '--'
                marker = 'o' if model == 'SimpleCNN' else 's'
                label = f'{model}-{attack}'
                
                axes[0, 0].plot(data['epsilon'], data['attack_success_rate'], 
                               linestyle=linestyle, marker=marker, label=label, linewidth=2)
    
    axes[0, 0].set_title('Attack Success Rate Comparison', fontweight='bold')
    axes[0, 0].set_xlabel('Epsilon')
    axes[0, 0].set_ylabel('Attack Success Rate (%)')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Adversarial Accuracy Comparison
    for model in models:
        for attack in attacks:
            data = combined_df[(combined_df['model'] == model) & 
                             (combined_df['attack'] == attack) & 
                             (combined_df['epsilon'].isin(common_epsilons))]
            
            if not data.empty:
                linestyle = '-' if attack == 'FGSM' else '--'
                marker = 'o' if model == 'SimpleCNN' else 's'
                label = f'{model}-{attack}'
                
                axes[0, 1].plot(data['epsilon'], data['adversarial_accuracy'], 
                               linestyle=linestyle, marker=marker, label=label, linewidth=2)
    
    axes[0, 1].set_title('Adversarial Accuracy Comparison', fontweight='bold')
    axes[0, 1].set_xlabel('Epsilon')
    axes[0, 1].set_ylabel('Adversarial Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Model Robustness at fixed epsilon
    epsilon_compare = 0.1
    comparison_data = combined_df[combined_df['epsilon'] == epsilon_compare]
    
    if not comparison_data.empty:
        x_pos = np.arange(len(models))
        bar_width = 0.35
        
        for i, attack in enumerate(attacks):
            attack_data = comparison_data[comparison_data['attack'] == attack]
            values = []
            for model in models:
                model_data = attack_data[attack_data['model'] == model]
                if not model_data.empty:
                    values.append(model_data['adversarial_accuracy'].iloc[0])
                else:
                    values.append(0)
            
            axes[1, 0].bar(x_pos + i * bar_width, values, bar_width, 
                          label=attack, alpha=0.8)
        
        axes[1, 0].set_title(f'Model Robustness Comparison (Œµ={epsilon_compare})', fontweight='bold')
        axes[1, 0].set_xlabel('Models')
        axes[1, 0].set_ylabel('Adversarial Accuracy (%)')
        axes[1, 0].set_xticks(x_pos + bar_width / 2)
        axes[1, 0].set_xticklabels(models)
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Perturbation Analysis
    for model in models:
        for attack in attacks:
            data = combined_df[(combined_df['model'] == model) & 
                             (combined_df['attack'] == attack) & 
                             (combined_df['epsilon'].isin(common_epsilons))]
            
            if not data.empty:
                linestyle = '-' if attack == 'FGSM' else '--'
                marker = 'o' if model == 'SimpleCNN' else 's'
                label = f'{model}-{attack}'
                
                axes[1, 1].plot(data['epsilon'], data['avg_l2_perturbation'], 
                               linestyle=linestyle, marker=marker, label=label, linewidth=2)
    
    axes[1, 1].set_title('L2 Perturbation Comparison', fontweight='bold')
    axes[1, 1].set_xlabel('Epsilon')
    axes[1, 1].set_ylabel('Average L2 Perturbation')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return combined_df

# Compare attacks
print("\nCOMPARING FGSM AND PGD ATTACKS")
print("=" * 40)

combined_results = compare_attacks(fgsm_df, pgd_df)

# Statistical analysis
print("\nAttack Effectiveness Analysis:")
print("-" * 30)

for epsilon in [0.05, 0.1, 0.2]:
    print(f"\nAt Œµ = {epsilon}:")
    
    epsilon_data = combined_results[combined_results['epsilon'] == epsilon]
    
    for model in models:
        model_data = epsilon_data[epsilon_data['model'] == model]
        
        if not model_data.empty:
            fgsm_data = model_data[model_data['attack'] == 'FGSM']
            pgd_data = model_data[model_data['attack'] == 'PGD']
            
            if not fgsm_data.empty and not pgd_data.empty:
                fgsm_success = fgsm_data['attack_success_rate'].iloc[0]
                pgd_success = pgd_data['attack_success_rate'].iloc[0]
                
                print(f"  {model}:")
                print(f"    FGSM Success Rate: {fgsm_success:.2f}%")
                print(f"    PGD Success Rate: {pgd_success:.2f}%")
                print(f"    PGD Improvement: {pgd_success - fgsm_success:+.2f}%")

print("\nAttack comparison complete!")

## 6. Transferability Analysis

Analyze the transferability of adversarial examples between models.

In [None]:
print("TRANSFERABILITY ANALYSIS")
print("=" * 40)

if len(models) >= 2:
    # Create transferability analyzer
    transferability_analyzer = TransferabilityAnalyzer(
        models=models, 
        device=device
    )
    
    # Test transferability with both FGSM and PGD
    attack_configs = {
        'FGSM': {'method': 'fgsm', 'epsilon': 0.1},
        'PGD': {'method': 'pgd', 'epsilon': 0.1, 'steps': 10, 'step_size': 0.01}
    }
    
    transferability_results = {}
    
    for attack_name, attack_config in attack_configs.items():
        print(f"\nAnalyzing {attack_name} transferability...")
        
        # Run transferability analysis
        transfer_results = transferability_analyzer.analyze_transferability(
            dataloader=subset_loader,
            attack_config=attack_config,
            num_samples=50  # Smaller sample for faster computation
        )
        
        transferability_results[attack_name] = transfer_results
        
        # Print results
        print(f"\n{attack_name} Transferability Matrix:")
        print("Source ‚Üí Target | Success Rate")
        print("-" * 30)
        
        for source_model, target_results in transfer_results.items():
            for target_model, metrics in target_results.items():
                if source_model != target_model:
                    success_rate = metrics['attack_success_rate']
                    print(f"{source_model} ‚Üí {target_model}: {success_rate:.2f}%")
    
    # Visualize transferability
    def visualize_transferability(transfer_results, attack_name):
        """Visualize transferability matrix"""
        
        model_names = list(transfer_results.keys())
        n_models = len(model_names)
        
        # Create transferability matrix
        transfer_matrix = np.zeros((n_models, n_models))
        
        for i, source in enumerate(model_names):
            for j, target in enumerate(model_names):
                if source == target:
                    # Native attack success rate
                    transfer_matrix[i, j] = transfer_results[source][target]['attack_success_rate']
                else:
                    # Cross-model transfer success rate
                    transfer_matrix[i, j] = transfer_results[source][target]['attack_success_rate']
        
        # Plot heatmap
        plt.figure(figsize=(10, 8))
        sns.heatmap(transfer_matrix, annot=True, fmt='.1f', cmap='Reds',
                   xticklabels=[f'Target: {m}' for m in model_names],
                   yticklabels=[f'Source: {m}' for m in model_names],
                   cbar_kws={'label': 'Attack Success Rate (%)'})
        
        plt.title(f'{attack_name} Transferability Matrix', fontsize=14, fontweight='bold')
        plt.xlabel('Target Model')
        plt.ylabel('Source Model')
        plt.tight_layout()
        plt.show()
        
        return transfer_matrix
    
    # Visualize transferability for each attack
    for attack_name, transfer_results in transferability_results.items():
        print(f"\nVisualizing {attack_name} transferability...")
        matrix = visualize_transferability(transfer_results, attack_name)
    
    # Analyze transferability patterns
    print("\nTransferability Analysis Summary:")
    print("=" * 40)
    
    for attack_name, transfer_results in transferability_results.items():
        print(f"\n{attack_name} Attack:")
        
        cross_transfer_rates = []
        native_rates = []
        
        for source_model, target_results in transfer_results.items():
            for target_model, metrics in target_results.items():
                success_rate = metrics['attack_success_rate']
                
                if source_model == target_model:
                    native_rates.append(success_rate)
                else:
                    cross_transfer_rates.append(success_rate)
        
        if cross_transfer_rates and native_rates:
            avg_cross_transfer = np.mean(cross_transfer_rates)
            avg_native = np.mean(native_rates)
            
            print(f"  Average native attack success: {avg_native:.2f}%")
            print(f"  Average cross-model transfer: {avg_cross_transfer:.2f}%")
            print(f"  Transfer efficiency: {avg_cross_transfer/avg_native:.2f}")
            
            if avg_cross_transfer > 50:
                print(f"  ‚úÖ High transferability - attacks transfer well between models")
            elif avg_cross_transfer > 20:
                print(f"  ‚ö†Ô∏è  Moderate transferability")
            else:
                print(f"  ‚ùå Low transferability")

else:
    print("\n‚ö†Ô∏è  Need at least 2 models for transferability analysis")
    print("Skipping transferability analysis.")

print("\nTransferability analysis complete!")

## 7. Adversarial Interpretability Analysis

Analyze how adversarial attacks affect model interpretability.

In [None]:
print("ADVERSARIAL INTERPRETABILITY ANALYSIS")
print("=" * 45)

# Select a model for interpretability analysis
analysis_model_name = 'SimpleCNN' if 'SimpleCNN' in models else list(models.keys())[0]
analysis_model = models[analysis_model_name]

print(f"Analyzing interpretability for: {analysis_model_name}")

# Create interpretability analyzer
interp_analyzer = AdversarialInterpretabilityAnalyzer(
    model=analysis_model,
    device=device
)

# Get a few samples for detailed analysis
sample_images, sample_labels = next(iter(subset_loader))
sample_images = sample_images[:3].to(device)  # Analyze first 3 images
sample_labels = sample_labels[:3].to(device)

# Attack configurations for interpretability
attack_configs = {
    'FGSM_mild': {'method': 'fgsm', 'epsilon': 0.05},
    'FGSM_strong': {'method': 'fgsm', 'epsilon': 0.15},
    'PGD_mild': {'method': 'pgd', 'epsilon': 0.05, 'steps': 10, 'step_size': 0.005},
    'PGD_strong': {'method': 'pgd', 'epsilon': 0.15, 'steps': 10, 'step_size': 0.015}
}

# Analyze interpretability changes
print("\nAnalyzing saliency map changes under adversarial attacks...")

interpretability_results = interp_analyzer.analyze_saliency_changes(
    images=sample_images,
    labels=sample_labels,
    attack_configs=attack_configs,
    class_names=class_names
)

print("Interpretability analysis complete!")

# Create custom visualization for interpretability
def visualize_interpretability_changes(images, labels, results, class_names):
    """Visualize how adversarial attacks affect model interpretability"""
    
    n_samples = min(2, images.size(0))  # Limit to 2 samples for readability
    n_attacks = len(attack_configs) + 1  # +1 for original
    
    fig, axes = plt.subplots(n_samples * 2, n_attacks, figsize=(4 * n_attacks, 8 * n_samples))
    
    if n_samples == 1:
        axes = axes.reshape(2, -1)
    
    # Denormalization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(device)
    
    for sample_idx in range(n_samples):
        true_label = labels[sample_idx].item()
        
        # Original image and saliency
        original_img = images[sample_idx:sample_idx+1]
        
        with torch.no_grad():
            orig_output = analysis_model(original_img)
            orig_pred = orig_output.argmax(dim=1).item()
            orig_conf = F.softmax(orig_output, dim=1).max().item()
        
        # Display original image
        img_display = (original_img.squeeze() * std + mean).clamp(0, 1)
        img_display = img_display.permute(1, 2, 0).cpu().numpy()
        
        axes[sample_idx*2, 0].imshow(img_display)
        axes[sample_idx*2, 0].set_title(f'Original\nTrue: {class_names[true_label]}\n' + 
                                       f'Pred: {class_names[orig_pred]}\nConf: {orig_conf:.3f}', 
                                       fontsize=10)
        axes[sample_idx*2, 0].axis('off')
        
        # Original saliency
        saliency_map = SaliencyMap(analysis_model, device=device)
        orig_saliency, _, _ = saliency_map.generate(original_img)
        
        axes[sample_idx*2+1, 0].imshow(orig_saliency, cmap='hot')
        axes[sample_idx*2+1, 0].set_title('Original Saliency', fontsize=10)
        axes[sample_idx*2+1, 0].axis('off')
        
        # Adversarial examples and their saliencies
        for col_idx, (attack_name, config) in enumerate(attack_configs.items(), 1):
            if attack_name in results and sample_idx < len(results[attack_name]['adversarial_images']):
                adv_img = results[attack_name]['adversarial_images'][sample_idx]
                adv_saliency = results[attack_name]['adversarial_saliencies'][sample_idx]
                
                with torch.no_grad():
                    adv_output = analysis_model(adv_img.unsqueeze(0))
                    adv_pred = adv_output.argmax(dim=1).item()
                    adv_conf = F.softmax(adv_output, dim=1).max().item()
                
                # Display adversarial image
                adv_display = (adv_img * std.squeeze() + mean.squeeze()).clamp(0, 1)
                adv_display = adv_display.permute(1, 2, 0).cpu().numpy()
                
                success = '‚úì' if adv_pred != true_label else '‚úó'
                axes[sample_idx*2, col_idx].imshow(adv_display)
                axes[sample_idx*2, col_idx].set_title(
                    f'{attack_name}\nPred: {class_names[adv_pred]} {success}\nConf: {adv_conf:.3f}',
                    fontsize=10
                )
                axes[sample_idx*2, col_idx].axis('off')
                
                # Display adversarial saliency
                axes[sample_idx*2+1, col_idx].imshow(adv_saliency, cmap='hot')
                
                # Calculate saliency difference
                saliency_diff = np.abs(adv_saliency - orig_saliency).mean()
                axes[sample_idx*2+1, col_idx].set_title(
                    f'{attack_name} Saliency\nDiff: {saliency_diff:.3f}',
                    fontsize=10
                )
                axes[sample_idx*2+1, col_idx].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize interpretability changes
if interpretability_results:
    print("\nVisualizing interpretability changes...")
    visualize_interpretability_changes(
        sample_images, sample_labels, interpretability_results, class_names
    )

# Quantitative analysis of interpretability changes
print("\nQuantitative Interpretability Analysis:")
print("-" * 40)

if interpretability_results:
    for attack_name, attack_results in interpretability_results.items():
        if 'saliency_differences' in attack_results:
            saliency_diffs = attack_results['saliency_differences']
            avg_diff = np.mean(saliency_diffs)
            std_diff = np.std(saliency_diffs)
            
            print(f"\n{attack_name}:")
            print(f"  Average saliency change: {avg_diff:.4f} ¬± {std_diff:.4f}")
            print(f"  Max saliency change: {np.max(saliency_diffs):.4f}")
            print(f"  Min saliency change: {np.min(saliency_diffs):.4f}")
            
            if avg_diff > 0.1:
                print(f"  ‚ö†Ô∏è  High interpretability change")
            elif avg_diff > 0.05:
                print(f"  ‚ö†Ô∏è  Moderate interpretability change")
            else:
                print(f"  ‚úÖ Low interpretability change")

print("\nInterpretability analysis complete!")

## 8. Robustness Evaluation and Defense Strategies

Evaluate model robustness and discuss potential defense strategies.

In [None]:
print("MODEL ROBUSTNESS EVALUATION")
print("=" * 40)

def comprehensive_robustness_evaluation(models, dataloader, class_names):
    """Comprehensive evaluation of model robustness"""
    
    robustness_results = {}
    
    # Test multiple attack configurations
    test_configs = {
        'weak_fgsm': {'attack': 'fgsm', 'epsilon': 0.03},
        'medium_fgsm': {'attack': 'fgsm', 'epsilon': 0.1},
        'strong_fgsm': {'attack': 'fgsm', 'epsilon': 0.2},
        'weak_pgd': {'attack': 'pgd', 'epsilon': 0.03, 'steps': 10, 'step_size': 0.003},
        'medium_pgd': {'attack': 'pgd', 'epsilon': 0.1, 'steps': 10, 'step_size': 0.01},
        'strong_pgd': {'attack': 'pgd', 'epsilon': 0.2, 'steps': 10, 'step_size': 0.02}
    }
    
    for model_name, model in models.items():
        print(f"\nEvaluating robustness of {model_name}...")
        
        model_results = {}
        
        # Clean accuracy
        clean_correct = 0
        total_samples = 0
        
        model.eval()
        with torch.no_grad():
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = outputs.argmax(dim=1)
                clean_correct += (predictions == labels).sum().item()
                total_samples += labels.size(0)
        
        clean_accuracy = 100.0 * clean_correct / total_samples
        model_results['clean_accuracy'] = clean_accuracy
        
        print(f"  Clean accuracy: {clean_accuracy:.2f}%")
        
        # Test each attack configuration
        for config_name, config in test_configs.items():
            print(f"    Testing {config_name}...")
            
            if config['attack'] == 'fgsm':
                attack_method = FGSM(model, device=device)
                epsilon = config['epsilon']
            else:  # pgd
                attack_method = PGD(model, device=device, 
                                  steps=config['steps'], 
                                  step_size=config['step_size'])
                epsilon = config['epsilon']
            
            adv_correct = 0
            
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)
                
                # Generate adversarial examples
                adv_images = attack_method.attack(images, labels, epsilon=epsilon)
                
                # Test adversarial accuracy
                with torch.no_grad():
                    adv_outputs = model(adv_images)
                    adv_predictions = adv_outputs.argmax(dim=1)
                    adv_correct += (adv_predictions == labels).sum().item()
            
            adv_accuracy = 100.0 * adv_correct / total_samples
            robustness_score = adv_accuracy / clean_accuracy if clean_accuracy > 0 else 0
            
            model_results[config_name] = {
                'adversarial_accuracy': adv_accuracy,
                'robustness_score': robustness_score,
                'attack_success_rate': 100.0 * (clean_correct - adv_correct) / clean_correct if clean_correct > 0 else 0
            }
            
            print(f"      Adv accuracy: {adv_accuracy:.2f}%")
            print(f"      Robustness score: {robustness_score:.3f}")
        
        robustness_results[model_name] = model_results
    
    return robustness_results

# Run comprehensive robustness evaluation
robustness_results = comprehensive_robustness_evaluation(
    models, subset_loader, class_names
)

# Visualize robustness comparison
def visualize_robustness_comparison(robustness_results):
    """Visualize robustness comparison across models and attacks"""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Model Robustness Comparison', fontsize=16, fontweight='bold')
    
    model_names = list(robustness_results.keys())
    attack_configs = ['weak_fgsm', 'medium_fgsm', 'strong_fgsm', 
                     'weak_pgd', 'medium_pgd', 'strong_pgd']
    
    # Plot 1: Robustness scores heatmap
    robustness_matrix = []
    for model_name in model_names:
        model_scores = []
        for attack_config in attack_configs:
            if attack_config in robustness_results[model_name]:
                score = robustness_results[model_name][attack_config]['robustness_score']
                model_scores.append(score)
            else:
                model_scores.append(0)
        robustness_matrix.append(model_scores)
    
    robustness_matrix = np.array(robustness_matrix)
    
    im = axes[0, 0].imshow(robustness_matrix, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
    axes[0, 0].set_title('Robustness Score Matrix', fontweight='bold')
    axes[0, 0].set_xticks(range(len(attack_configs)))
    axes[0, 0].set_xticklabels(attack_configs, rotation=45, ha='right')
    axes[0, 0].set_yticks(range(len(model_names)))
    axes[0, 0].set_yticklabels(model_names)
    
    # Add text annotations
    for i in range(len(model_names)):
        for j in range(len(attack_configs)):
            axes[0, 0].text(j, i, f'{robustness_matrix[i, j]:.2f}', 
                           ha='center', va='center', fontsize=8)
    
    plt.colorbar(im, ax=axes[0, 0], label='Robustness Score')
    
    # Plot 2: Average robustness by model
    avg_robustness = []
    for model_name in model_names:
        scores = [robustness_results[model_name][config]['robustness_score'] 
                 for config in attack_configs 
                 if config in robustness_results[model_name]]
        avg_robustness.append(np.mean(scores))
    
    bars = axes[0, 1].bar(model_names, avg_robustness, alpha=0.7)
    axes[0, 1].set_title('Average Robustness Score', fontweight='bold')
    axes[0, 1].set_ylabel('Average Robustness Score')
    axes[0, 1].set_ylim(0, 1)
    axes[0, 1].grid(True, alpha=0.3)
    
    # Add value labels
    for i, bar in enumerate(bars):
        height = bar.get_height()
        axes[0, 1].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{avg_robustness[i]:.3f}', ha='center', va='bottom')
    
    # Plot 3: Attack success rates
    fgsm_attacks = ['weak_fgsm', 'medium_fgsm', 'strong_fgsm']
    pgd_attacks = ['weak_pgd', 'medium_pgd', 'strong_pgd']
    
    x_pos = np.arange(len(model_names))
    bar_width = 0.35
    
    fgsm_success = []
    pgd_success = []
    
    for model_name in model_names:
        fgsm_rates = [robustness_results[model_name][attack]['attack_success_rate']
                     for attack in fgsm_attacks 
                     if attack in robustness_results[model_name]]
        pgd_rates = [robustness_results[model_name][attack]['attack_success_rate']
                    for attack in pgd_attacks 
                    if attack in robustness_results[model_name]]
        
        fgsm_success.append(np.mean(fgsm_rates) if fgsm_rates else 0)
        pgd_success.append(np.mean(pgd_rates) if pgd_rates else 0)
    
    axes[1, 0].bar(x_pos - bar_width/2, fgsm_success, bar_width, 
                  label='FGSM', alpha=0.8)
    axes[1, 0].bar(x_pos + bar_width/2, pgd_success, bar_width, 
                  label='PGD', alpha=0.8)
    
    axes[1, 0].set_title('Average Attack Success Rate', fontweight='bold')
    axes[1, 0].set_xlabel('Models')
    axes[1, 0].set_ylabel('Attack Success Rate (%)')
    axes[1, 0].set_xticks(x_pos)
    axes[1, 0].set_xticklabels(model_names)
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Robustness vs Clean Accuracy
    clean_accs = [robustness_results[model]['clean_accuracy'] for model in model_names]
    
    axes[1, 1].scatter(clean_accs, avg_robustness, s=100, alpha=0.7)
    
    for i, model in enumerate(model_names):
        axes[1, 1].annotate(model, (clean_accs[i], avg_robustness[i]), 
                           xytext=(5, 5), textcoords='offset points')
    
    axes[1, 1].set_title('Robustness vs Clean Accuracy', fontweight='bold')
    axes[1, 1].set_xlabel('Clean Accuracy (%)')
    axes[1, 1].set_ylabel('Average Robustness Score')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Visualize robustness comparison
visualize_robustness_comparison(robustness_results)

# Print robustness summary
print("\nROBUSTNESS EVALUATION SUMMARY")
print("=" * 40)

for model_name, results in robustness_results.items():
    print(f"\n{model_name}:")
    print(f"  Clean accuracy: {results['clean_accuracy']:.2f}%")
    
    # Calculate average robustness
    robustness_scores = [results[config]['robustness_score'] 
                        for config in results if config != 'clean_accuracy']
    avg_robustness = np.mean(robustness_scores)
    
    print(f"  Average robustness score: {avg_robustness:.3f}")
    
    if avg_robustness > 0.7:
        print(f"  ‚úÖ High robustness")
    elif avg_robustness > 0.4:
        print(f"  ‚ö†Ô∏è  Moderate robustness")
    else:
        print(f"  ‚ùå Low robustness")

print("\nRobustness evaluation complete!")

## 9. Summary and Conclusions

Summarize findings and provide recommendations for improving adversarial robustness.

In [None]:
print("\n" + "="*80)
print("ADVERSARIAL ATTACKS ANALYSIS - SUMMARY AND CONCLUSIONS")
print("="*80)

# Key findings from the analysis
print("\nüîç KEY FINDINGS:")
print("-" * 20)

# 1. Attack Effectiveness
print("\n1. Attack Effectiveness:")
if 'combined_results' in locals():
    # Find most effective attack
    high_eps_data = combined_results[combined_results['epsilon'] >= 0.1]
    if not high_eps_data.empty:
        best_attack = high_eps_data.loc[high_eps_data['attack_success_rate'].idxmax()]
        print(f"   ‚Ä¢ Most effective attack: {best_attack['attack']} on {best_attack['model']}")
        print(f"   ‚Ä¢ Success rate: {best_attack['attack_success_rate']:.1f}% at Œµ={best_attack['epsilon']}")
        
        # Compare FGSM vs PGD
        fgsm_avg = combined_results[combined_results['attack'] == 'FGSM']['attack_success_rate'].mean()
        pgd_avg = combined_results[combined_results['attack'] == 'PGD']['attack_success_rate'].mean()
        
        print(f"   ‚Ä¢ Average FGSM success rate: {fgsm_avg:.1f}%")
        print(f"   ‚Ä¢ Average PGD success rate: {pgd_avg:.1f}%")
        
        if pgd_avg > fgsm_avg:
            print(f"   ‚úì PGD is {pgd_avg - fgsm_avg:.1f}% more effective than FGSM")
        else:
            print(f"   ‚ö†Ô∏è FGSM performs similarly to PGD")

# 2. Model Robustness
print("\n2. Model Robustness Ranking:")
if robustness_results:
    # Calculate overall robustness scores
    model_robustness = {}
    for model_name, results in robustness_results.items():
        scores = [results[config]['robustness_score'] 
                 for config in results if config != 'clean_accuracy']
        model_robustness[model_name] = np.mean(scores)
    
    # Sort by robustness
    sorted_models = sorted(model_robustness.items(), key=lambda x: x[1], reverse=True)
    
    for i, (model, score) in enumerate(sorted_models, 1):
        clean_acc = robustness_results[model]['clean_accuracy']
        print(f"   {i}. {model}: {score:.3f} robustness (Clean: {clean_acc:.1f}%)")
    
    # Best balance between accuracy and robustness
    balance_scores = {}
    for model_name, results in robustness_results.items():
        clean_acc = results['clean_accuracy']
        robust_score = model_robustness[model_name]
        balance_scores[model_name] = (clean_acc / 100) * robust_score
    
    best_balance = max(balance_scores.items(), key=lambda x: x[1])
    print(f"   ‚úì Best accuracy-robustness balance: {best_balance[0]} (score: {best_balance[1]:.3f})")

# 3. Transferability
if 'transferability_results' in locals() and transferability_results:
    print("\n3. Transferability Insights:")
    
    for attack_name, transfer_results in transferability_results.items():
        cross_transfer_rates = []
        
        for source_model, target_results in transfer_results.items():
            for target_model, metrics in target_results.items():
                if source_model != target_model:
                    cross_transfer_rates.append(metrics['attack_success_rate'])
        
        if cross_transfer_rates:
            avg_transfer = np.mean(cross_transfer_rates)
            print(f"   ‚Ä¢ {attack_name} average transferability: {avg_transfer:.1f}%")
            
            if avg_transfer > 50:
                print(f"     ‚ö†Ô∏è High transferability - models share vulnerabilities")
            else:
                print(f"     ‚úì Limited transferability")

# 4. Interpretability Impact
if 'interpretability_results' in locals() and interpretability_results:
    print("\n4. Interpretability Impact:")
    
    for attack_name, attack_results in interpretability_results.items():
        if 'saliency_differences' in attack_results:
            avg_diff = np.mean(attack_results['saliency_differences'])
            print(f"   ‚Ä¢ {attack_name} saliency change: {avg_diff:.3f}")
            
            if avg_diff > 0.1:
                print(f"     ‚ö†Ô∏è Significant interpretability disruption")
            else:
                print(f"     ‚úì Moderate interpretability impact")

# Security Implications
print("\nüö® SECURITY IMPLICATIONS:")
print("-" * 30)

print("\n1. Attack Feasibility:")
print("   ‚Ä¢ FGSM attacks can be generated quickly with minimal computation")
print("   ‚Ä¢ PGD attacks are more powerful but require more computational resources")
print("   ‚Ä¢ Small perturbations (Œµ < 0.1) can significantly reduce accuracy")
print("   ‚Ä¢ Attacks can fool models while remaining visually imperceptible")

print("\n2. Real-World Risks:")
print("   ‚Ä¢ Sports classification systems could misclassify images")
print("   ‚Ä¢ Potential for bypassing content filtering systems")
print("   ‚Ä¢ Risk in automated sports analysis and broadcasting")

if 'transferability_results' in locals() and transferability_results:
    print("\n3. Cross-Model Vulnerabilities:")
    if any(np.mean([transfer_results[source][target]['attack_success_rate'] 
                   for source in transfer_results 
                   for target in transfer_results[source] 
                   if source != target]) > 30 
          for transfer_results in transferability_results.values()):
        print("   ‚ö†Ô∏è Models share common vulnerabilities")
        print("   ‚Ä¢ Black-box attacks possible using surrogate models")
        print("   ‚Ä¢ Need for diverse training approaches")
    else:
        print("   ‚úì Limited cross-model vulnerability transfer")

# Defense Recommendations
print("\nüõ°Ô∏è  DEFENSE RECOMMENDATIONS:")
print("-" * 35)

print("\n1. Adversarial Training:")
print("   ‚Ä¢ Train models with adversarial examples in the training set")
print("   ‚Ä¢ Use multiple attack types (FGSM, PGD, C&W) during training")
print("   ‚Ä¢ Gradually increase attack strength during training")
print("   ‚Ä¢ Expected improvement: 20-40% robustness increase")

print("\n2. Data Augmentation:")
print("   ‚Ä¢ Add Gaussian noise during training")
print("   ‚Ä¢ Use random transformations beyond current augmentations")
print("   ‚Ä¢ Implement mixup and cutmix techniques")
print("   ‚Ä¢ Expected improvement: 10-20% robustness increase")

print("\n3. Model Architecture Improvements:")
print("   ‚Ä¢ Use certified defense layers")
print("   ‚Ä¢ Implement defensive distillation")
print("   ‚Ä¢ Add batch normalization and dropout for regularization")
print("   ‚Ä¢ Consider ensemble methods for improved robustness")

print("\n4. Input Preprocessing:")
print("   ‚Ä¢ Apply image denoising filters")
print("   ‚Ä¢ Use JPEG compression to remove small perturbations")
print("   ‚Ä¢ Implement randomized smoothing")
print("   ‚Ä¢ Note: May slightly reduce clean accuracy")

print("\n5. Detection Mechanisms:")
print("   ‚Ä¢ Monitor prediction confidence scores")
print("   ‚Ä¢ Detect unusual activation patterns")
print("   ‚Ä¢ Use statistical tests on model outputs")
print("   ‚Ä¢ Implement uncertainty quantification")

# Future Work
print("\nüî¨ FUTURE RESEARCH DIRECTIONS:")
print("-" * 35)

print("\n1. Advanced Attacks:")
print("   ‚Ä¢ Semantic adversarial attacks")
print("   ‚Ä¢ Physical adversarial examples")
print("   ‚Ä¢ Adaptive attacks against specific defenses")
print("   ‚Ä¢ Universal adversarial perturbations")

print("\n2. Robust Training Methods:")
print("   ‚Ä¢ Certified adversarial training")
print("   ‚Ä¢ Distributionally robust optimization")
print("   ‚Ä¢ Meta-learning for adversarial robustness")
print("   ‚Ä¢ Self-supervised robust pretraining")

print("\n3. Evaluation Metrics:")
print("   ‚Ä¢ Develop better robustness metrics")
print("   ‚Ä¢ Study robustness-accuracy trade-offs")
print("   ‚Ä¢ Analyze robustness across different domains")

# Implementation Priority
print("\nüìã IMPLEMENTATION PRIORITY:")
print("-" * 30)

if robustness_results:
    # Assess current robustness level
    avg_robustness = np.mean([np.mean([results[config]['robustness_score'] 
                                     for config in results if config != 'clean_accuracy']) 
                            for results in robustness_results.values()])
    
    print(f"\nCurrent average robustness: {avg_robustness:.3f}")
    
    if avg_robustness < 0.3:
        print("\nüö® HIGH PRIORITY - Low robustness detected:")
        print("   1. Implement adversarial training immediately")
        print("   2. Add input preprocessing defenses")
        print("   3. Implement attack detection")
    elif avg_robustness < 0.6:
        print("\n‚ö†Ô∏è MEDIUM PRIORITY - Moderate robustness:")
        print("   1. Enhance data augmentation")
        print("   2. Experiment with adversarial training")
        print("   3. Consider ensemble methods")
    else:
        print("\n‚úÖ LOW PRIORITY - Good robustness:")
        print("   1. Fine-tune existing defenses")
        print("   2. Test against stronger attacks")
        print("   3. Focus on maintaining clean accuracy")

print("\n" + "="*80)
print("üìä Analysis complete! Use these insights to improve model robustness.")
print("üîí Remember: Security is an ongoing process, not a one-time fix.")
print("="*80)