# Federated Learning with Flower, PyTorch, and CIFAR-10

This notebook demonstrates how to set up a federated learning pipeline using the [Flower](https://flower.dev/) framework with PyTorch and the CIFAR-10 dataset. Federated learning allows multiple clients to collaboratively train a machine learning model while keeping their data local and private.

To promote modularity and reusability, the notebook is structured to cleanly separate the components that are shared between centralized and federated training (e.g., model architecture, training logic) from those specific to the federated setup (e.g., client/server logic, simulation). This design makes it easy to benchmark, test, and iterate across training modes.

### Notebook Structure

1. **Setup and Imports** – Install dependencies and import core libraries.
2. **Dataset Preparation** – Load and partition CIFAR-10 using Flower Datasets and IID partitioning.
3. **Training and Evaluation Functions** – Define reusable model training and testing logic for both centralized and federated workflows.
4. **ClientApp** – Implement a federated Flower client using the shared training/evaluation logic.
5. **ServerApp** – Configure the federated learning strategy and server behavior.
6. **Simulation** – Run a federated learning simulation across multiple clients.
7. **Evaluation** – Evaluate the final global model performance after training.

> Note: This notebook uses simulated clients and centralized coordination for demonstration purposes.


## 1. Setup and Imports

In this section, we install and import the required libraries for our federated learning setup.

- **Flower (flwr)**: A framework for building federated learning systems.
- **Flower Datasets (flwr_datasets)**: Utilities to download and partition datasets easily.
- **PyTorch**: Used for building and training the neural network.
- **Predefined CNN model**: Imported from `fedlearn.model`.

We will use `FederatedDataset` and `IidPartitioner` from Flower Datasets to partition CIFAR-10 into multiple IID client datasets.


In [1]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

from fedlearn.model import SmallCNN

2025-05-30 16:17:27.918826: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Training on cpu
Flower 1.17.0 / PyTorch 2.2.2


## 2. Dataset Preparation

In this section, we load and partition the CIFAR-10 dataset using the Flower Datasets library.

We perform the following steps:
- Download CIFAR-10 using `FederatedDataset`.
- Apply standard normalization and transformation.
- Use `IidPartitioner` to create IID partitions of the dataset, simulating multiple clients in a federated learning setup.

Each partition corresponds to a different client in the federated system.


In [2]:
NUM_PARTITIONS = 10 # Number of partitions for the federated dataset same as the number of clients
BATCH_SIZE = 32


def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

### 2.1 Visualizing Label Distribution Across Clients

To better understand the data each client sees, we visualize the class label distribution with clients on the x-axis. Each bar now represents the number of samples of a particular class on that client.

Since we're using IID partitioning, we expect each class to be fairly evenly distributed across all clients.

This visualization helps verify that the `IidPartitioner` produces balanced partitions with representative data for each client.


## 3. Training and Evaluation Functions

To support both centralized and federated learning, we define reusable functions for training and evaluating models.

These functions can be invoked in two contexts:
- Centrally, to train and evaluate a model using all available data (e.g. for benchmarking).
- Locally on each federated client, for on-device training and reporting.

The training function supports:
- Configurable optimizer and hyperparameters
- Batch-based model updates

The evaluation function returns both loss and accuracy metrics.


In [4]:
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

## 4. ClientApp

Here we define the Flower client, which encapsulates the logic each federated client uses to train and evaluate the model.

Each client performs the following:
- Receives global model weights from the server.
- Trains the model locally on its own partitioned dataset.
- Sends the updated weights back to the server.
- Evaluates the global model on its local test data.

This class makes use of the shared `train` and `test` functions defined earlier to keep the logic consistent and reusable across different training scenarios.

Common customization points:
- Adjusting the optimizer configuration inside `train()`.
- Changing the number of local training epochs or batch size.
- Extending evaluation logic with custom metrics.


In [5]:
class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = SmallCNN().to(DEVICE)

    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

## 5. ServerApp

This section defines the server application using Flower's `server_fn` API. The server is responsible for:

- Configuring the federated learning strategy (e.g. FedAvg).
- Reading runtime settings such as number of rounds and client participation ratio from a configuration context.
- Returning the strategy and configuration needed to run the simulation.

This approach makes the server logic highly modular and allows easy injection of config values like `num-server-rounds` or `fraction-fit` during runtime.


In [6]:
# Create an instance of the model and get the parameters
params = get_parameters(SmallCNN())

# The `evaluate` function will be called by Flower after every round
def evaluate(
    server_round: int,
    parameters: NDArrays,
    config: Dict[str, Scalar],
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    net = SmallCNN().to(DEVICE)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=1.0, # Use all clients for training, C
        fraction_evaluate=0.5, # Use 50% of clients for evaluation
        min_fit_clients=10,  # Minimum number of clients to train
        min_evaluate_clients=5,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(
            params
        ),  # Pass initial model parameters
        evaluate_fn=evaluate,  # Pass the evaluation function
    )

    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=10)
    return ServerAppComponents(strategy=strategy, config=config)

# Create the ServerApp
server = ServerApp(server_fn=server_fn)

## 6. Simulation

With all components in place, we now simulate the federated learning process using Flower’s `start_simulation` function.

This includes:
- Initializing each client with its own data partition.
- Launching multiple client instances in parallel using `ClientApp`.
- Running the simulation across a specified number of federated rounds.

We control how many clients participate per round and how many rounds of training we perform.

Common configuration options:
- Number of clients in the federation
- Number of training rounds
- Client resources (e.g., number of CPUs or GPUs per client)


In [None]:

NUM_PARTITIONS = 10  # Number of partitions (clients)
run_simulation(
    server_app=server, client_app=client, num_supernodes=NUM_PARTITIONS
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      initial parameters (loss, other metrics): 0.07212714471817017, {'accuracy': 0.1}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


Server-side evaluation loss 0.07212714471817017 / accuracy 0.1
[36m(ClientAppActor pid=13967)[0m [Client 2] fit, config: {}
