In [1]:
# Build a strategy from scratch
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision

zsh:1: no matches found: flwr[simulation]


In [2]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

  from .autonotebook import tqdm as notebook_tqdm
2025-09-04 12:29:53,094	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Training on cpu
Flower 1.20.0 / PyTorch 2.2.2


In [3]:
def load_datasets(partition_id, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

In [4]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [None]:
# This code defines the client-side logic for federated learning using the Flower framework.
# 
# 1. The `FlowerClient` class inherits from `NumPyClient` and implements the required methods for a federated client:
#    - `__init__`: Initializes the client with its partition ID, neural network, and data loaders for training and validation.
#    - `get_parameters`: Returns the current model parameters. This is called by the server to get the client's model weights.
#    - `fit`: Receives global model parameters from the server, updates the local model, trains it for one epoch on the client's local data, and returns the updated parameters and the number of training examples.
#    - `evaluate`: Receives global model parameters, updates the local model, evaluates it on the client's validation data, and returns the loss, number of validation examples, and accuracy.
#
# 2. The `client_fn` function is a factory that creates a new `FlowerClient` instance for each client process. It:
#    - Instantiates a new model and moves it to the appropriate device (CPU or GPU).
#    - Retrieves the partition ID and number of partitions from the context.
#    - Loads the training and validation data for this client partition.
#    - Returns the client as a Flower-compatible client object.
#
# 3. Finally, a `ClientApp` is created using the `client_fn`, which is the entry point for running the client in a federated learning experiment.

class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [6]:
NUM_PARTITIONS = 10


def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    # If no strategy is provided, by default, ServerAppComponents will use FedAvg
    return ServerAppComponents(config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1}}

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=4048)[0m [Client 8] get_parameters
[36m(ClientAppActor pid=4048)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=4048)[0m Epoch 1: train loss 0.06604356318712234, accuracy 0.212
[36m(ClientAppActor pid=4048)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=4043)[0m [Client 5] fit, config: {}
[36m(ClientAppActor pid=4048)[0m Epoch 1: train loss 0.06568748503923416, accuracy 0.218
[36m(ClientAppActor pid=4043)[0m Epoch 1: train loss 0.06581880897283554, accuracy 0.20425
[36m(ClientAppActor pid=4048)[0m [Client 9] fit, config: {}[32m [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=4048)[0m [Client 0] evaluate, config: {}




[36m(ClientAppActor pid=4048)[0m Epoch 1: train loss 0.06530498713254929, accuracy 0.224[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=4046)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=4043)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=4041)[0m [Client 6] evaluate, config: {}[32m [repeated 9x across cluster][0m
[36m(ClientAppActor pid=4043)[0m Epoch 1: train loss 0.05664949491620064, accuracy 0.3255
[36m(ClientAppActor pid=4045)[0m [Client 6] fit, config: {}[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=4041)[0m Epoch 1: train loss 0.05793365463614464, accuracy 0.3225


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=4044)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=4044)[0m [Client 8] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=4044)[0m Epoch 1: train loss 0.05727699398994446, accuracy 0.33125[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=4041)[0m [Client 1] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)
[36m(ClientAppActor pid=4045)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/uoft-cs/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/.huggingface.yaml
[36m(ClientAppActor pid=4045)[0m Retrying in 1s [Retry 1/5].
[36m(ClientAppActor pid=4045)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/uoft-cs/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/.huggingface.yaml[32m [repeated 24x across cluster][0m
[36m(ClientAppActor pid=4045)[0m Retrying in 8s [Retry 4/5].[32m [repeated 24x across cluster][0m
[36m(ClientAppActor pid=4047)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/dataset_infos.json[32m [repeated 8

[36m(ClientAppActor pid=4046)[0m [Client 6] fit, config: {}
[36m(ClientAppActor pid=4041)[0m [Client 9] evaluate, config: {}[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=4044)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=4046)[0m Epoch 1: train loss 0.05279345437884331, accuracy 0.37725


[36m(ClientAppActor pid=4046)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 13x across cluster][0m
[36m(ClientAppActor pid=4046)[0m Retrying in 2s [Retry 2/5].[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=4046)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=4046)[0m Retrying in 8s [Retry 4/5].[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=4046)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=4046)[0m Retrying in 8s [Retry 5/5].[32m [repeated 3x across cluster][0m


[36m(ClientAppActor pid=4045)[0m [Client 2] fit, config: {}[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=4041)[0m Epoch 1: train loss 0.05397529527544975, accuracy 0.355[32m [repeated 6x across cluster][0m


[36m(ClientAppActor pid=4046)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/cifar10.py[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=4046)[0m Retrying in 1s [Retry 1/5].[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=4046)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/cifar10.py[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=4046)[0m Retrying in 8s [Retry 4/5].[32m [repeated 6x across cluster][0m


[36m(ClientAppActor pid=4046)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=4045)[0m Epoch 1: train loss 0.0542261116206646, accuracy 0.35625
[36m(ClientAppActor pid=4043)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=4046)[0m Epoch 1: train loss 0.053807150572538376, accuracy 0.36625
[36m(ClientAppActor pid=4047)[0m [Client 4] evaluate, config: {}


[36m(ClientAppActor pid=4043)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/cifar10.py
[36m(ClientAppActor pid=4043)[0m Retrying in 8s [Retry 4/5].


[36m(ClientAppActor pid=4043)[0m Epoch 1: train loss 0.053080376237630844, accuracy 0.37225


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 115.68s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06396813477277756
[92mINFO [0m:      		round 2: 0.054952736127376556
[92mINFO [0m:      		round 3: 0.052292625081539146
[92mINFO [0m:      


[36m(ClientAppActor pid=4047)[0m [Client 8] evaluate, config: {}[32m [repeated 9x across cluster][0m


In [7]:
from typing import Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg


class FedCustom(Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients

    def __repr__(self) -> str:
        return "FedCustom"

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        return ndarrays_to_parameters(ndarrays)

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Create custom configs
        n_clients = len(clients)
        half_clients = n_clients // 2
        standard_config = {"lr": 0.001}
        higher_lr_config = {"lr": 0.003}
        fit_configurations = []
        for idx, client in enumerate(clients):
            if idx < half_clients:
                fit_configurations.append((client, FitIns(parameters, standard_config)))
            else:
                fit_configurations.append(
                    (client, FitIns(parameters, higher_lr_config))
                )
        return fit_configurations

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""

        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]
        parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
        metrics_aggregated = {}
        return parameters_aggregated, metrics_aggregated

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate global model parameters using an evaluation function."""

        # Let's assume we won't perform the global model evaluation on the server side.
        return None

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients

In [8]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(
        config=config,
        strategy=FedCustom(),  # <-- pass the new strategy here
    )


# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=13241)[0m [Client 7] get_parameters
[36m(ClientAppActor pid=13241)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=13241)[0m Epoch 1: train loss 0.06429890543222427, accuracy 0.21975
[36m(ClientAppActor pid=13241)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=13234)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=13241)[0m Epoch 1: train loss 0.06471852213144302, accuracy 0.215
[36m(ClientAppActor pid=13234)[0m Epoch 1: train loss 0.06481718271970749, accuracy 0.213
[36m(ClientAppActor pid=13241)[0m [Client 9] fit, config: {}[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=13241)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=13241)[0m Epoch 1: train loss 0.0644933432340622, accuracy 0.2155[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=13237)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=13238)[0m [Client 0] fit, config: {}




[36m(ClientAppActor pid=13238)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m
[36m(ClientAppActor pid=13237)[0m Epoch 1: train loss 0.05834822356700897, accuracy 0.302
[36m(ClientAppActor pid=13236)[0m Epoch 1: train loss 0.05672045424580574, accuracy 0.323
[36m(ClientAppActor pid=13236)[0m [Client 8] fit, config: {}[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=13237)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=13237)[0m Epoch 1: train loss 0.0578608475625515, accuracy 0.31775[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=13236)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=13237)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)
[36m(ClientAppActor pid=13237)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md
[36m(ClientAppActor pid=13237)[0m Retrying in 1s [Retry 1/5].
[36m(ClientAppActor pid=13237)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 24x across cluster][0m
[36m(ClientAppActor pid=13237)[0m Retrying in 8s [Retry 4/5].[32m [repeated 24x across cluster][0m
[36m(ClientAppActor pid=13237)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=13237)[0m Retrying in 8s [Retry 5/5].[32m [repeated 8x across cluster][0m
[36m(ClientAppActor

[36m(ClientAppActor pid=13240)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=13236)[0m [Client 8] evaluate, config: {}[32m [repeated 8x across cluster][0m


[36m(ClientAppActor pid=13239)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/.huggingface.yaml[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=13239)[0m Retrying in 8s [Retry 4/5].[32m [repeated 7x across cluster][0m


[36m(ClientAppActor pid=13240)[0m Epoch 1: train loss 0.053540877997875214, accuracy 0.3645
[36m(ClientAppActor pid=13240)[0m [Client 8] fit, config: {}[32m [repeated 8x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=13240)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=13241)[0m Epoch 1: train loss 0.05401288717985153, accuracy 0.3695[32m [repeated 9x across cluster][0m




[36m(ClientAppActor pid=13241)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 110.04s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06156734501123427
[92mINFO [0m:      		round 2: 0.05566076983213425
[92mINFO [0m:      		round 3: 0.052821533501148224
[92mINFO [0m:      


[36m(ClientAppActor pid=13241)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m


In [9]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for just 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(
        config=config,
        strategy=FedCustom(),  # <-- pass the new strategy here
    )


# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=14902)[0m [Client 7] get_parameters
[36m(ClientAppActor pid=14902)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=14902)[0m Epoch 1: train loss 0.06465314328670502, accuracy 0.2225
[36m(ClientAppActor pid=14902)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=14901)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=14902)[0m Epoch 1: train loss 0.06467147171497345, accuracy 0.22175
[36m(ClientAppActor pid=14901)[0m Epoch 1: train loss 0.06458563357591629, accuracy 0.23175


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=14902)[0m [Client 9] fit, config: {}[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=14902)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=14902)[0m Epoch 1: train loss 0.06452812254428864, accuracy 0.23975[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=14898)[0m [Client 5] fit, config: {}
[36m(ClientAppActor pid=14895)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=14899)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m




[36m(ClientAppActor pid=14897)[0m Epoch 1: train loss 0.056990284472703934, accuracy 0.334
[36m(ClientAppActor pid=14899)[0m [Client 0] fit, config: {}[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=14898)[0m Epoch 1: train loss 0.05775522068142891, accuracy 0.31225


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=14897)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=14898)[0m [Client 9] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=14898)[0m Epoch 1: train loss 0.05718472972512245, accuracy 0.33375[32m [repeated 8x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=14898)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m
[36m(ClientAppActor pid=14902)[0m [Client 2] fit, config: {}[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=14895)[0m Epoch 1: train loss 0.05424690991640091, accuracy 0.35325
[36m(ClientAppActor pid=14897)[0m Epoch 1: train loss 0.054067060351371765, accuracy 0.369


[36m(ClientAppActor pid=14897)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md
[36m(ClientAppActor pid=14897)[0m Retrying in 1s [Retry 1/5].
[36m(ClientAppActor pid=14897)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=14897)[0m Retrying in 8s [Retry 4/5].[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=14897)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=14897)[0m Retrying in 8s [Retry 5/5].[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=14897)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/cifar10.py[32m [repeated 3x across cluster][0m
[36m(Cli

[36m(ClientAppActor pid=14897)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=14902)[0m Epoch 1: train loss 0.05395974963903427, accuracy 0.35525[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=14900)[0m [Client 9] fit, config: {}


[36m(ClientAppActor pid=14900)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/.huggingface.yaml[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=14900)[0m Retrying in 4s [Retry 3/5].[32m [repeated 5x across cluster][0m
[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=14902)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=14900)[0m Epoch 1: train loss 0.05394162982702255, accuracy 0.369[32m [repeated 2x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 104.92s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.062151119661331175
[92mINFO [0m:      		round 2: 0.05519558441638948
[92mINFO [0m:      		round 3: 0.05231935760974884
[92mINFO [0m:      


[36m(ClientAppActor pid=14902)[0m [Client 8] evaluate, config: {}[32m [repeated 9x across cluster][0m


