# Lesson 91: Federated Learning Fundamentals

## Introduction

**Federated Learning** is a revolutionary machine learning paradigm that enables training models across multiple decentralized devices or servers holding local data samples, without exchanging the raw data itself. This approach addresses critical challenges in modern machine learning:

### Why Federated Learning?

1. **Privacy Preservation**: Raw data never leaves the device, protecting sensitive user information
2. **Regulatory Compliance**: Meets GDPR, HIPAA, and other data protection requirements
3. **Data Silos**: Enables collaboration across organizations without data sharing
4. **Edge Computing**: Leverages computational power of edge devices (smartphones, IoT devices)
5. **Network Efficiency**: Reduces bandwidth by transmitting model updates instead of raw data

### Real-World Applications

- **Mobile Keyboards**: Google's Gboard uses federated learning to improve next-word prediction
- **Healthcare**: Training diagnostic models across hospitals without sharing patient data
- **Finance**: Fraud detection across banks while maintaining customer privacy
- **IoT**: Smart home devices learning user preferences locally

### Learning Objectives

By the end of this lesson, you will:
- Understand the architecture and workflow of federated learning systems
- Implement the Federated Averaging (FedAvg) algorithm
- Compare centralized vs. federated learning performance
- Recognize challenges: communication costs, non-IID data, and convergence
- Build a practical federated learning simulation with multiple clients

## Core Concepts and Theory

### Federated Learning Architecture

Federated learning involves three key components:

1. **Central Server**: Coordinates training and aggregates model updates
2. **Clients (Participants)**: Local devices/servers with private data
3. **Communication Protocol**: Secure channels for model parameter exchange

### The Federated Averaging (FedAvg) Algorithm

Proposed by McMahan et al. (2017), FedAvg is the foundational algorithm for federated learning.

#### Algorithm Steps:

**Server Side:**
1. Initialize global model with parameters $w_0$
2. For each round $t = 1, 2, \ldots, T$:
   - Select a subset of $K$ clients (out of $N$ total)
   - Send current global model $w_t$ to selected clients
   - Wait for client updates
   - Aggregate updates using weighted average

**Client Side (each client $k$):**
1. Receive global model $w_t$ from server
2. Train on local data for $E$ epochs
3. Compute updated model $w_k^{t+1}$
4. Send $w_k^{t+1}$ back to server

#### Mathematical Foundation

The global model update is computed as a weighted average of client models:

$$
w_{t+1} = \sum_{k=1}^{K} \frac{n_k}{n} w_k^{t+1}
$$

Where:
- $w_{t+1}$ = new global model parameters
- $w_k^{t+1}$ = updated model from client $k$
- $n_k$ = number of training samples on client $k$
- $n = \sum_{k=1}^{K} n_k$ = total samples across selected clients

#### Objective Function

Federated learning minimizes the global loss:

$$
\min_{w} f(w) = \sum_{k=1}^{N} \frac{n_k}{n_{\text{total}}} F_k(w)
$$

Where:
- $F_k(w) = \frac{1}{n_k} \sum_{i \in \mathcal{D}_k} \ell(x_i, y_i; w)$ is the local loss on client $k$'s data $\mathcal{D}_k$
- $\ell(x_i, y_i; w)$ is the loss function for sample $(x_i, y_i)$

### Key Challenges

1. **Communication Costs**: Frequent model exchanges can be expensive
2. **Non-IID Data**: Client data distributions may differ significantly
3. **System Heterogeneity**: Varying computational capabilities across clients
4. **Privacy Guarantees**: Model updates can still leak information
5. **Convergence**: Slower and less stable than centralized training

## Practical Implementation

Let's implement a federated learning system from scratch using PyTorch. We'll simulate multiple clients training on the MNIST dataset.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import random

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

### Define the Neural Network Model

We'll use a simple CNN for MNIST digit classification.

In [2]:
class SimpleCNN(nn.Module):
    """Simple CNN for MNIST classification"""
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = torch.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
print(f"Model architecture:\n{model}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters())}")

### Load and Partition MNIST Dataset

We'll split the training data among multiple clients to simulate federated learning.

In [3]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

print(f"Total training samples: {len(train_dataset)}")
print(f"Total test samples: {len(test_dataset)}")

In [4]:
def create_client_datasets(dataset, num_clients, iid=True):
    """
    Partition dataset among clients.
    
    Args:
        dataset: PyTorch dataset
        num_clients: Number of federated clients
        iid: If True, distribute data uniformly (IID)
             If False, create non-IID distribution (each client gets only 2 digits)
    
    Returns:
        List of client datasets
    """
    if iid:
        # IID: Randomly shuffle and split equally
        num_items = len(dataset) // num_clients
        client_datasets = []
        all_indices = list(range(len(dataset)))
        random.shuffle(all_indices)
        
        for i in range(num_clients):
            start = i * num_items
            end = start + num_items if i < num_clients - 1 else len(dataset)
            indices = all_indices[start:end]
            client_datasets.append(Subset(dataset, indices))
    else:
        # Non-IID: Each client gets data from only 2 digit classes
        # Group indices by label
        label_indices = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(dataset):
            label_indices[label].append(idx)
        
        client_datasets = []
        digits_per_client = 2
        
        for i in range(num_clients):
            # Assign 2 consecutive digits to each client (with wrap-around)
            client_labels = [(i * digits_per_client + j) % 10 for j in range(digits_per_client)]
            client_indices = []
            
            for label in client_labels:
                client_indices.extend(label_indices[label])
            
            random.shuffle(client_indices)
            client_datasets.append(Subset(dataset, client_indices))
    
    return client_datasets

# Create client datasets (IID distribution)
NUM_CLIENTS = 5
client_datasets = create_client_datasets(train_dataset, NUM_CLIENTS, iid=True)

print(f"\nCreated {NUM_CLIENTS} client datasets (IID):")
for i, dataset in enumerate(client_datasets):
    print(f"  Client {i+1}: {len(dataset)} samples")

### Implement Client Training Function

In [5]:
def client_update(client_model, client_dataset, epochs=1, batch_size=32, lr=0.01):
    """
    Train model on client's local data.
    
    Args:
        client_model: Model to train
        client_dataset: Client's local dataset
        epochs: Number of local training epochs
        batch_size: Batch size for training
        lr: Learning rate
    
    Returns:
        Updated model state dict and number of samples
    """
    client_model.train()
    criterion = nn.NLLLoss()
    optimizer = optim.SGD(client_model.parameters(), lr=lr)
    
    dataloader = DataLoader(client_dataset, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = client_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    
    return client_model.state_dict(), len(client_dataset)

print("Client training function defined successfully!")

### Implement Federated Averaging (FedAvg)

In [6]:
def federated_averaging(global_model, client_models, client_weights):
    """
    Aggregate client models using weighted averaging.
    
    Args:
        global_model: Global model to update
        client_models: List of client model state dicts
        client_weights: List of client dataset sizes
    
    Returns:
        Updated global model
    """
    global_dict = global_model.state_dict()
    total_samples = sum(client_weights)
    
    # Initialize with zeros
    for key in global_dict.keys():
        global_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)
    
    # Weighted average of client parameters
    for client_model, weight in zip(client_models, client_weights):
        for key in global_dict.keys():
            global_dict[key] += client_model[key] * (weight / total_samples)
    
    global_model.load_state_dict(global_dict)
    return global_model

print("Federated averaging function defined successfully!")

### Evaluation Function

In [7]:
def evaluate(model, test_loader):
    """
    Evaluate model on test dataset.
    
    Returns:
        Test accuracy and loss
    """
    model.eval()
    criterion = nn.NLLLoss()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    return accuracy, test_loss

print("Evaluation function defined successfully!")

### Run Federated Learning Training

In [8]:
def federated_learning(num_rounds=10, local_epochs=1, client_fraction=1.0):
    """
    Execute federated learning training.
    
    Args:
        num_rounds: Number of communication rounds
        local_epochs: Number of local training epochs per client
        client_fraction: Fraction of clients to select each round
    
    Returns:
        Training history (accuracies and losses)
    """
    # Initialize global model
    global_model = SimpleCNN().to(device)
    
    history = {'accuracy': [], 'loss': []}
    
    for round_num in range(num_rounds):
        print(f"\n{'='*50}")
        print(f"Communication Round {round_num + 1}/{num_rounds}")
        print(f"{'='*50}")
        
        # Select clients for this round
        num_selected = max(1, int(NUM_CLIENTS * client_fraction))
        selected_clients = random.sample(range(NUM_CLIENTS), num_selected)
        print(f"Selected clients: {[c+1 for c in selected_clients]}")
        
        # Train on selected clients
        client_models = []
        client_weights = []
        
        for client_id in selected_clients:
            # Create client model copy
            client_model = SimpleCNN().to(device)
            client_model.load_state_dict(global_model.state_dict())
            
            # Train on client's data
            state_dict, num_samples = client_update(
                client_model, 
                client_datasets[client_id],
                epochs=local_epochs
            )
            
            client_models.append(state_dict)
            client_weights.append(num_samples)
            print(f"  Client {client_id+1} trained on {num_samples} samples")
        
        # Aggregate client models
        global_model = federated_averaging(global_model, client_models, client_weights)
        print("  Global model updated via FedAvg")
        
        # Evaluate global model
        accuracy, loss = evaluate(global_model, test_loader)
        history['accuracy'].append(accuracy)
        history['loss'].append(loss)
        
        print(f"\n  Test Accuracy: {accuracy:.2f}%")
        print(f"  Test Loss: {loss:.4f}")
    
    return global_model, history

# Run federated learning
NUM_ROUNDS = 10
LOCAL_EPOCHS = 2

print(f"\nStarting Federated Learning...")
print(f"Configuration:")
print(f"  - Number of clients: {NUM_CLIENTS}")
print(f"  - Communication rounds: {NUM_ROUNDS}")
print(f"  - Local epochs per round: {LOCAL_EPOCHS}")

fed_model, fed_history = federated_learning(
    num_rounds=NUM_ROUNDS,
    local_epochs=LOCAL_EPOCHS,
    client_fraction=1.0
)

print(f"\n{'='*50}")
print("Federated Learning Complete!")
print(f"Final Test Accuracy: {fed_history['accuracy'][-1]:.2f}%")
print(f"{'='*50}")

## Comparison: Centralized vs. Federated Learning

Let's compare federated learning with traditional centralized training.

In [9]:
def centralized_training(num_epochs=10, batch_size=32, lr=0.01):
    """
    Traditional centralized training on entire dataset.
    """
    central_model = SimpleCNN().to(device)
    criterion = nn.NLLLoss()
    optimizer = optim.SGD(central_model.parameters(), lr=lr)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    history = {'accuracy': [], 'loss': []}
    
    for epoch in range(num_epochs):
        central_model.train()
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = central_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        
        # Evaluate
        accuracy, loss = evaluate(central_model, test_loader)
        history['accuracy'].append(accuracy)
        history['loss'].append(loss)
        
        print(f"Epoch {epoch+1}/{num_epochs} - Accuracy: {accuracy:.2f}%, Loss: {loss:.4f}")
    
    return central_model, history

print("\nStarting Centralized Training...\n")
central_model, central_history = centralized_training(
    num_epochs=NUM_ROUNDS,
    batch_size=32,
    lr=0.01
)

print(f"\nCentralized Training Complete!")
print(f"Final Test Accuracy: {central_history['accuracy'][-1]:.2f}%")

### Visualize Training Comparison

In [10]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot accuracy
ax1.plot(range(1, NUM_ROUNDS + 1), fed_history['accuracy'], 
         marker='o', label='Federated Learning', linewidth=2)
ax1.plot(range(1, NUM_ROUNDS + 1), central_history['accuracy'], 
         marker='s', label='Centralized Learning', linewidth=2)
ax1.set_xlabel('Round / Epoch', fontsize=12)
ax1.set_ylabel('Test Accuracy (%)', fontsize=12)
ax1.set_title('Accuracy Comparison', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot loss
ax2.plot(range(1, NUM_ROUNDS + 1), fed_history['loss'], 
         marker='o', label='Federated Learning', linewidth=2)
ax2.plot(range(1, NUM_ROUNDS + 1), central_history['loss'], 
         marker='s', label='Centralized Learning', linewidth=2)
ax2.set_xlabel('Round / Epoch', fontsize=12)
ax2.set_ylabel('Test Loss', fontsize=12)
ax2.set_title('Loss Comparison', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('federated_vs_centralized.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nKey Observations:")
print(f"  - Federated final accuracy: {fed_history['accuracy'][-1]:.2f}%")
print(f"  - Centralized final accuracy: {central_history['accuracy'][-1]:.2f}%")
print(f"  - Accuracy gap: {abs(fed_history['accuracy'][-1] - central_history['accuracy'][-1]):.2f}%")
print(f"\n  Federated learning achieves comparable performance while preserving data privacy!")

## Hands-On Activity: Exploring Non-IID Data

**Challenge**: Investigate how non-IID (non-identically distributed) data affects federated learning performance.

In real-world scenarios, client data is often non-IID—each client may have different data distributions. For example:
- Mobile keyboard users type different types of text
- Hospital patients have varying demographics
- Smart home devices have different usage patterns

### Task: Compare IID vs. Non-IID Performance

Modify the code below to train with non-IID data distribution (each client gets only 2 digit classes).

In [11]:
# Create non-IID client datasets
client_datasets_noniid = create_client_datasets(train_dataset, NUM_CLIENTS, iid=False)

print(f"Created {NUM_CLIENTS} client datasets (Non-IID):")
for i, dataset in enumerate(client_datasets_noniid):
    print(f"  Client {i+1}: {len(dataset)} samples")
    # Show label distribution for first few samples
    labels = [dataset[j][1] for j in range(min(100, len(dataset)))]
    unique_labels = sorted(set(labels))
    print(f"    Sample labels: {unique_labels}")

In [12]:
# Modified federated learning function for non-IID experiment
def federated_learning_noniid(client_datasets, num_rounds=10, local_epochs=1):
    """Federated learning with specified client datasets"""
    global_model = SimpleCNN().to(device)
    history = {'accuracy': [], 'loss': []}
    
    for round_num in range(num_rounds):
        print(f"Round {round_num + 1}/{num_rounds}", end=' ')
        
        client_models = []
        client_weights = []
        
        for client_id in range(NUM_CLIENTS):
            client_model = SimpleCNN().to(device)
            client_model.load_state_dict(global_model.state_dict())
            
            state_dict, num_samples = client_update(
                client_model, 
                client_datasets[client_id],
                epochs=local_epochs
            )
            
            client_models.append(state_dict)
            client_weights.append(num_samples)
        
        global_model = federated_averaging(global_model, client_models, client_weights)
        accuracy, loss = evaluate(global_model, test_loader)
        history['accuracy'].append(accuracy)
        history['loss'].append(loss)
        
        print(f"- Accuracy: {accuracy:.2f}%, Loss: {loss:.4f}")
    
    return global_model, history

print("\nTraining with Non-IID data...\n")
fed_model_noniid, fed_history_noniid = federated_learning_noniid(
    client_datasets_noniid,
    num_rounds=NUM_ROUNDS,
    local_epochs=LOCAL_EPOCHS
)

print(f"\nNon-IID Federated Learning Complete!")
print(f"Final Test Accuracy: {fed_history_noniid['accuracy'][-1]:.2f}%")

### Compare IID vs. Non-IID Performance

In [13]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot accuracy comparison
ax1.plot(range(1, NUM_ROUNDS + 1), fed_history['accuracy'], 
         marker='o', label='IID Data', linewidth=2, color='blue')
ax1.plot(range(1, NUM_ROUNDS + 1), fed_history_noniid['accuracy'], 
         marker='^', label='Non-IID Data', linewidth=2, color='red')
ax1.plot(range(1, NUM_ROUNDS + 1), central_history['accuracy'], 
         marker='s', label='Centralized (baseline)', linewidth=2, 
         color='green', linestyle='--', alpha=0.7)
ax1.set_xlabel('Communication Round', fontsize=12)
ax1.set_ylabel('Test Accuracy (%)', fontsize=12)
ax1.set_title('IID vs Non-IID Performance', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Plot loss comparison
ax2.plot(range(1, NUM_ROUNDS + 1), fed_history['loss'], 
         marker='o', label='IID Data', linewidth=2, color='blue')
ax2.plot(range(1, NUM_ROUNDS + 1), fed_history_noniid['loss'], 
         marker='^', label='Non-IID Data', linewidth=2, color='red')
ax2.plot(range(1, NUM_ROUNDS + 1), central_history['loss'], 
         marker='s', label='Centralized (baseline)', linewidth=2, 
         color='green', linestyle='--', alpha=0.7)
ax2.set_xlabel('Communication Round', fontsize=12)
ax2.set_ylabel('Test Loss', fontsize=12)
ax2.set_title('Loss Curves', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('iid_vs_noniid_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("PERFORMANCE SUMMARY")
print("="*60)
print(f"\nCentralized Learning:  {central_history['accuracy'][-1]:.2f}%")
print(f"Federated (IID):       {fed_history['accuracy'][-1]:.2f}%")
print(f"Federated (Non-IID):   {fed_history_noniid['accuracy'][-1]:.2f}%")
print(f"\nPerformance degradation (Non-IID vs IID): "
      f"{fed_history['accuracy'][-1] - fed_history_noniid['accuracy'][-1]:.2f}%")
print("\nKey Insight: Non-IID data distribution causes slower convergence and ")
print("lower final accuracy, highlighting a major challenge in federated learning!")
print("="*60)

## Key Takeaways

### What We Learned

1. **Federated Learning Workflow**:
   - Clients train locally on private data
   - Server aggregates model updates using FedAvg
   - Privacy is preserved as raw data never leaves devices

2. **Performance Characteristics**:
   - Federated learning can achieve near-centralized performance with IID data
   - Non-IID data significantly degrades convergence and accuracy
   - Communication rounds are the bottleneck (not compute time)

3. **Practical Challenges**:
   - **Statistical Heterogeneity**: Non-IID data is common in real deployments
   - **Communication Costs**: Each round requires sending full model parameters
   - **System Heterogeneity**: Clients have varying computational capabilities
   - **Privacy Leakage**: Model updates can still reveal sensitive information

### Advanced Techniques (Beyond This Lesson)

1. **Communication Efficiency**:
   - Gradient compression and quantization
   - Sparse updates (only send changed parameters)
   - Model distillation

2. **Privacy Enhancement**:
   - Differential privacy (adding noise to updates)
   - Secure aggregation (encrypted model updates)
   - Homomorphic encryption

3. **Non-IID Mitigation**:
   - Personalized federated learning
   - FedProx (proximal term for stability)
   - Multi-task learning approaches

4. **Scalability**:
   - Hierarchical federated learning
   - Asynchronous updates
   - Client selection strategies

### When to Use Federated Learning

✅ **Good Use Cases**:
- Sensitive data (healthcare, finance)
- Regulatory constraints (GDPR compliance)
- Data cannot be centralized (edge devices, mobile)
- Large-scale distributed data sources

❌ **Poor Use Cases**:
- Data can be easily centralized
- No privacy concerns
- Extremely heterogeneous data distributions
- High communication costs unacceptable

## Resources and Further Learning

### Foundational Papers

1. **McMahan et al. (2017)** - "Communication-Efficient Learning of Deep Networks from Decentralized Data"
   - Original FedAvg paper
   - [arXiv:1602.05629](https://arxiv.org/abs/1602.05629)

2. **Li et al. (2020)** - "Federated Optimization in Heterogeneous Networks"
   - Introduces FedProx for non-IID data
   - [arXiv:1812.06127](https://arxiv.org/abs/1812.06127)

3. **Kairouz et al. (2021)** - "Advances and Open Problems in Federated Learning"
   - Comprehensive survey of the field
   - [arXiv:1912.04977](https://arxiv.org/abs/1912.04977)

### Frameworks and Tools

1. **TensorFlow Federated**: https://www.tensorflow.org/federated
2. **PySyft**: https://github.com/OpenMined/PySyft
3. **Flower (flwr)**: https://flower.dev/
4. **FedML**: https://fedml.ai/

### Additional Topics to Explore

- Differential privacy in federated learning
- Secure aggregation protocols
- Federated learning on edge devices
- Cross-silo vs. cross-device federated learning
- Personalized federated learning
- Vertical federated learning

### Datasets for Experimentation

- LEAF Benchmark: https://leaf.cmu.edu/
- FEMNIST (Federated EMNIST)
- Shakespeare text dataset
- Reddit and StackOverflow datasets

---

## Next Steps

1. **Experiment**: Modify hyperparameters (local epochs, learning rate, client selection)
2. **Explore**: Try different datasets and model architectures
3. **Extend**: Implement differential privacy or secure aggregation
4. **Read**: Study the foundational papers listed above
5. **Build**: Create a federated learning application for a real-world use case

**Continue to Lesson 92**: Communication-Efficient Learning techniques!