# 🎓 Advanced Federated Learning with Scaffold and ResNet on CIFAR-100

Welcome to the next level of federated learning! This tutorial will guide you through implementing Scaffold (Stochastic Controlled Averaging for Federated Learning), a powerful and robust algorithm designed to handle non-IID data. We'll use a more advanced ResNet model and the challenging CIFAR-100 dataset.

## Step 1: Setup and Data Preparation (Non-IID CIFAR-100)

First, let's prepare our environment and the more complex CIFAR-100 dataset. We will again simulate non-IID data conditions to show why Scaffold is necessary.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms
import copy
import random
import numpy as np

In [2]:
# Set a device for training
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data transformations for CIFAR-100
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

### Download and load the CIFAR100 training and test datasets


In [3]:
train_dataset = datasets.CIFAR100('./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100('./data', train=False, download=True, transform=transform)

In [4]:
NUM_CLIENTS = 10
CLIENT_BATCH_SIZE = 32

# Function to create non-IID data partitions for CIFAR-100
def create_non_iid_partitions_cifar100(dataset, num_clients, classes_per_client=10):
    class_indices = [[] for _ in range(100)]
    for i, (_, label) in enumerate(dataset):
        class_indices[label].append(i)
    
    client_indices = [[] for _ in range(num_clients)]
    
    # Assign a non-overlapping set of classes to each client
    classes_per_client_list = np.array_split(np.arange(100), num_clients)
    
    for client_idx, client_classes in enumerate(classes_per_client_list):
        for class_idx in client_classes:
            client_indices[client_idx].extend(class_indices[class_idx])
    
    partitions = [data.Subset(dataset, random.sample(indices, len(indices))) for indices in client_indices]
    return partitions

client_data = create_non_iid_partitions_cifar100(train_dataset, NUM_CLIENTS)
client_trainloaders = [data.DataLoader(d, batch_size=CLIENT_BATCH_SIZE, shuffle=True) for d in client_data]
test_dataloader = data.DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Data has been partitioned among {NUM_CLIENTS} clients, each with a disjoint set of classes.")

Data has been partitioned among 10 clients, each with a disjoint set of classes.


### Code Explanation:

We're using CIFAR-100, which has 100 classes, and the data normalization values are different.

create_non_iid_partitions_cifar100: This function creates a stark non-IID distribution by giving each client a completely separate set of classes (e.g., client 1 gets classes 0-9, client 2 gets classes 10-19, and so on). This is a strong test for Scaffold.

## Step 2: Defining the Advanced ResNet Model

We'll use a simplified ResNet architecture, which is a powerful CNN that uses "residual blocks" to improve training in deep networks.

In [5]:
# A simple ResNet block
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out


In [6]:
# A simplified ResNet model
class ResNet(nn.Module):
    def __init__(self, num_blocks, num_classes=100):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        self.layer1 = self._make_layer(ResNetBlock, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(ResNetBlock, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(ResNetBlock, 64, num_blocks[2], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

# Instantiate our ResNet model
# We'll use a ResNet-18 like architecture for this tutorial
resnet_model = ResNet([2, 2, 2])

### Code Explanation:

BasicBlock: This class defines a core building block of a ResNet. It contains two convolutional layers and a "shortcut" connection that adds the input x to the output of the convolutional layers.

ResNet: This class orchestrates multiple BasicBlock layers to create the full network. The _make_layer function helps build these layers systematically.

This model is significantly deeper and more powerful than the simple CNN from the last tutorial.

## Step 3: The Client-side Training with Scaffold

Scaffold's key innovation is the use of "control variates" to correct for client drift. Each client maintains a local control variate c_k and the server maintains a global control variate c. The client's local gradient update is corrected using these variates.

The client's training step is modified as follows:
w_k
leftarroww_k−
eta(
nablaF_k(w_k)−c_k+c)

In [7]:
# A helper function to get model parameters as a flat vector
def get_params_vector(model):
    return torch.cat([p.data.view(-1) for p in model.parameters()])

# A helper function to set model parameters from a flat vector
def set_params_vector(model, param_vec):
    start = 0
    for p in model.parameters():
        num_params = p.numel()
        p.data.copy_(param_vec[start:start+num_params].view(p.size()))
        start += num_params

In [8]:
# Client training function for Scaffold
def client_training_scaffold(model, global_model_state_dict, c_global, c_local, trainloader, lr=0.01, epochs=1):
    local_model = copy.deepcopy(model).to(DEVICE)
    local_model.train()
    
    # Load global model state and initialize optimizer
    local_model.load_state_dict(global_model_state_dict)
    optimizer = optim.SGD(local_model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Get initial local model parameters
    w_initial = get_params_vector(local_model)

    # Initialize client control variate from the server's global variate
    c_local_client = copy.deepcopy(c_local).to(DEVICE)
    c_global_server = c_global.to(DEVICE)

    for epoch in range(epochs):  # Single epoch for simplicity
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = local_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # --- Scaffold's Gradient Correction ---
            # Get local gradients
            local_grad = torch.cat([p.grad.data.view(-1) for p in local_model.parameters()])
            
            # Compute the corrected gradient: grad - c_local + c_global
            corrected_grad = local_grad - c_local_client + c_global_server

            # Update parameters with the corrected gradient
            for p, g in zip(local_model.parameters(), corrected_grad.view_as(local_grad)):
                p.data.add_(g, alpha=-lr)
            
    # Calculate the new client control variate
    w_final = get_params_vector(local_model)
    # c_new_local = c_local - c_global + (w_initial - w_final) / (len(trainloader) * lr)
    # The formula is slightly more complex, but we'll use a simpler form for demonstration
    c_new_local = c_local_client - c_global_server + (w_initial.to(DEVICE) - w_final.to(DEVICE)) / (len(trainloader.dataset) * lr)
    
    return local_model.state_dict(), c_new_local.cpu()

### Code Explanation:

The client function now takes c_global and c_local as inputs.

The most important part is the gradient correction line: corrected_grad = local_grad - c_local_client + c_global_server. This is where Scaffold's magic happens, preventing client drift.

The client's local parameters are updated using this corrected gradient instead of the raw local gradient.

The function returns both the updated model parameters and the new client control variate.

## Step 4: Server-side Aggregation

The server's role in Scaffold is to aggregate the model weights (just like FedAvg) and to update the global control variate based on the client's updates.

In [9]:
def aggregate_parameters(client_updates):
    """Aggregates parameters from multiple clients using FedAvg."""
    if not client_updates:
        return None

    global_state_dict = copy.deepcopy(client_updates[0])
    
    # Zero-out the global state dictionary
    for name in global_state_dict.keys():
        global_state_dict[name] = torch.zeros_like(global_state_dict[name])

    for client_state_dict in client_updates:
        for name, param in client_state_dict.items():
            # Correction: Only aggregate floating-point parameters
            if param.dtype == torch.float32:
                global_state_dict[name] += param / len(client_updates)
            else:
                # For non-float params (like num_batches_tracked), just copy from the first client
                # as they are not aggregated in the same way.
                global_state_dict[name] = client_state_dict[name]
    
    return global_state_dict

def update_global_control_variate(c_global, c_local_updates):
    """Updates the global control variate in Scaffold."""
    if not c_local_updates:
        return c_global

    new_c_global = copy.deepcopy(c_global)
    
    for c_update in c_local_updates:
        # Correction: Move the client update to the same device as the global variate
        new_c_global += c_update.to(new_c_global.device) / len(c_local_updates)
        
    return new_c_global

## Step 5: The Federated Training Loop

Finally, we tie everything together. The main loop will now manage not only the global model but also the global control variate and each client's local control variate.

In [10]:
def server_evaluation(model, dataloader):
    """Evaluates the global model on the server's test set."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Initialize the global model and control variates
global_model = ResNet([2, 2, 2]).to(DEVICE)
global_c_variate = torch.zeros_like(get_params_vector(global_model))
client_c_variates = [torch.zeros_like(get_params_vector(global_model)) for _ in range(NUM_CLIENTS)]

print(f"Initial global model accuracy: {server_evaluation(global_model, test_dataloader):.2f}%")


Initial global model accuracy: 0.67%


In [None]:
# Federated Learning Main Loop
NUM_ROUNDS = 10
CLIENTS_PER_ROUND = 5
CLIENT_LEARNING_RATE = 0.01

for round_num in range(NUM_ROUNDS):
    print(f"\n--- Starting Federated Round {round_num + 1}/{NUM_ROUNDS} ---")
    
    participating_clients_indices = random.sample(range(NUM_CLIENTS), CLIENTS_PER_ROUND)
    print(f"Server selects clients: {participating_clients_indices}")

    global_model_state_dict = global_model.state_dict()
    client_updates = []
    client_c_variate_updates = []
    
    for client_idx in participating_clients_indices:
        # Simulate local training on each client
        print(f"Client {client_idx} is training...")
        client_dataloader = client_trainloaders[client_idx]
        
        # Get the client's current local control variate
        c_local = client_c_variates[client_idx]

        # Run client training with Scaffold
        local_state_dict, new_c_local = client_training_scaffold(
            global_model, global_model_state_dict, global_c_variate, c_local, client_dataloader, lr=CLIENT_LEARNING_RATE, epochs=3
        )
        
        # Store the updates
        client_updates.append(local_state_dict)
        client_c_variate_updates.append(new_c_local)

        # Update the client's local control variate for the next round
        client_c_variates[client_idx] = new_c_local

    # Server aggregates model updates and control variate updates
    new_global_state_dict = aggregate_parameters(client_updates)
    print(f"Global model parameters Aggregated.")

    if new_global_state_dict:
        global_model.load_state_dict(new_global_state_dict)
        print(f"Global model parameters updated.")


    global_c_variate = update_global_control_variate(global_c_variate, client_c_variate_updates)

    accuracy = server_evaluation(global_model, test_dataloader)
    print(f"Global model accuracy after round {round_num + 1}: {accuracy:.2f}%")

print("\n--- Federated Learning Finished ---")


--- Starting Federated Round 1/10 ---
Server selects clients: [6, 1, 5, 3, 2]
Client 6 is training...
Client 1 is training...
Client 5 is training...
Client 3 is training...
