In [24]:
from collections import OrderedDict
from typing import List, Tuple, Optional
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.optim import SGD
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from fedlearn.model import SmallCNN

In [46]:
datadir = Path().cwd().parent / "data" / "flower_dataset"
logdir = Path().cwd().parent / "logs" / "scaffold"

if not logdir.exists():
    logdir.mkdir(parents=True, exist_ok=True)

In [33]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cuda
Flower 1.18.0 / PyTorch 2.6.0+cu126


In [None]:
NUM_CLIENTS = 10
BATCH_SIZE = 32


def load_datasets(partition_id: int):
    fds = FederatedDataset(
        dataset="cifar10", 
        partitioners={"train": NUM_CLIENTS},
        cache_dir=datadir,
        )
    
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    # Create train/val for each partition and wrap it into DataLoader
    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

### Define Scaffold Optimizer

Recall that the local update in Scaffold is given by

$$
w^{(i)} \gets w^{(i)} - \eta_l \left( g_i(w^{(i)}) + c - c_i \right)
$$

Which can be seen as a gradient correction to Stochastic Gradient Descent (SGD). We may therefore extend the pytorch ```SGD``` class. We do this by computing the the regular SGD step, then adding the correction manually:

$$
\begin{align*}
w^{(i)} &\gets w^{(i)} - \eta_l \, g_i\left(w^{(i)}\right) \\
w^{(i)} &\gets w^{(i)} - \eta_l (c - c_i)
\end{align*}
$$

In [11]:
class ScaffoldOptimizer(SGD):
    def __init__(self, params, lr=0.01, momentum=0.9, weight_decay=0):
        super().__init__(params, lr, momentum, weight_decay)

    def step(self, global_cv, client_cv):
        """
        Perform a single optimization step.
        :param global_cv: Global control variable
        :param client_cv: Client control variable
        """
        # compute regular SGD step
        #   w <- w - lr * grad
        super().step() 

        # now add the correction term
        #   w <- w - lr * (g_cv - c_cv)
        for group in self.param_groups:
            for param, g_cv, c_cv in zip(group["params"], global_cv, client_cv):
                # here we add the correction term to each parameter tensor.
                # the alpha value scales the correction term
                param.data.add_(g_cv - c_cv, alpha=-group["lr"]) 

We can now write a function for the local training. In this function, we want simply want to perform gradient corrected SGD updates over the local data for $E$ epochs.

In [None]:
def train_scaffold(net: torch.nn.Module, 
                   device: torch.device, 
                   trainloader: torch.utils.data.DataLoader,
                   criterion: nn.Module,
                   num_epochs: int, 
                   lr: float, 
                   momentum: float, 
                   weight_decay: float, 
                   global_cv: List[torch.Tensor], 
                   client_cv: List[torch.Tensor],
                   ) -> None:
    """
    Function that trains a model using the Scaffold optimization algorithm.
    Parameters:
        net:            The neural network model to train.
        device:         The device to run the training on (CPU or GPU).
        trainloader:    DataLoader for the training data.
        criterion:      Loss function to use for training.
        num_epochs:     Number of epochs to train the model.
        lr:             Learning rate for the optimizer.
        momentum:       Momentum factor for the optimizer.
        weight_decay:   Weight decay (L2 penalty) for the optimizer.
        global_cv:      Global control variables for Scaffold.
        client_cv:      Client control variables for Scaffold.
    """
    
    net.train()
    optimizer = ScaffoldOptimizer(
        net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )
    
    for _ in range(num_epochs):
        for Xtrain, Ytrain in trainloader:
            Xtrain, Ytrain = Xtrain.to(device), Ytrain.to(device)
            optimizer.zero_grad()
            output = net(Xtrain)
            loss = criterion(output, Ytrain)
            loss.backward()
            
            # Perform a single optimization step with the control variables
            optimizer.step(global_cv, client_cv)

We will also define a test function, which will be called to evaluate our model. This will give us some metrics to evaluate the performance of the model. As we are working with a classifier, we are interested in both the loss and accuracy of the model.

In [None]:
def test(net: torch.nn.Module, 
         device: torch.device, 
         testloader: torch.utils.data.DataLoader,
         criterion: nn.Module,
         ) -> Tuple[float, float]:
    """
    Function that tests a model on the test dataset.
    Parameters:
        net:        The neural network model to test.
        device:     The device to run the testing on (CPU or GPU).
        testloader: DataLoader for the test data.
        criterion:  Loss function to use for testing.
    Returns:
        Tuple containing the average loss and accuracy on the test set.
    """
    
    net.eval()
    total_loss = 0.0    # Accumulator for total loss
    correct = 0         # tracker for correct predictions
    total = 0           # tracker for total predictions
    
    with torch.no_grad():
        for Xtest, Ytest in testloader:
            Xtest, Ytest = Xtest.to(device), Ytest.to(device)
            output = net(Xtest)
            loss = criterion(output, Ytest)
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += Ytest.size(0)
            correct += predicted.eq(Ytest).sum().item()
    
    avg_loss = total_loss / len(testloader) # compute the average loss
    accuracy = correct / total              # compute the accuracy
    return avg_loss, accuracy

For the simulation, we will use the flower framework, which was introduced in the _ notebook. To do so, we need to specify both Client and Server classes. We need to consider a couple of things: 
1. We can inherit the  ```NumPyClient``` class from the flower framework, however we need to remember to convert between ```np.ndarray``` and ```torch.tensor``` before and after local updates.
2. We need to specify a ```client.fit()``` method, containing all the logic for the local update. This method has 2 inputs
   1. parameters: a list of ```np.ndarray```, containing both global model parameters and global control variates
   2. config: dict for specifying training configuration (we will ignore this for now)
3. 

In [17]:
class ScaffoldClient(NumPyClient):
    def __init__(self, 
                 cid: int, 
                 net: torch.nn.Module, 
                 trainloader: torch.utils.data.DataLoader, 
                 valloader: torch.utils.data.DataLoader,
                 criterion: nn.Module,
                 device: torch.device,
                 num_epochs: int,
                 lr: float,
                 momentum: float,
                 weight_decay: float,
                 save_dir: Optional[str] = None,
                 ):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.criterion = criterion
        self.device = device
        self.num_epochs = num_epochs
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        # define directory to save client control variates
        if save_dir is None:
            save_dir = "client_cvs"

        # create directory if it does not exist
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # define the path to save the client control variates
        self.save_name = os.path.join(save_dir, f"client_{self.cid}_cv.pt")

        # initialize client control variates
        self.client_cv = [torch.zeros(param.shape) for param in self.net.parameters()]

    # define methods required by NumPyClient interface
    def get_parameters(self) -> List[np.ndarray]:
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k : torch.tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    # Here is where all the training logic and control variate updates happen
    def fit(self, parameters: List[np.ndarray], config: dict) -> Tuple[List[np.ndarray], int, dict]:
        self.set_parameters(parameters)

        # the global parameters are packed together with the global control variates
        # in the form [params, global_cv]. we start by separating them
        params = parameters[:len(parameters) // 2]
        global_cv = parameters[len(parameters) // 2:]

        # load the current global model:
        self.set_parameters(params)

        # load client control variates, if they exist:
        if os.path.exists(self.save_name):
            self.client_cv = torch.load(self.save_name)

        # convert global control variates to tensors
        global_cv = [torch.tensor(cv) for cv in global_cv]

        # call the training function
        train_scaffold(
            net=self.net,
            device=self.device,
            trainloader=self.trainloader,
            criterion=self.criterion,
            num_epochs=self.num_epochs,
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
            global_cv=global_cv,
            client_cv=self.client_cv
        )

        # update the client control variates
        yi = self.get_parameters()
        client_cv_new = []

        # compute coefficient for the control variates
        # 1 / (K * eta) where K is the number of backward passes (num_epochs * len(trainloader))
        coeff = 1. / (self.num_epochs * len(self.trainloader) * self.lr) 

        # compute client control variate update
        for xj, yj, cj, cij in zip(params, yi, global_cv, self.client_cv):
            client_cv_new.append(
                cij - cj - coeff * (xj - yj)
            ) 

        # compute server updates
        server_update_x = [xj - yj for xj, yj in zip(params, yi)]
        server_update_c = [cij_n - cij for cij_n, cij in zip(client_cv_new, self.client_cv)]

        self.client_cv = client_cv_new
        torch.save(self.client_cv, self.save_name)

        #concatenate server updates
        server_update = server_update_x + server_update_c

        return server_update, len(self.trainloader.dataset), {}



    def evaluate(self, parameters: List[np.ndarray], config: dict) -> Tuple[float, int, dict]:
        self.set_parameters(parameters)
        avg_loss, accuracy = test(
            net=self.net,
            device=self.device,
            testloader=self.valloader,
            criterion=self.criterion
        )
        return float(avg_loss), len(self.testloader), {"accuracy": accuracy}

Now that we have the flower client defined, we need to define a constructor function which the flower framework can use to instatiate clients as it goes.

In [None]:
def client_fn(context: Context) -> Client:
    cid = context.node_config["partition"]
    trainloader, valloader, testloader = load_datasets(cid)

    net = SmallCNN().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    
    # Define hyperparameters for training
    num_epochs = 1
    lr = 0.01
    momentum = 0.9
    weight_decay = 0.0005

    return ScaffoldClient(
        cid=cid,
        net=net,
        trainloader=trainloader,
        valloader=valloader,
        criterion=criterion,
        device=DEVICE,
        num_epochs=num_epochs,
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
        save_dir="client_cvs"
    )