In [2]:
!pip install -q "flwr[simulation]" "flwr-datasets[vision]" torch torchvision scipy


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
from collections import OrderedDict
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

  from .autonotebook import tqdm as notebook_tqdm
2025-09-04 12:58:31,563	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Training on cpu
Flower 1.20.0 / PyTorch 2.2.2


In [4]:
def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
    valloader = DataLoader(partition_train_test["test"], batch_size=32)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=32)
    return trainloader, valloader, testloader

In [5]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [6]:
class FlowerNumPyClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [7]:
def numpyclient_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerNumPyClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
numpyclient = ClientApp(client_fn=numpyclient_fn)

In [8]:
def server_fn(context: Context) -> ServerAppComponents:
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(config=config)


# Create ServerApp
server = ServerApp(server_fn=server_fn)

In [9]:
# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1}}

NUM_PARTITIONS = 10

# Run simulation
run_simulation(
    server_app=server,
    client_app=numpyclient,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=20874)[0m [Client 8] get_parameters
[36m(ClientAppActor pid=20874)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=20874)[0m Epoch 1: train loss 0.06559710204601288, accuracy 0.214
[36m(ClientAppActor pid=20874)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=20871)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=20867)[0m [Client 7] fit, config: {}




[36m(ClientAppActor pid=20874)[0m Epoch 1: train loss 0.06613316386938095, accuracy 0.21225
[36m(ClientAppActor pid=20869)[0m [Client 5] fit, config: {}[32m [repeated 5x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m




[36m(ClientAppActor pid=20874)[0m [Client 9] fit, config: {}




[36m(ClientAppActor pid=20867)[0m Epoch 1: train loss 0.06529438495635986, accuracy 0.2185


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=20874)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=20874)[0m Epoch 1: train loss 0.06570418179035187, accuracy 0.235[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=20867)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=20867)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m
[36m(ClientAppActor pid=20867)[0m Epoch 1: train loss 0.05954650789499283, accuracy 0.29925
[36m(ClientAppActor pid=20871)[0m [Client 3] fit, config: {}[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=20874)[0m Epoch 1: train loss 0.05962961167097092, accuracy 0.29525


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=20874)[0m [Client 9] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=20867)[0m Epoch 1: train loss 0.05954241007566452, accuracy 0.30375[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=20873)[0m [Client 7] evaluate, config: {}




[36m(ClientAppActor pid=20874)[0m Epoch 1: train loss 0.05907357111573219, accuracy 0.31725


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=20872)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=20873)[0m [Client 8] evaluate, config: {}[32m [repeated 9x across cluster][0m




[36m(ClientAppActor pid=20872)[0m Epoch 1: train loss 0.05596117302775383, accuracy 0.3385
[36m(ClientAppActor pid=20871)[0m [Client 4] fit, config: {}[32m [repeated 7x across cluster][0m


[36m(ClientAppActor pid=20872)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md
[36m(ClientAppActor pid=20872)[0m Retrying in 1s [Retry 1/5].
[36m(ClientAppActor pid=20872)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=20872)[0m Retrying in 8s [Retry 4/5].[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=20872)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=20872)[0m Retrying in 8s [Retry 5/5].[32m [repeated 2x across cluster][0m


[36m(ClientAppActor pid=20867)[0m Epoch 1: train loss 0.054733727127313614, accuracy 0.35125[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=20872)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=20873)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=20872)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=20872)[0m Epoch 1: train loss 0.05522831901907921, accuracy 0.35925[32m [repeated 2x across cluster][0m


[36m(ClientAppActor pid=20873)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md
[36m(ClientAppActor pid=20873)[0m Retrying in 8s [Retry 5/5].
[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 166.71s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06395405081510544
[92mINFO [0m:      		round 2: 0.05688847650289536
[92mINFO [0m:      		round 3: 0.053622226190567016
[92mINFO [0m:      


[36m(ClientAppActor pid=20873)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m


In [10]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)


class FlowerClient(Client):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.partition_id}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters = ndarrays_to_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.partition_id}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object
        parameters_updated = ndarrays_to_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.partition_id}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's
        parameters_original = ins.parameters
        ndarrays_original = parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [11]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=28680)[0m [Client 5] get_parameters
[36m(ClientAppActor pid=28680)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=28680)[0m Epoch 1: train loss 0.06568188220262527, accuracy 0.222
[36m(ClientAppActor pid=28680)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=28680)[0m Epoch 1: train loss 0.06524818390607834, accuracy 0.22875
[36m(ClientAppActor pid=28673)[0m [Client 7] fit, config: {}




[36m(ClientAppActor pid=28673)[0m Epoch 1: train loss 0.06572027504444122, accuracy 0.19975
[36m(ClientAppActor pid=28675)[0m [Client 5] fit, config: {}[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=28680)[0m [Client 3] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=28675)[0m Epoch 1: train loss 0.06601592898368835, accuracy 0.193[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=28679)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=28677)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=28679)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m
[36m(ClientAppActor pid=28677)[0m Epoch 1: train loss 0.059571534395217896, accuracy 0.30075
[36m(ClientAppActor pid=28675)[0m [Client 2] fit, config: {}[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=28679)[0m Epoch 1: train loss 0.05950388312339783, accuracy 0.295


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=28674)[0m [Client 9] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=28677)[0m Epoch 1: train loss 0.05947599560022354, accuracy 0.307[32m [repeated 7x across cluster][0m
[36m(ClientAppActor pid=28674)[0m [Client 0] evaluate, config: {}


[36m(ClientAppActor pid=28674)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md
[36m(ClientAppActor pid=28674)[0m Retrying in 1s [Retry 1/5].
[36m(ClientAppActor pid=28674)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=28674)[0m Retrying in 8s [Retry 4/5].[32m [repeated 6x across cluster][0m
[36m(ClientAppActor pid=28677)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/main/README.md[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=28677)[0m Retrying in 8s [Retry 5/5].[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=28674)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/cifar10.py[32m [repeated 3x across cluster][0m
[36m(Cli

[36m(ClientAppActor pid=28674)[0m Epoch 1: train loss 0.05889628455042839, accuracy 0.314
[36m(ClientAppActor pid=28674)[0m [Client 8] evaluate, config: {}[32m [repeated 8x across cluster][0m


[36m(ClientAppActor pid=28674)[0m HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/cifar10/resolve/0b2714987fa478483af9968de7c934580d0bb9a2/.huggingface.yaml[32m [repeated 3x across cluster][0m
[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=28677)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=28677)[0m [Client 9] evaluate, config: {}


[36m(ClientAppActor pid=28674)[0m Retrying in 1s [Retry 1/5].


[36m(ClientAppActor pid=28677)[0m Epoch 1: train loss 0.055597271770238876, accuracy 0.34875
[36m(ClientAppActor pid=28675)[0m [Client 3] fit, config: {}[32m [repeated 7x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)


[36m(ClientAppActor pid=28677)[0m Epoch 1: train loss 0.05474460870027542, accuracy 0.36575[32m [repeated 8x across cluster][0m
[36m(ClientAppActor pid=28676)[0m [Client 9] fit, config: {}[32m [repeated 2x across cluster][0m
[36m(ClientAppActor pid=28676)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=28676)[0m Epoch 1: train loss 0.05476615950465202, accuracy 0.36275


[92mINFO [0m:      aggregate_evaluate: received 10 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 144.83s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06479371778964997
[92mINFO [0m:      		round 2: 0.056666840946674346
[92mINFO [0m:      		round 3: 0.05364905495643616
[92mINFO [0m:      


[36m(ClientAppActor pid=28677)[0m [Client 9] evaluate, config: {}[32m [repeated 9x across cluster][0m


In [16]:
from io import BytesIO
from typing import cast

import numpy as np

from flwr.common.typing import NDArray, NDArrays, Parameters


def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:
    """Convert NumPy ndarrays to parameters object."""
    tensors = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]
    return Parameters(tensors=tensors, tensor_type="numpy.ndarray")


def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [sparse_bytes_to_ndarray(tensor) for tensor in parameters.tensors]


def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:
    """Serialize NumPy ndarray to bytes."""
    bytes_io = BytesIO()

    if len(ndarray.shape) > 1:
        # We convert our ndarray into a sparse matrix
        ndarray = torch.tensor(ndarray).to_sparse_csr()

        # And send it byutilizing the sparse matrix attributes
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.savez(
            bytes_io,  # type: ignore
            crow_indices=ndarray.crow_indices(),
            col_indices=ndarray.col_indices(),
            values=ndarray.values(),
            allow_pickle=False,
        )
    else:
        # WARNING: NEVER set allow_pickle to true.
        # Reason: loading pickled data can execute arbitrary code
        # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html
        np.save(bytes_io, ndarray, allow_pickle=False)
    return bytes_io.getvalue()


def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize NumPy ndarray from bytes."""
    bytes_io = BytesIO(tensor)
    # WARNING: NEVER set allow_pickle to true.
    # Reason: loading pickled data can execute arbitrary code
    # Source: https://numpy.org/doc/stable/reference/generated/numpy.load.html
    loader = np.load(bytes_io, allow_pickle=False)  # type: ignore

    if "crow_indices" in loader:
        # We convert our sparse matrix back to a ndarray, using the attributes we sent
        ndarray_deserialized = (
            torch.sparse_csr_tensor(
                crow_indices=loader["crow_indices"],
                col_indices=loader["col_indices"],
                values=loader["values"],
            )
            .to_dense()
            .numpy()
        )
    else:
        ndarray_deserialized = loader
    return cast(NDArray, ndarray_deserialized)

In [13]:
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
)


class FlowerClient(Client):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f"[Client {self.partition_id}] get_parameters")

        # Get parameters as a list of NumPy ndarray's
        ndarrays: List[np.ndarray] = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters = ndarrays_to_sparse_parameters(ndarrays)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return GetParametersRes(
            status=status,
            parameters=parameters,
        )

    def fit(self, ins: FitIns) -> FitRes:
        print(f"[Client {self.partition_id}] fit, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        # Update local model, train, get updated parameters
        set_parameters(self.net, ndarrays_original)
        train(self.net, self.trainloader, epochs=1)
        ndarrays_updated = get_parameters(self.net)

        # Serialize ndarray's into a Parameters object using our custom function
        parameters_updated = ndarrays_to_sparse_parameters(ndarrays_updated)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return FitRes(
            status=status,
            parameters=parameters_updated,
            num_examples=len(self.trainloader),
            metrics={},
        )

    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        print(f"[Client {self.partition_id}] evaluate, config: {ins.config}")

        # Deserialize parameters to NumPy ndarray's using our custom function
        parameters_original = ins.parameters
        ndarrays_original = sparse_parameters_to_ndarrays(parameters_original)

        set_parameters(self.net, ndarrays_original)
        loss, accuracy = test(self.net, self.valloader)

        # Build and return response
        status = Status(code=Code.OK, message="Success")
        return EvaluateRes(
            status=status,
            loss=float(loss),
            num_examples=len(self.valloader),
            metrics={"accuracy": float(accuracy)},
        )


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()

In [None]:

parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)


ModuleNotFoundError: No module named 'symbol'