# Lab 8b: Federated Learning and Adversarial Training

## Learning Objectives

By the end of this lab, you will understand:

1. **Federated Learning (FL):** Training without centralizing data
2. **FedAvg Algorithm:** Aggregating model updates across clients
3. **Privacy & Security Benefits:** Data locality, attack surface reduction
4. **Adversarial Training:** Robustness against evasion attacks
5. **Combined Defenses:** FL + adversarial training for stronger guarantees
6. **Practical Challenges:** Non-IID data, client drift, communication costs

## Table of Contents

1. [Threat Model & Rationale](#threat-model)
2. [Federated Learning (FedAvg)](#fl)
3. [Adversarial Training (FGSM)](#adv-training)
4. [Robustness Evaluation](#evaluation)
5. [Exercises](#exercises)

---

## Threat Model & Rationale <a id="threat-model"></a>

**Why FL?** Data is sensitive, so it remains on client devices.

### Threats Addressed:

| Threat | Defense | Mechanism |
|--------|---------|-----------|
| Data Exfiltration | FL | No raw data leaves device |
| Membership Inference | FL + DP | Aggregate updates only |
| Evasion Attacks | Adversarial Training | Train on adversarial examples |
| Data Poisoning | Robust Aggregation (optional) | Trimmed mean, Krum |

### FL + Adversarial Training:
- FL reduces data exposure
- Adversarial training increases robustness to input attacks
- Combined defenses create layered security

---

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataclasses import dataclass

np.random.seed(42)
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

# Subset for speed
train_indices = np.random.choice(len(train_dataset), 8000, replace=False)
test_indices = np.random.choice(len(test_dataset), 2000, replace=False)

train_data = Subset(train_dataset, train_indices)
test_data = Subset(test_dataset, test_indices)

test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

print(f"Train: {len(train_data)}, Test: {len(test_data)}")

In [None]:
# ============================================================================
# Model Definition
# ============================================================================

class SmallCNN(nn.Module):
    def __init__(self):
        super(SmallCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return 100.0 * correct / total

print("Model ready.")

In [None]:
# ============================================================================
# PART 1: Federated Learning (FedAvg)
# ============================================================================

print("\n" + "="*70)
print("PART 1: Federated Learning (FedAvg)")
print("="*70)

def create_clients(dataset, n_clients=5, non_iid=False):
    """Split dataset into client subsets (IID or non-IID)."""
    indices = np.arange(len(dataset))
    if non_iid:
        # Sort by labels to create non-IID splits
        labels = np.array([dataset[i][1] for i in indices])
        indices = indices[np.argsort(labels)]
    
    client_splits = np.array_split(indices, n_clients)
    clients = [Subset(dataset, split) for split in client_splits]
    return clients

def local_train(model, loader, epochs=1, lr=0.01):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for _ in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

def federated_average(models):
    """FedAvg: average model parameters."""
    avg_model = SmallCNN().to(device)
    
    # Initialize with zeros
    for param in avg_model.parameters():
        param.data.zero_()
    
    # Average parameters
    for model in models:
        for avg_param, param in zip(avg_model.parameters(), model.parameters()):
            avg_param.data += param.data / len(models)
    
    return avg_model

# Create clients
clients = create_clients(train_data, n_clients=5, non_iid=True)
client_loaders = [DataLoader(c, batch_size=64, shuffle=True) for c in clients]

print(f"Created {len(clients)} clients.")

# Federated training
global_model = SmallCNN().to(device)
rounds = 5
global_accuracies = []

for r in range(rounds):
    local_models = []
    for loader in client_loaders:
        local_model = SmallCNN().to(device)
        local_model.load_state_dict(global_model.state_dict())
        local_train(local_model, loader, epochs=1)
        local_models.append(local_model)
    
    global_model = federated_average(local_models)
    acc = evaluate(global_model, test_loader)
    global_accuracies.append(acc)
    print(f"Round {r+1}/{rounds}: Global Test Accuracy = {acc:.2f}%")

In [None]:
# ============================================================================
# PART 2: Adversarial Training (FGSM)
# ============================================================================

print("\n" + "="*70)
print("PART 2: Adversarial Training (FGSM)")
print("="*70)

def fgsm_attack(model, data, target, epsilon=0.1):
    data.requires_grad = True
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    model.zero_grad()
    loss.backward()
    data_grad = data.grad.data
    perturbed = data + epsilon * data_grad.sign()
    return torch.clamp(perturbed, -3, 3)  # normalized range

def adversarial_train(model, loader, epochs=3, epsilon=0.1):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            # Create adversarial examples
            adv_data = fgsm_attack(model, data, target, epsilon=epsilon)
            # Train on adversarial + clean mixed
            optimizer.zero_grad()
            output = model(adv_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{epochs}: Adversarial training complete")

def evaluate_adversarial(model, loader, epsilon=0.1):
    model.eval()
    correct = 0
    total = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        adv_data = fgsm_attack(model, data, target, epsilon=epsilon)
        output = model(adv_data)
        pred = output.argmax(dim=1)
        correct += (pred == target).sum().item()
        total += target.size(0)
    return 100.0 * correct / total

# Train adversarial model (centralized for comparison)
adv_model = SmallCNN().to(device)
adv_train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
adversarial_train(adv_model, adv_train_loader, epochs=3, epsilon=0.1)

clean_acc = evaluate(adv_model, test_loader)
adv_acc = evaluate_adversarial(adv_model, test_loader, epsilon=0.1)

print(f"\nAdversarially Trained Model:")
print(f"  Clean Accuracy: {clean_acc:.2f}%")
print(f"  Adversarial Accuracy (FGSM): {adv_acc:.2f}%")

In [None]:
# ============================================================================
# PART 3: Combined Defense Comparison
# ============================================================================

print("\n" + "="*70)
print("PART 3: Defense Comparison")
print("="*70)

# Baseline model (centralized training for comparison)
baseline_model = SmallCNN().to(device)
baseline_loader = DataLoader(train_data, batch_size=64, shuffle=True)
optimizer = optim.SGD(baseline_model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()

for _ in range(3):
    for data, target in baseline_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = baseline_model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

baseline_clean = evaluate(baseline_model, test_loader)
baseline_adv = evaluate_adversarial(baseline_model, test_loader, epsilon=0.1)

fedavg_clean = evaluate(global_model, test_loader)
fedavg_adv = evaluate_adversarial(global_model, test_loader, epsilon=0.1)

adv_clean = clean_acc
adv_adv = adv_acc

summary = pd.DataFrame({
    'Model': ['Baseline', 'FedAvg', 'Adv Training'],
    'Clean Acc': [baseline_clean, fedavg_clean, adv_clean],
    'FGSM Acc': [baseline_adv, fedavg_adv, adv_adv]
})

print(summary.to_string(index=False))

In [None]:
# ============================================================================
# PART 4: Visualization
# ============================================================================

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# FedAvg convergence
ax = axes[0]
ax.plot(global_accuracies, marker='o', linewidth=2, color='#3498db')
ax.set_xlabel('Federated Round')
ax.set_ylabel('Test Accuracy (%)')
ax.set_title('FedAvg Convergence')
ax.grid(alpha=0.3)

# Robustness comparison
ax = axes[1]
x_pos = np.arange(len(summary))
width = 0.35
ax.bar(x_pos - width/2, summary['Clean Acc'], width, label='Clean', color='#2ecc71', alpha=0.8)
ax.bar(x_pos + width/2, summary['FGSM Acc'], width, label='FGSM', color='#e74c3c', alpha=0.8)
ax.set_xticks(x_pos)
ax.set_xticklabels(summary['Model'])
ax.set_ylabel('Accuracy (%)')
ax.set_title('Robustness Comparison')
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('federated_adversarial.png', dpi=150, bbox_inches='tight')
plt.show()

print('âœ“ Visualization complete.')

---

## Summary: Federated Learning and Adversarial Training

### Key Findings:

1. **FedAvg achieves strong utility** without sharing raw data
2. **Non-IID data slows convergence** (client drift)
3. **Adversarial training improves robustness** (higher FGSM accuracy)
4. **Combined defenses provide layered security**

### Practical Guidance:
- **FL for privacy:** Use when data cannot leave devices
- **Adversarial training for robustness:** Use when evasion attacks are a threat
- **Combine with DP:** For formal privacy guarantees

---

## Exercises

### Exercise 1: Non-IID Severity (Medium)
Increase non-IID skew and measure FedAvg convergence degradation.

### Exercise 2: Client Participation (Medium)
Use partial participation (e.g., 50% clients per round). Compare accuracy.

### Exercise 3: Robust Aggregation (Hard)
Implement trimmed mean or Krum to defend against malicious clients.

### Exercise 4: Adversarial Training in FL (Hard)
Train each client with FGSM examples and compare global robustness.

### Exercise 5: Communication Efficiency (Hard)
Simulate gradient compression (top-k or quantization) and measure accuracy loss.

### Exercise 6: Privacy-Utility Trade-off (Hard)
Combine FL + DP + adversarial training. Evaluate privacy and robustness together.