# Moving Beyond FedAvg with Flower Strategies

Welcome to the KAICD course on federated learning with Flower, part 2!

In this notebook, we'll customize the federated learning system we built in part 1 (again, using Flower and PyTorch). In part 1, we use PyTorch for the model training pipeline and data loading. In part 2, we continue to federate the PyTorch-based pipeline using Flower.

## Part 0: Preparation

Before we begin with any actual code, let's start make sure that we have everything we need. We recommend switching to a runtime with GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`).

### Installing dependencies

Next, we install and import the necessary packages:

In [3]:
!pip install torch==1.9.0 torchvision==0.10.0 git+https://github.com/adap/flower.git@release/0.17#egg=flwr["simulation"]

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")

Collecting flwr[simulation]
  Cloning https://github.com/adap/flower.git (to revision release/0.17) to /tmp/pip-install-o5dfdrvu/flwr_bc542f8c16e74044acdbc2f792938cf6
  Running command git clone -q https://github.com/adap/flower.git /tmp/pip-install-o5dfdrvu/flwr_bc542f8c16e74044acdbc2f792938cf6
  Running command git checkout -b release/0.17 --track origin/release/0.17
  Switched to a new branch 'release/0.17'
  Branch 'release/0.17' set up to track remote branch 'release/0.17' from 'origin'.
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Training on cuda:0


If you run on Google Colab and your runtime has a GPU accelerator, you should see the output `Training on cuda:0`.

### Data loading

Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own `DataLoader`:

In [4]:
NUM_CLIENTS = 10

def load_datasets():
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into 10 partitions to simulate the individual dataset
    partition_size = len(trainset) // NUM_CLIENTS
    lengths = [partition_size] * NUM_CLIENTS
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets()

Files already downloaded and verified
Files already downloaded and verified


### Model training/evaluation

Let's continue with the usual model definition, training and test functions:

In [5]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(testloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

## Part 1: Strategy customization

### Flower client

To implement the Flower client, we create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`:

In [6]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

### Starting with a customized strategy

We now have `FlowerClient` which defines client-side training and evaluation and `client_fn` which allows Flower to create `FlowerClient` instances whenever it needs to call `fit` or `evaluate` on one particular client. The last step is to start the actual simulation using `flwr.simulation.start_simulation`. 

The function `start_simulation` accepts a number of arguments, amongst them the `client_fn` used to create `FlowerClient` instances, the number of clients to simulate `num_clients`, the number of rounds `num_rounds`, and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, *Federated Averaging* (FedAvg).

Flower comes with a number of built-in strategies, but we can also use our own strategy implementations to customize nearly all aspects of the federated learning approach. For this example, we use the built-in `FedAvg` implementation and customize it using a few basic parameters. The last step is the actual call to `start_simulation` which - you guessed it - starts the simulation.

In [7]:
# # Create FedAvg strategy
# strategy=fl.server.strategy.FedAdam(
#         fraction_fit=1.0,  # Sample 10% of available clients for training
#         fraction_eval=1.0,  # Sample 5% of available clients for evaluation
#         min_fit_clients=10,  # Never sample less than 10 clients for training
#         min_eval_clients=10,  # Never sample less than 5 clients for evaluation
#         min_available_clients=10,  # Wait until all 10 clients are available
# )

# # Start simulation
# fl.simulation.start_simulation(
#     client_fn=client_fn,
#     num_clients=NUM_CLIENTS,
#     num_rounds=5,
#     strategy=strategy,
# )

## Part 2: Server-side parameter **initialization**

Flower, by default, initializes the global model by asking one random client for the initial parameters. In many cases, we want more control over parameter initialization though. Flower therefore allows you to directly pass the initial parameters to the Strategy:

In [8]:
# Pass initial parameters to the Strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.3,
    fraction_eval=0.3,
    min_fit_clients=3,
    min_eval_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.weights_to_parameters(
        get_parameters(Net())
    ),
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    num_rounds=3,  # Just three rounds
    strategy=strategy,
)

INFO flower 2021-10-21 21:45:08,267 | app.py:95 | Ray initialized with resources: {'accelerator_type:K80': 1.0, 'memory': 7841171867.0, 'object_store_memory': 3920585932.0, 'CPU': 2.0, 'GPU': 1.0, 'node:172.28.0.2': 1.0}
INFO flower 2021-10-21 21:45:08,277 | app.py:104 | Starting Flower simulation running: {'num_rounds': 3}
INFO flower 2021-10-21 21:45:08,279 | server.py:118 | Initializing global parameters
INFO flower 2021-10-21 21:45:08,284 | server.py:300 | Using initial parameters provided by strategy
INFO flower 2021-10-21 21:45:08,287 | server.py:120 | Evaluating initial parameters
INFO flower 2021-10-21 21:45:08,290 | server.py:133 | FL starting
DEBUG flower 2021-10-21 21:45:08,293 | server.py:255 | fit_round: strategy sampled 3 clients (out of 10)


KeyboardInterrupt: ignored

Passing `initial_parameters` to the `FedAvg` strategy prevents Flower from asking one of the clients for the initial parameters. If we look closely, we can see that the logs do not show any calls to the `FlowerClient.get_parameters` method.

## Part 3: Server-side parameter **evaluation** ("Centralized Evaluation")

Flower can evaluate the aggregated model on the server-side or the client-side. Client-side and server-side evaluation are similar in some ways, but different in others.

**Centralized Evaluation** is conceptually simple: it works the same way that evaluation in centralized machine learning does. If there is a server-side dataset that can be used for evaluation purposes, then that's great. We can evaluate the newly aggregated model after each round of training without having to send the model to clients. We're also fortunate in the sense that our entire evaluation dataset is available at all times.

**Federated Evaluation** is more complex, but also more powerful: it doesn't require a centralized dataset and allows us to evaluate models over a larger set of data, which often yields more realistic evaluation results. In fact, many scenarios require us to use **Federated Evaluation** if we want to get representative evaluation results at all. But this power comes at a cost: once we start to evaluate on the client side, we must be aware that our evaluation dataset often changes over consecutive rounds of learning. The evaluation results are not stable, so even if we do not change the model, we'd see our evaluation results fluctuate over time because clients are not always connected and because the dataset on each client can change.

We've seen how federated evaluation works on the client side (i.e., by implementing the `evaluate` method in `FlowerClient`). Now let's see how we can evaluate on the server-side:

In [1]:
# Create a function that returns an evaluation function
def get_evaluation_fn():
    # Load data and model here to avoid the overhead of doing it in `evaluate` itself
    net = Net()
    valloader = valloaders[0]

    # The `evaluate` function will be called after every round
    def evaluate(
        weights: fl.common.Weights,
    ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
        set_parameters(net, weights)  # Update model with the latest parameters
        loss, accuracy = test(net, valloader)
        print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
        return loss, {"accuracy": accuracy}

    return evaluate

In [None]:
# Pass the evaluation function to the Strategy
strategy=fl.server.strategy.FedAdam(
    fraction_fit=1.0,
    fraction_eval=1.0,
    min_fit_clients=10,
    min_eval_clients=10,
    min_available_clients=10,
    initial_parameters=fl.common.weights_to_parameters(
        get_parameters(Net())
    ),
    eval_fn=get_evaluation_fn(),
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    num_rounds=3,
    strategy=strategy,
)

INFO flower 2021-10-21 21:46:49,811 | app.py:95 | Ray initialized with resources: {'GPU': 1.0, 'node:172.28.0.2': 1.0, 'memory': 7841206272.0, 'accelerator_type:K80': 1.0, 'object_store_memory': 3920603136.0, 'CPU': 2.0}
INFO flower 2021-10-21 21:46:49,815 | app.py:104 | Starting Flower simulation running: {'num_rounds': 3}
INFO flower 2021-10-21 21:46:49,824 | server.py:118 | Initializing global parameters
INFO flower 2021-10-21 21:46:49,829 | server.py:300 | Using initial parameters provided by strategy
INFO flower 2021-10-21 21:46:49,834 | server.py:120 | Evaluating initial parameters


## Part 4: Sending arbitrary values server to client

TODO

## Part 4: Implementing custom strategies

TODO

## Recap

In this notebook, we've seen how we can gradually enhance our system by customizing the strategy, choosing a different strategy, initializing parameters on the server side, evaluating models on the server-side, 

Quite a bit of power for so little code!