# Gradient Inversion Attack on CIFAR-10

This notebook demonstrates how to:
1. Train a ResNet20-4 model on CIFAR-10
2. Perform gradient inversion attack to reconstruct private training data from gradients
3. Evaluate the privacy leakage in federated learning scenarios

## 1. Setup and Imports

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import time
import datetime
from collections import defaultdict

import inversefed
from inversefed.data.loss import Classification

# Set style for better plots
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 2. Configuration

In [None]:
# Configuration
DATASET_DIR = './datasets'  # Local dataset directory
CIFAR10_PATH = os.path.join(DATASET_DIR, 'cifar10')
MODEL_NAME = 'ResNet20-4'  # Options: 'ConvNet', 'ResNet20-4', 'ResNet32-10'
MODEL_SAVE_PATH = f'models/{MODEL_NAME.lower()}_cifar10.pth'
EPOCHS = 20  # Number of training epochs
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")
print(f"Model: {MODEL_NAME}")
print(f"Training epochs: {EPOCHS}")

## 3. Dataset Preparation

In [None]:
def download_cifar10_if_needed():
    """Download CIFAR-10 dataset if it doesn't exist."""
    if not os.path.exists(CIFAR10_PATH):
        print(f"CIFAR-10 dataset not found at {CIFAR10_PATH}")
        print("Downloading CIFAR-10 dataset...")
        os.makedirs(CIFAR10_PATH, exist_ok=True)
        
        # This will download CIFAR-10 to the specified directory
        _ = torchvision.datasets.CIFAR10(
            root=CIFAR10_PATH, 
            train=True, 
            download=True
        )
        _ = torchvision.datasets.CIFAR10(
            root=CIFAR10_PATH, 
            train=False, 
            download=True
        )
        print("CIFAR-10 dataset downloaded successfully!")
    else:
        print(f"CIFAR-10 dataset found at {CIFAR10_PATH}")

# Download dataset if needed
download_cifar10_if_needed()

## 4. Visualization Utilities

In [None]:
# Get normalization constants
dm = torch.as_tensor(inversefed.consts.cifar10_mean)[:, None, None]
ds = torch.as_tensor(inversefed.consts.cifar10_std)[:, None, None]

def plot_images(tensor, title="Images", save_path=None, nrow=8):
    """Plot and optionally save images."""
    # Clone and denormalize
    tensor = tensor.clone().detach().cpu()
    tensor = tensor * ds + dm
    tensor = torch.clamp(tensor, 0, 1)
    
    # Plot
    if tensor.shape[0] == 1:
        plt.figure(figsize=(5, 5))
        plt.imshow(tensor[0].permute(1, 2, 0))
        plt.title(title)
        plt.axis('off')
    else:
        n_images = min(tensor.shape[0], nrow)
        fig, axes = plt.subplots(1, n_images, figsize=(2*n_images, 2.5))
        if n_images == 1:
            axes = [axes]
        for i, (im, ax) in enumerate(zip(tensor[:n_images], axes)):
            ax.imshow(im.permute(1, 2, 0))
            ax.set_title(f"Image {i}")
            ax.axis('off')
        fig.suptitle(title)
        plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

def plot_comparison(original, reconstructed, title="Comparison"):
    """Plot original vs reconstructed images side by side."""
    # Denormalize
    original = original.clone().detach().cpu() * ds + dm
    reconstructed = reconstructed.clone().detach().cpu() * ds + dm
    original = torch.clamp(original, 0, 1)
    reconstructed = torch.clamp(reconstructed, 0, 1)
    
    n_images = original.shape[0]
    fig, axes = plt.subplots(2, n_images, figsize=(2*n_images, 4))
    
    if n_images == 1:
        axes = axes.reshape(2, 1)
    
    for i in range(n_images):
        # Original
        axes[0, i].imshow(original[i].permute(1, 2, 0))
        axes[0, i].set_title(f"Original {i}")
        axes[0, i].axis('off')
        
        # Reconstructed
        axes[1, i].imshow(reconstructed[i].permute(1, 2, 0))
        axes[1, i].set_title(f"Reconstructed {i}")
        axes[1, i].axis('off')
    
    fig.suptitle(title)
    plt.tight_layout()
    plt.show()

## 5. Model Training

### 5.1 Setup Training

In [None]:
# System setup
setup = inversefed.utils.system_startup()

# Training configuration
defs = inversefed.training_strategy('conservative')
defs.epochs = EPOCHS

# Load data
loss_fn, trainloader, validloader = inversefed.construct_dataloaders(
    'CIFAR10', defs, data_path=CIFAR10_PATH
)

print(f"Training samples: {len(trainloader.dataset):,}")
print(f"Validation samples: {len(validloader.dataset):,}")
print(f"Batch size: {defs.batch_size}")
print(f"Number of batches: {len(trainloader)}")

### 5.2 Create Model

In [None]:
# Create model
model, model_seed = inversefed.construct_model(MODEL_NAME, num_classes=10, num_channels=3)
model.to(**setup)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: {MODEL_NAME}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model seed: {model_seed}")

### 5.3 Train or Load Model

In [None]:
# Check if model already exists
if os.path.exists(MODEL_SAVE_PATH):
    print(f"Loading existing model from {MODEL_SAVE_PATH}")
    model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=setup['device']))
else:
    print(f"Training new model for {EPOCHS} epochs...")
    print("This may take a while...")
    
    # Train the model
    training_stats = inversefed.train(
        model, loss_fn, trainloader, validloader, defs, setup=setup
    )
    
    # Save the model
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"\nModel saved to {MODEL_SAVE_PATH}")
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(training_stats['train_losses'], label='Train')
    plt.plot(training_stats['valid_losses'], label='Valid')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(training_stats['train_Accuracy'], label='Train')
    plt.plot(training_stats['valid_Accuracy'], label='Valid')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

### 5.4 Evaluate Model

In [None]:
# Evaluate model
model.eval()
correct = 0
total = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for images, labels in validloader:
        images = images.to(**setup)
        labels = labels.to(device=setup['device'], dtype=torch.long)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Per-class accuracy
        c = (predicted == labels).squeeze()
        for i in range(labels.size(0)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

print(f"Overall accuracy: {100 * correct / total:.2f}%\n")

# Per-class accuracy
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
class_accs = []
for i in range(10):
    if class_total[i] > 0:
        acc = 100 * class_correct[i]/class_total[i]
        class_accs.append(acc)
        print(f"{classes[i]:8s}: {acc:5.2f}%")

# Plot class accuracies
plt.figure(figsize=(10, 6))
plt.bar(classes, class_accs)
plt.xlabel('Class')
plt.ylabel('Accuracy (%)')
plt.title('Per-Class Accuracy')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## 6. Gradient Inversion Attack

### 6.1 Select Target Image

In [None]:
# Let's look at some sample images from the validation set
sample_indices = [42, 100, 200, 300, 400, 500, 600, 700]
sample_images = []
sample_labels = []

for idx in sample_indices:
    img, label = validloader.dataset[idx]
    sample_images.append(img)
    sample_labels.append(label)

sample_images = torch.stack(sample_images)
print("Sample images from validation set:")
print(f"Labels: {[classes[l] for l in sample_labels]}")
plot_images(sample_images, title="Sample Validation Images")

In [None]:
# Select target image for attack
# You can change this to any index from 0-9999
TARGET_ID = 42  # Change this to select different images

# For better results, you might want to pick from well-classified classes
# Look at the per-class accuracy above and choose accordingly

ground_truth, true_label = validloader.dataset[TARGET_ID]
ground_truth = ground_truth.unsqueeze(0).to(**setup)
labels = torch.tensor([true_label], device=setup['device'])

print(f"Target image index: {TARGET_ID}")
print(f"Target class: {classes[true_label]}")
print(f"Class accuracy: {class_accs[true_label]:.2f}%")

# Show the target image
plot_images(ground_truth, title=f"Target Image - {classes[true_label]}")

### 6.2 Compute Gradients

In [None]:
# Compute gradient
model.zero_grad()
loss_fn = Classification()
outputs = model(ground_truth)
loss, _, _ = loss_fn(outputs, labels)
input_gradient = torch.autograd.grad(loss, model.parameters())
input_gradient = [grad.detach() for grad in input_gradient]

# Compute gradient statistics
gradient_norms = [g.norm().item() for g in input_gradient]
total_norm = torch.stack([g.norm() for g in input_gradient]).mean().item()

print(f"Loss value: {loss.item():.4f}")
print(f"Average gradient norm: {total_norm:.4f}")
print(f"Number of gradient tensors: {len(input_gradient)}")
print(f"Total gradient parameters: {sum(g.numel() for g in input_gradient):,}")

# Plot gradient norms by layer
plt.figure(figsize=(10, 6))
plt.bar(range(len(gradient_norms)), gradient_norms)
plt.xlabel('Layer Index')
plt.ylabel('Gradient Norm')
plt.title('Gradient Norms by Layer')
plt.tight_layout()
plt.show()

### 6.3 Configure and Run Attack

In [None]:
# Configure reconstruction
config = dict(
    signed=True,
    boxed=True,
    cost_fn='sim',  # Options: 'sim', 'l2'
    indices='def',  # Options: 'def', 'top10', 'top50', 'all'
    weights='equal',
    lr=0.1,
    optim='adam',  # Options: 'adam', 'sgd', 'LBFGS'
    restarts=8,  # Number of random restarts
    max_iterations=8000,
    total_variation=1e-6,
    init='randn',
    filter='none',
    lr_decay=True,
    scoring_choice='loss'
)

print("Attack Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

In [None]:
# Run the attack
print(f"\nStarting gradient inversion attack with {config['restarts']} restarts...")
print("This may take several minutes...\n")

start_time = time.time()

# Get normalization constants for the reconstructor
dm_cuda = dm.to(**setup)
ds_cuda = ds.to(**setup)

rec_machine = inversefed.GradientReconstructor(model, (dm_cuda, ds_cuda), config, num_images=1)
output, stats = rec_machine.reconstruct(input_gradient, labels, img_shape=(3, 32, 32))

attack_time = time.time() - start_time

print(f"\nAttack completed in {attack_time:.2f} seconds ({attack_time/60:.2f} minutes)")

### 6.4 Evaluate Attack Results

In [None]:
# Calculate metrics
test_mse = (output - ground_truth).pow(2).mean().item()
test_psnr = inversefed.metrics.psnr(output, ground_truth, factor=1/ds_cuda)

# Calculate feature-level MSE
with torch.no_grad():
    feat_mse = (model(output) - model(ground_truth)).pow(2).mean().item()

print("Attack Results:")
print(f"Reconstruction loss: {stats['opt']:.4f}")
print(f"Pixel MSE: {test_mse:.4f}")
print(f"PSNR: {test_psnr:.2f} dB")
print(f"Feature MSE: {feat_mse:.4e}")

# Interpret results
if test_psnr > 30:
    quality = "Excellent reconstruction - privacy severely compromised!"
elif test_psnr > 20:
    quality = "Good reconstruction - significant privacy leakage"
elif test_psnr > 15:
    quality = "Moderate reconstruction - some privacy leakage"
else:
    quality = "Poor reconstruction - limited privacy leakage"

print(f"\nQuality assessment: {quality}")

In [None]:
# Visualize results
plot_comparison(ground_truth, output, 
                title=f"Original vs Reconstructed (PSNR: {test_psnr:.2f} dB)")

# Save the comparison
comparison = torch.cat([ground_truth, output], dim=0)
plot_images(comparison, title="Attack Results", 
            save_path=f"attack_result_{MODEL_NAME}_{TARGET_ID}.png")

## 7. Batch Attack (Multiple Images)

Let's try attacking multiple images at once, which is more realistic for federated learning.

In [None]:
# Select multiple images with different labels
BATCH_SIZE = 8
batch_ground_truth, batch_labels = [], []
idx = 100

while len(batch_labels) < BATCH_SIZE:
    img, label = validloader.dataset[idx]
    idx += 1
    if label not in batch_labels:
        batch_labels.append(label)
        batch_ground_truth.append(img.to(**setup))

batch_ground_truth = torch.stack(batch_ground_truth)
batch_labels = torch.tensor(batch_labels, device=setup['device'])

print(f"Batch attack on {BATCH_SIZE} images")
print(f"Classes: {[classes[l] for l in batch_labels]}")
plot_images(batch_ground_truth, title="Original Images for Batch Attack")

In [None]:
# Compute batch gradient
model.zero_grad()
outputs = model(batch_ground_truth)
loss, _, _ = loss_fn(outputs, batch_labels)
batch_gradient = torch.autograd.grad(loss, model.parameters())
batch_gradient = [grad.detach() for grad in batch_gradient]

print(f"Batch loss: {loss.item():.4f}")
print(f"Batch gradient norm: {torch.stack([g.norm() for g in batch_gradient]).mean():.4f}")

In [None]:
# Configure batch reconstruction (different parameters)
batch_config = dict(
    signed=True,
    boxed=True,
    cost_fn='sim',
    indices='def',
    weights='equal',
    lr=0.01,  # Lower learning rate for batch
    optim='adam',
    restarts=4,  # Fewer restarts due to complexity
    max_iterations=12000,  # More iterations needed
    total_variation=1e-2,  # Higher TV for batch
    init='randn',
    filter='none',
    lr_decay=True,
    scoring_choice='loss'
)

# Run batch attack
print("Starting batch reconstruction...")
start_time = time.time()

batch_rec_machine = inversefed.GradientReconstructor(
    model, (dm_cuda, ds_cuda), batch_config, num_images=BATCH_SIZE
)
batch_output, batch_stats = batch_rec_machine.reconstruct(
    batch_gradient, batch_labels, img_shape=(3, 32, 32)
)

batch_time = time.time() - start_time
print(f"\nBatch attack completed in {batch_time:.2f} seconds")

In [None]:
# Evaluate batch results
batch_mse = (batch_output - batch_ground_truth).pow(2).mean().item()
batch_psnr = inversefed.metrics.psnr(batch_output, batch_ground_truth, factor=1/ds_cuda)

print(f"Batch Reconstruction loss: {batch_stats['opt']:.4f}")
print(f"Average MSE: {batch_mse:.4f}")
print(f"Average PSNR: {batch_psnr:.2f} dB")

# Per-image PSNR
individual_psnrs = []
for i in range(BATCH_SIZE):
    psnr = inversefed.metrics.psnr(
        batch_output[i:i+1], batch_ground_truth[i:i+1], factor=1/ds_cuda
    )
    individual_psnrs.append(psnr)
    print(f"  {classes[batch_labels[i]]}: {psnr:.2f} dB")

# Visualize batch results
plot_comparison(batch_ground_truth[:4], batch_output[:4], 
                title=f"Batch Attack Results (Avg PSNR: {batch_psnr:.2f} dB)")

## 8. Analysis and Conclusions

### Key Findings

1. **Model Performance**: How well the model is trained affects reconstruction quality
2. **Single vs Batch**: Single images are generally easier to reconstruct than batches
3. **Class Dependencies**: Some classes are more vulnerable than others
4. **Privacy Implications**: Even with limited success, some information leaks through gradients

### Factors Affecting Reconstruction Quality

- **Model accuracy**: Better trained models often lead to better reconstructions
- **Batch size**: Smaller batches are easier to invert
- **Image complexity**: Simple images reconstruct better
- **Attack parameters**: Hyperparameter tuning is crucial
- **Model architecture**: Different architectures have different vulnerabilities

In [None]:
# Summary statistics
print("Summary of Results")
print("=" * 50)
print(f"Model: {MODEL_NAME}")
print(f"Model Accuracy: {100 * correct / total:.2f}%")
print(f"\nSingle Image Attack:")
print(f"  Target: {classes[true_label]} (index {TARGET_ID})")
print(f"  PSNR: {test_psnr:.2f} dB")
print(f"  Time: {attack_time:.2f} seconds")
print(f"\nBatch Attack ({BATCH_SIZE} images):")
print(f"  Average PSNR: {batch_psnr:.2f} dB")
print(f"  Best PSNR: {max(individual_psnrs):.2f} dB")
print(f"  Worst PSNR: {min(individual_psnrs):.2f} dB")
print(f"  Time: {batch_time:.2f} seconds")

## 9. Experiment with Different Settings

Try modifying these parameters to see how they affect reconstruction:

1. **Different target images**: Change `TARGET_ID` to attack different images
2. **Attack parameters**: Modify `config` dictionary (learning rate, iterations, cost function)
3. **Model architecture**: Try 'ConvNet' or 'ResNet32-10' instead of 'ResNet20-4'
4. **Batch size**: Try different batch sizes for batch attack
5. **Untrained model**: Skip training to see how untrained models behave