# Use a federated learning strategy
Welcome Back! to the Flower federated learning tutorial!

In this notebook, we’ll begin to customize the federated learning system we built in the introductory notebook again, using the Flower framework, Flower Datasets, and PyTorch.

## Step 0: Preparation

### Loading dependencies

In [1]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from datasets.utils.logging import disable_progress_bar

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()


Training on cuda
Flower 1.20.0 / PyTorch 2.7.1+cu126


### Loading Data

In [2]:
NUM_PARTITIONS = 10
BATCH_SIZE = 32


def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

### Model training/evaluation

In [3]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

### Define the Flower ClientApp

The first step toward creating a ClientApp is to implement a subclasses of flwr.client.Client or flwr.client.NumPyClient. We use NumPyClient in this tutorial because it is easier to implement and requires us to write less boilerplate. To implement NumPyClient, we create a subclass that implements the three methods get_parameters, fit, and evaluate:

get_parameters: Return the current local model parameters

fit: Receive model parameters from the server, train the model on the local data, and return the updated model parameters to the server

evaluate: Receive model parameters from the server, evaluate the model on the local data, and return the evaluation result to the server

In [4]:
class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)

    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

## Strategy customization

### Server-side parameter initialization

In [5]:
# Create an instance of the model and get the parameters
params = get_parameters(Net())

In [6]:
# Create a list of ClientConfig objects, one for each client
client_configs = [{"partition_id": i} for i in range(NUM_PARTITIONS)]

# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0},
                      "client_configs": client_configs  # Pass the client configurations
                     }

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0},
                      "client_configs": client_configs  # Pass the client configurations
                     }
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [7]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(
            params
        ),  # Pass initial model parameters
    )

    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)

# Create ServerApp
server = ServerApp(server_fn=server_fn)

In [8]:


# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      


[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=74042)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.06471679359674454, accuracy 0.21225
[36m(ClientAppActor pid=74042)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.0645328238606453, accuracy 0.22325
[36m(ClientAppActor pid=74042)[0m [Client 5] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.06615382432937622, accuracy 0.201
[36m(ClientAppActor pid=74042)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=74042)[0m [Client 6] evaluate, config: {}
[36m(ClientAppActor pid=74042)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=74042)[0m [Client 5] fit, config: {}
[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.057944368571043015, accuracy 0.304
[36m(ClientAppActor pid=74042)[0m [Client 6] fit, config: {}
[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.057033099234104156, accuracy 0.32
[36m(ClientAppActor pid=74042)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.05787914618849754, accuracy 0.32225
[36m(ClientAppActor pid=74042)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=74042)[0m [Client 8] evaluate, config: {}
[36m(ClientAppActor pid=74042)[0m [Client 9] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=74042)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.05401994660496712, accuracy 0.37175
[36m(ClientAppActor pid=74042)[0m [Client 6] fit, config: {}
[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.05263754725456238, accuracy 0.37925
[36m(ClientAppActor pid=74042)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=74042)[0m Epoch 1: train loss 0.053533826023340225, accuracy 0.36525
[36m(ClientAppActor pid=74042)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=74042)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=74042)[0m [Client 9] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 88.20s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06338811039924622
[92mINFO [0m:      		round 2: 0.05562281262874603
[92mINFO [0m:      		round 3: 0.05197260351975758
[92mINFO [0m:      


### Starting with a customized strategy

In [9]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAdagrad strategy
    strategy = FedAdagrad(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
    )
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=143168)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.0652865543961525, accuracy 0.221
[36m(ClientAppActor pid=143168)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.06579652428627014, accuracy 0.21375
[36m(ClientAppActor pid=143168)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.06515423953533173, accuracy 0.22225
[36m(ClientAppActor pid=143168)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=143168)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=143168)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=143168)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.5984933972358704, accuracy 0.27975
[36m(ClientAppActor pid=143168)[0m [Client 4] fit, config: {}
[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.6368985176086426, accuracy 0.26625
[36m(ClientAppActor pid=143168)[0m [Client 7] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.6322721242904663, accuracy 0.292
[36m(ClientAppActor pid=143168)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=143168)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=143168)[0m [Client 8] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=143168)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.08759572356939316, accuracy 0.171
[36m(ClientAppActor pid=143168)[0m [Client 6] fit, config: {}
[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.08578860014677048, accuracy 0.178
[36m(ClientAppActor pid=143168)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=143168)[0m Epoch 1: train loss 0.08662308752536774, accuracy 0.1785
[36m(ClientAppActor pid=143168)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=143168)[0m [Client 6] evaluate, config: {}
[36m(ClientAppActor pid=143168)[0m [Client 9] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 92.65s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 4.898074420928955
[92mINFO [0m:      		round 2: 0.3416735932032267
[92mINFO [0m:      		round 3: 0.07738929653167725
[92mINFO [0m:      


### Server-side parameter evaluation

In [10]:
# The `evaluate` function will be called by Flower after every round
def evaluate(
    server_round: int,
    parameters: NDArrays,
    config: Dict[str, Scalar],
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    net = Net().to(DEVICE)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

In [11]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create the FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        #initial_parameters=ndarrays_to_parameters(params),
        #evaluate_fn=evaluate,  # Pass the evaluation function
    )
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

In [12]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=144503)[0m [Client 9] get_parameters
[36m(ClientAppActor pid=144503)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.06609909236431122, accuracy 0.20775
[36m(ClientAppActor pid=144503)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.06607887893915176, accuracy 0.204
[36m(ClientAppActor pid=144503)[0m [Client 3] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.06502433121204376, accuracy 0.23025
[36m(ClientAppActor pid=144503)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=144503)[0m [Client 8] evaluate, config: {}
[36m(ClientAppActor pid=144503)[0m [Client 9] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=144503)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.059500642120838165, accuracy 0.296
[36m(ClientAppActor pid=144503)[0m [Client 5] fit, config: {}
[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.059205763041973114, accuracy 0.28975
[36m(ClientAppActor pid=144503)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.05798201262950897, accuracy 0.328
[36m(ClientAppActor pid=144503)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=144503)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=144503)[0m [Client 8] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=144503)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.05443347245454788, accuracy 0.3645
[36m(ClientAppActor pid=144503)[0m [Client 6] fit, config: {}
[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.053423840552568436, accuracy 0.37125
[36m(ClientAppActor pid=144503)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=144503)[0m Epoch 1: train loss 0.05419186130166054, accuracy 0.37
[36m(ClientAppActor pid=144503)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=144503)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=144503)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 86.83s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06311487909158071
[92mINFO [0m:      		round 2: 0.056315647681554164
[92mINFO [0m:      		round 3: 0.053133439938227334
[92mINFO [0m:      


### Sending/receiving arbitrary values to/from clients

In [13]:
class FlowerClient(NumPyClient):
    def __init__(self, pid, net, trainloader, valloader):
        self.pid = pid  # partition ID of a client
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.pid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        # Read values from config
        server_round = config["server_round"]
        local_epochs = config["local_epochs"]

        # Use values provided by the config
        print(f"[Client {self.pid}, round {server_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.pid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [14]:
def fit_config(server_round: int):
    """Return training configuration dict for each round.

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "server_round": server_round,  # The current round of federated learning
        "local_epochs": 1 if server_round < 2 else 2,
    }
    return config

In [15]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
        evaluate_fn=evaluate,
        on_fit_config_fn=fit_config,  # Pass the fit_config function
    )
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)
"""# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)
"""

'# Run simulation\nrun_simulation(\n    server_app=server,\n    client_app=client,\n    num_supernodes=NUM_PARTITIONS,\n    backend_config=backend_config,\n)\n'

### Scaling federated learning

In [16]:
NUM_PARTITIONS = 1000

In [17]:
def fit_config(server_round: int):
    config = {
        "server_round": server_round,
        "local_epochs": 3,
    }
    return config


def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.025,  # Train on 25 clients (each round)
        fraction_evaluate=0.05,  # Evaluate on 50 clients (each round)
        min_fit_clients=20,
        min_evaluate_clients=40,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
        on_fit_config_fn=fit_config,
    )
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 20 clients (out of 1000)


[36m(ClientAppActor pid=145791)[0m [Client 6, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11524901539087296, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11437729746103287, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.1138882040977478, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m [Client 21, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11517267674207687, accuracy 0.2
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.1130889430642128, accuracy 0.2
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.11164098232984543, accuracy 0.2
[36m(ClientAppActor pid=145791)[0m [Client 68, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11540628969669342, accuracy 0.1
[36m(ClientAppActor pid=14579

[92mINFO [0m:      aggregate_fit: received 20 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=145791)[0m [Client 994, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11426031589508057, accuracy 0.1
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11396493762731552, accuracy 0.1
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.11276853084564209, accuracy 0.2
[36m(ClientAppActor pid=145791)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 15] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 36] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 274] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 282] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 319] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 328] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 410] eval

[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 25 clients (out of 1000)


[36m(ClientAppActor pid=145791)[0m [Client 212] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 217, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11544045060873032, accuracy 0.1
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11386077851057053, accuracy 0.1
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.11296092718839645, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m [Client 220, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11475949734449387, accuracy 0.075
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11365550011396408, accuracy 0.075
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.11274579912424088, accuracy 0.1
[36m(ClientAppActor pid=145791)[0m [Client 240, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1

[92mINFO [0m:      aggregate_fit: received 25 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=145791)[0m [Client 765, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11441119015216827, accuracy 0.15
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11455817520618439, accuracy 0.15
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.11151323467493057, accuracy 0.15
[36m(ClientAppActor pid=145791)[0m [Client 176] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 288] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 332] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 351] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 418] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 490] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 502] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 588] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 

[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 25 clients (out of 1000)


[36m(ClientAppActor pid=145791)[0m [Client 225] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 75, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11550428718328476, accuracy 0.1
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11364906281232834, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.11327391862869263, accuracy 0.225
[36m(ClientAppActor pid=145791)[0m [Client 243, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11311212927103043, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11106453090906143, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.1100773960351944, accuracy 0.125
[36m(ClientAppActor pid=145791)[0m [Client 256, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch

[92mINFO [0m:      aggregate_fit: received 25 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=145791)[0m [Client 475, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=145791)[0m Epoch 1: train loss 0.11326520889997482, accuracy 0.2
[36m(ClientAppActor pid=145791)[0m Epoch 2: train loss 0.11073123663663864, accuracy 0.2
[36m(ClientAppActor pid=145791)[0m Epoch 3: train loss 0.10847633332014084, accuracy 0.2
[36m(ClientAppActor pid=145791)[0m [Client 58] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 116] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 171] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 178] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 192] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 194] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 199] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 254] evaluate, config: {}
[36m(ClientAppActor pid=145791)[0m [Client 444]

[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 817.18s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.23007594823837277
[92mINFO [0m:      		round 2: 0.23008259344100943
[92mINFO [0m:      		round 3: 0.22783266162872312
[92mINFO [0m:      


[36m(ClientAppActor pid=145791)[0m [Client 145] evaluate, config: {}
