In [70]:
from os import cpu_count

from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

import flwr as fl

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

Training on cuda using PyTorch 2.2.1+cu121 and Flower 1.7.0


## Customize The FLOWER Client Class

### Step 0: Preparation

In [71]:
# hyper-parameters

NUM_CLIENT = 10
EPOCHS_CLIENT = 1
BATCH_SIZE = 32
NUM_WORKER = cpu_count()


### Loading The Dataset And Partitioning ###

def load_datasets(num_clients: int):
    
    # image transformation
    img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    
    # loading the torchvision dataset
    train_dataset = CIFAR10(
        root='./dataset',
        train=True,
        transform=img_transform,
        download=True
    )

    test_dataset = CIFAR10(
        root='./dataset',
        train=False,
        transform=img_transform,
        download=True
    )

    # split training dataset into partitions
    partition_size = len(train_dataset) // num_clients
    lengths = [partition_size] * num_clients
    train_part_dataset = random_split(
        dataset=train_dataset,
        lengths=lengths,
        generator=torch.Generator().manual_seed(42)
        )
    
    # create all client train and val dataloaders
    train_dataloaders = []
    val_dataloaders = []

    # split partition into train and val datasets and wrap into torch dataloaders
    for dataset in train_part_dataset:
        # split the partition
        split_dataset = random_split(
            dataset=dataset,
            lengths=[0.8, 0.2], # train & val dataset split fraction
            generator=torch.Generator().manual_seed(42)
        )
        # wrap with torch dataloader and add to dataloader list
        partition_train_dl = DataLoader(
            dataset=split_dataset[0],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKER
        )
        partition_val_dl = DataLoader(
            dataset=split_dataset[1],
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKER
        )
        train_dataloaders.append(partition_train_dl)
        val_dataloaders.append(partition_val_dl)
        
    # create test dataloader from the test split (Dataset) with transform function
    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKER
    )

    # return all the train (partitioned), val (partitioned) & test dataloaders
    return train_dataloaders, val_dataloaders, test_dataloader



In [72]:
train_dataloaders, val_dataloaders, test_dataloader = load_datasets(NUM_CLIENT)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:29<00:00, 5827991.25it/s]


Extracting ./dataset/cifar-10-python.tar.gz to ./dataset
Files already downloaded and verified


In [73]:
### Defining The Model ###

class TinyVGG(nn.Module):
    """Creates the TinyVGG architecture for 32*32 Image Data"""
        
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=10, out_channels=10, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=10, out_channels=10, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=10, out_channels=10, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features= 10 * 8 * 8, out_features=10)
        )

    def forward(self, x: torch.tensor) -> torch.tensor:
        return self.classifier(self.block2(self.block1(x)))

### Train and Test Function For FLOWER Clients
    
# calculate accuracy
def accuracy_fn(y_pred: torch.tensor, y_true: torch.tensor) -> float:
    """Calculates the accuracy of a model on given predictions

    Args:
        y_pred: predicted labels
        y_true: true labels
    
    Returns:
        A float value which is the calculated accuracy.
    """
    return ((torch.eq(y_pred, y_true).sum().item() / len(y_true)) * 100)

# fit the model on training data
def train(model: torch.nn.Module,
          data_loader: torch.utils.data.DataLoader,
          epochs: int,
          device: torch.device,
          verbose=False,
          loss_fn: torch.nn.Module = None,
          optimizer: torch.optim.Optimizer = None) -> Tuple[float, float]:
    """Trains a PyTorch model for the given epochs.

    Turns a target PyTorch model to training mode and then
    runs through all of the required training steps (forward
    pass, loss calculation, optimizer step).

    Args:
        model: A PyTorch model to be trained.
        dataloader: A DataLoader instance for the model to be trained on.
        epochs: Epochs.
        device: A target device to compute on (e.g. "cuda" or "cpu").
        verbose: A boolean value to see the model metrics (loss and accuracy)
        loss_fn: A PyTorch loss function to minimize.
        optimizer: A PyTorch optimizer to help minimize the loss function.
    
    Returns:
        A tuple of training loss and training accuracy metrics.
        In the form (train_loss, train_accuracy). For example:

        (0.0223, 0.8985)
    """

    # optimizer and criterion (loss_fn) if None given
    if loss_fn == None:
        loss_fn = torch.nn.CrossEntropyLoss()
    if optimizer == None:
        optimizer = torch.optim.SGD(model.parameters())

    model.train() # model in train mode
    total_epoch_loss, total_epoch_acc = 0, 0
    for epoch in range(epochs):
        
        train_loss, train_acc = 0, 0
        
        for X, y in data_loader:
            # get data to device
            X = X.to(device)
            y = y.to(device)

            # forward pass
            y_logit = model(X)
            loss = loss_fn(y_logit, y)        

            # backward pass
            optimizer.zero_grad() # empty param's grad
            loss.backward() # backward propagation
            optimizer.step() # updata params (take the gradient descent step)

            # Metrics
            # calculate loss and accuracy per batch
            train_loss += loss.item() * len(y)
            y_pred_labels = torch.argmax(y_logit, dim=1)
            train_acc += accuracy_fn(y_pred_labels, y)
        
        # per epoch
        train_loss /= len(data_loader.dataset)
        train_acc /= len(data_loader)
        
        if verbose:
            print(f"Epoch {epoch+1} | Train Loss {train_loss:.4f} | Train Acc {train_acc:.2f}")

        # for all epochs
        total_epoch_loss += train_loss
        total_epoch_acc += train_acc
    
    return (total_epoch_loss / epochs, total_epoch_acc / epochs)
    

# test the model on test data
def test(model: torch.nn.Module,
         data_loader: torch.utils.data.DataLoader,
         device: torch.device,
         loss_fn: torch.nn.Module=None) -> Tuple[float, float]:
    """Tests a PyTorch model for the given epochs.

    Turns a target PyTorch model to "eval" mode and then performs
    a forward pass on a testing dataset.

    Args:
        model: A PyTorch model to be tested.
        data_loader: A DataLoader instance for the model to be tested on.
        device: A target device to compute on (e.g. "cuda" or "cpu").
        loss_fn: A PyTorch loss function to calculate loss on the test data.

    Returns:
        A tuple of testing loss and testing accuracy metrics.
        In the form (test_loss, test_accuracy). For example:

        (0.0223, 0.8985)
    """

    test_loss, test_acc = 0, 0

    # criterion (loss_fn) if None given
    if loss_fn == None:
        loss_fn = torch.nn.CrossEntropyLoss()

    model.eval() # model in evaluation mode
    with torch.inference_mode():
        for X, y in data_loader:
            # get data to device
            X = X.to(device)
            y = y.to(device)
            
            # forward pss
            y_logit = model(X)
            loss = loss_fn(y_logit, y)

            # calculate loss and accuracy per batch
            test_loss += loss.item() * len(y)
            y_pred_labels = torch.argmax(y_logit, dim=1)
            test_acc += accuracy_fn(y_pred_labels, y)

    test_loss /= len(data_loader.dataset)
    test_acc /= len(data_loader)
    return (test_loss, test_acc)

### Updating Model Parameters (helper functions from client's perspective)

# de-serialized & set client parameters
def set_parameters(model: nn.Module, parameters: List[np.ndarray]):
    # de-serialize the ndarray to tensors
    parameters = [torch.from_numpy(np_arr).to(dtype=torch.float32, device=DEVICE) for np_arr in parameters]
    # match every weight with its model block
    param_dict = zip(model.state_dict().keys(), parameters)
    # convert the param_dict to ordered dict and load back into the model
    model.load_state_dict(OrderedDict(param_dict), strict=True)

# get serialized parameters from client
def get_parameters(model: nn.Module) -> List[np.ndarray]:
    # serialize the model weights into ndarray and return
    return [weights.cpu().numpy() for _, weights in model.state_dict().items()]

In [74]:
### client metric aggregation functions
from flwr.common import Metrics

# aggregate the metrics received from all the client's evaluate function
def eval_weighted_avg(metrics: List[Tuple[int, Metrics]]):
    #  multiply accuracy with each client's number of samples / or is it batch size??
    accuracies = [num_examples * m['test_accuracy'] for num_examples, m in metrics]
    num_samples = [num_examples for num_examples, _ in metrics]

    # aggregate and return the custom metrics (weighted avg)
    return {'test_accuracy': sum(accuracies) / sum(num_samples)}

# aggregate the metrics received from all the client's fit function
def fit_weighted_avg(metrics: List[Tuple[int, Metrics]]):
    #  multiply accuracy with each client's number of samples / or is it batch size??
    accuracies = [num_examples * m['train_accuracy'] for num_examples, m in metrics]
    num_samples = [num_examples for num_examples, _ in metrics]
    # for loss
    losses = [num_examples * m['train_loss'] for num_examples, m in metrics]

    # aggregate and return the custom metrics (weighted avg)
    return {'train_loss': sum(losses) / sum(num_samples), 'train_accuracy': sum(accuracies) / sum(num_samples)}


### server-side parameter evaluation function
# evaluate on the server
def eval_server(server_round: int, params: fl.common.NDArray, config: Dict[str, fl.common.Scalar]):
    # create server model
    model = TinyVGG().to(DEVICE)
    # load validation dataloader
    val_dl = val_dataloaders[0]
    # update model with latest parameters
    set_parameters(model=model, parameters=params)
    # perform centralized evaluation
    loss, accuracy = test(model=model, data_loader=val_dl, device=DEVICE)
    print(f'Server round {server_round}: evaluation loss {loss} | accuracy {accuracy}')
    return loss, {'accuracy': accuracy}

### sending/receiving arbitrary values to/from clients
# configure/set client-side params from server side
# training configuration from server to client
def fit_config(server_round: int):
    """Return training configuration dict form server to client for each round.

    Perform rounds of training with one local epoch.
    """
    config = {
        'server_round': server_round,
        'local_epochs': 1
    }
    return config


### Step 1: Revisiting FLOWER NumPy Client

Whenever a client is called to do some work, `simulation` function calls `client_fn` to create an instance of `NumPyClient` object with client specific dataloaders. However, in reality, Flower wraps the object to make it look like a subclass of `fl.client.Client` and not `fl.client.NumPyClient`.

`fl.client.NumPyClient` is just another abstraction over `fl.client.Client`. Therefore, now we'll see how we can directly subclass from `fl.client.Client` and use it.

The biggest difference is that `Client` baseclass will expect us to take care of parameter serialization and de-serialization by ourselves.
And remember that serialization and de-serialization needs to be done on both client and server side (both will receive and send serialized parameters)

### Step 2: Moving from NumPyClient to Client

In [75]:
from flwr.common import(
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
    ndarrays_to_parameters,
    parameters_to_ndarrays
)

class FlowerClient(fl.client.Client):
    def __init__(self, cid, model, train_dl, val_dl) -> None:
        super().__init__()

        self.cid = cid
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f'[Client {self.cid}] get_parameters')
        # get client parameters as a list of ndarrays
        ndarray: List[np.ndarray] = get_parameters(self.model)
        # serialize ndarrays into a parameter object
        parameters = ndarrays_to_parameters(ndarrays=ndarray)
        # build and return response/client updates
        status = Status(code=Code.OK, message='success')
        return GetParametersRes(
            status=status,
            parameters=parameters
        )
    
    def fit(self, ins: FitIns) -> FitRes:

        # flower_params (de-serialize) -> np.ndarray -> model_params -> np.ndarray -> flower_params (serialize)

        print(f'[Client {self.cid}] fit, config: {ins.config}')

        # de-serialize received parameters to numpy ndarray
        ndarr_parameter = parameters_to_ndarrays(ins.parameters)

        # update local model parameters
        set_parameters(self.model, ndarr_parameter)

        # train & get updated parameters in ndarray
        train_loss, train_accuracy = train(
            model=self.model,
            data_loader=self.train_dl,
            epochs=EPOCHS_CLIENT,
            device=DEVICE
        )
        updated_params_ndarr = get_parameters(self.model)

        # serialize from ndarray into params object
        params_updated = ndarrays_to_parameters(updated_params_ndarr)
        
        # build and return response/client updates
        status = Status(code=Code.OK, message='success')
        return FitRes(
            status=status,
            parameters=params_updated,
            num_examples=len(self.train_dl),
            metrics={'train_loss': train_loss,
                     'train_accuracy': train_accuracy}
        )
    
    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        
        # flower_params (de-serialize) -> np.ndarray -> model_params -> test()

        print(f'[Client {self.cid}] evaluate, config: {ins.config}')

        # de-serialize received parameters to numpy ndarray
        ndarr_parameter = parameters_to_ndarrays(ins.parameters)

        # update local model parameters
        set_parameters(self.model, ndarr_parameter)
        
        # test & get updated parameters in ndarray
        test_loss, test_accuracy = test(
            model=self.model,
            data_loader=self.val_dl,
            device=DEVICE
        )

        # build and return response/client updates
        status = Status(code=Code.OK, message='success')
        return EvaluateRes(
            status=status,
            loss=test_loss,
            num_examples=len(self.val_dl),
            metrics={'test_accuracy': test_accuracy}
        )
    
def client_fn(cid) -> FlowerClient:
    model = TinyVGG().to(DEVICE)
    train_dl = train_dataloaders[int(cid)]
    val_dl = val_dataloaders[int(cid)]

    return FlowerClient(
        cid=cid,
        model=model,
        train_dl=train_dl,
        val_dl=val_dl
    )

In [76]:
### Start The Training

# client resources (allocate cpus and gpus)
client_resources = {'num_cpus': NUM_WORKER//NUM_CLIENT, 'num_gpus': 1 if torch.cuda.is_available() else 0}

# start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    # strategy=strategy, # <-- custom strategy
    client_resources=client_resources
)

INFO flwr 2024-03-29 11:42:44,517 | app.py:178 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)


2024-03-29 11:42:47,660	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2024-03-29 11:42:48,437 | app.py:213 | Flower VCE: Ray initialized with resources: {'node:__internal_head__': 1.0, 'object_store_memory': 7411503513.0, 'GPU': 1.0, 'accelerator_type:G': 1.0, 'memory': 14823007028.0, 'node:10.255.93.233': 1.0, 'CPU': 16.0}
INFO flwr 2024-03-29 11:42:48,437 | app.py:219 | Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO flwr 2024-03-29 11:42:48,437 | app.py:242 | Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 1}
INFO flwr 2024-03-29 11:42:48,443 | app.py:288 | Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
INFO flwr 2024-03-29 11:42:48,444 | server.py:89 | Initializing global parameters
INFO flwr 2024-03-29 11:42:48,444 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2024-03-29 11:42:50,196 | server.py:280 | Received initial paramete

[2m[36m(DefaultActor pid=904736)[0m [Client 1] get_parameters
[2m[36m(DefaultActor pid=904736)[0m [Client 1] fit, config: {}
[2m[36m(DefaultActor pid=904736)[0m [Client 0] fit, config: {}


DEBUG flwr 2024-03-29 11:42:53,051 | server.py:236 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2024-03-29 11:42:53,052 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=904736)[0m [Client 1] evaluate, config: {}
[2m[36m(DefaultActor pid=904736)[0m [Client 0] evaluate, config: {}


DEBUG flwr 2024-03-29 11:42:55,192 | server.py:187 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2024-03-29 11:42:55,193 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=904736)[0m [Client 0] fit, config: {}
[2m[36m(DefaultActor pid=904736)[0m [Client 1] fit, config: {}


DEBUG flwr 2024-03-29 11:42:58,005 | server.py:236 | fit_round 2 received 2 results and 0 failures
DEBUG flwr 2024-03-29 11:42:58,006 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=904736)[0m [Client 1] evaluate, config: {}
[2m[36m(DefaultActor pid=904736)[0m [Client 0] evaluate, config: {}


DEBUG flwr 2024-03-29 11:43:00,140 | server.py:187 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2024-03-29 11:43:00,140 | server.py:222 | fit_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=904736)[0m [Client 0] fit, config: {}
[2m[36m(DefaultActor pid=904736)[0m [Client 1] fit, config: {}


DEBUG flwr 2024-03-29 11:43:02,810 | server.py:236 | fit_round 3 received 2 results and 0 failures
DEBUG flwr 2024-03-29 11:43:02,811 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=904736)[0m [Client 0] evaluate, config: {}
[2m[36m(DefaultActor pid=904736)[0m [Client 1] evaluate, config: {}


DEBUG flwr 2024-03-29 11:43:04,936 | server.py:187 | evaluate_round 3 received 2 results and 0 failures
INFO flwr 2024-03-29 11:43:04,936 | server.py:153 | FL finished in 14.739556103013456
INFO flwr 2024-03-29 11:43:04,937 | app.py:226 | app_fit: losses_distributed [(1, 2.30242875957489), (2, 2.302306652069092), (3, 2.302180121421814)]
INFO flwr 2024-03-29 11:43:04,937 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO flwr 2024-03-29 11:43:04,937 | app.py:228 | app_fit: metrics_distributed {}
INFO flwr 2024-03-29 11:43:04,937 | app.py:229 | app_fit: losses_centralized []
INFO flwr 2024-03-29 11:43:04,938 | app.py:230 | app_fit: metrics_centralized {}


History (loss, distributed):
	round 1: 2.30242875957489
	round 2: 2.302306652069092
	round 3: 2.302180121421814

### Step 3: Custom Serialization

Serialization is an essential step in FL as server and clients rely heavily on internet communication for training purpose.
Therefore, writing custom serialization/deserialization functions to convert `ndarray` to `sparse matrices` and vice versa to save bandwidth.

In [77]:
from io import BytesIO
from typing import cast

from flwr.common.typing import NDArray, NDArrays, Parameters

# ndarrays -> sparse bytes
def ndarrays_to_sparse_parameters(ndarrays: NDArrays) -> Parameters:
    """Convert NumPy ndarrays to parameters object."""
    bytes_lst = [ndarray_to_sparse_bytes(ndarray) for ndarray in ndarrays]
    return Parameters(
        tensors=bytes_lst,
        tensor_type='numpy.ndarray'
    )

# sparse bytes -> ndarrays
def sparse_parameters_to_ndarrays(parameters: Parameters) -> NDArrays:
    """Convert parameters object to NumPy ndarrays."""
    return [sparse_bytes_to_ndarray(bytes) for bytes in parameters.tensors]

# convert a single numpy ndarray to a sparse matrix in bytes
def ndarray_to_sparse_bytes(ndarray: NDArray) -> bytes:
    """Serialize a NumPy ndarray to bytes."""
    
    bytes_io = BytesIO() # a file object to store the bytes

    if len(ndarray.shape) > 1:
        # convert ndarray to a torch sparse matrix
        ndarray = torch.tensor(ndarray).to_sparse_csr()
        # save to file
        np.savez(
            file=bytes_io,
            crow_indices=ndarray.crow_indices(),
            col_indices=ndarray.col_indices(),
            values=ndarray.values(),
            allow_pickle=False
        )
    else:
        np.save(
            file=bytes_io,
            arr=ndarray,
            allow_pickle=False
            )

    return bytes_io.getvalue()

# convert a sparse matrix in bytes to a numpy ndarray
def sparse_bytes_to_ndarray(tensor: bytes) -> NDArray:
    """Deserialize bytes to a NumPy ndarray"""

    bytes_io = BytesIO(tensor) # write bytes to a file object
    # load the sparse matrix objects from file
    loader = np.load(bytes_io, allow_pickle=False)

    # convert the torch sparse matrix back to ndarray
    if 'crow_indices' in loader:
        ndarr_deserialized = torch.sparse_csr_tensor(
            crow_indices=loader['crow_indices'],
            col_indices=loader['col_indices'],
            values=loader['values']
        ).to_dense().numpy()
    else:
        ndarr_deserialized = loader
    
    return cast(NDArray, ndarr_deserialized)



#### Client Side Implementation

In [78]:
from flwr.common import(
    Code,
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    GetParametersIns,
    GetParametersRes,
    Status,
    ndarrays_to_parameters,
    parameters_to_ndarrays
)

class FlowerClient(fl.client.Client):
    def __init__(self, cid, model, train_dl, val_dl) -> None:
        super().__init__()

        self.cid = cid
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl

    def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
        print(f'[Client {self.cid}] get_parameters')
        
        # get client parameters as a list of ndarrays
        ndarrays: List[np.ndarray] = get_parameters(self.model)
        
        # serialize ndarrays into a Parameter object (containing sparse matrix bytes)
        sparse_parameters = ndarrays_to_sparse_parameters(ndarrays)
        
        # build and return response/client updates
        status = Status(code=Code.OK, message='success')
        return GetParametersRes(
            status=status,
            parameters=sparse_parameters
        )
    
    def fit(self, ins: FitIns) -> FitRes:

        # flower_params (de-serialize) -> np.ndarray -> model_params -> np.ndarray -> flower_params (serialize)

        print(f'[Client {self.cid}] fit, config: {ins.config}')

        # de-serialize received parameters to numpy ndarray
        ndarr_parameter = sparse_parameters_to_ndarrays(ins.parameters)

        # update local model parameters
        set_parameters(self.model, ndarr_parameter)

        # train & get updated parameters in ndarray
        train_loss, train_accuracy = train(
            model=self.model,
            data_loader=self.train_dl,
            epochs=EPOCHS_CLIENT,
            device=DEVICE
        )
        updated_params_ndarr = get_parameters(self.model)

        # serialize from ndarray into params object
        params_updated = ndarrays_to_sparse_parameters(updated_params_ndarr)
        
        # build and return response/client updates
        status = Status(code=Code.OK, message='success')
        return FitRes(
            status=status,
            parameters=params_updated,
            num_examples=len(self.train_dl),
            metrics={'train_loss': train_loss,
                     'train_accuracy': train_accuracy}
        )
    
    def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
        
        # flower_params (de-serialize) -> np.ndarray -> model_params -> test()

        print(f'[Client {self.cid}] evaluate, config: {ins.config}')

        # de-serialize received parameters to numpy ndarray
        ndarr_parameter = sparse_parameters_to_ndarrays(ins.parameters)

        # update local model parameters
        set_parameters(self.model, ndarr_parameter)
        
        # test & get updated parameters in ndarray
        test_loss, test_accuracy = test(
            model=self.model,
            data_loader=self.val_dl,
            device=DEVICE
        )

        # build and return response/client updates
        status = Status(code=Code.OK, message='success')
        return EvaluateRes(
            status=status,
            loss=test_loss,
            num_examples=len(self.val_dl),
            metrics={'test_accuracy': test_accuracy}
        )
    
def client_fn(cid) -> FlowerClient:
    model = TinyVGG().to(DEVICE)
    train_dl = train_dataloaders[int(cid)]
    val_dl = val_dataloaders[int(cid)]

    return FlowerClient(
        cid=cid,
        model=model,
        train_dl=train_dl,
        val_dl=val_dl
    )

#### Server Side Implementation

In [79]:
from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import FitRes, MetricsAggregationFn, NDArrays, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg
from flwr.server.strategy.aggregate import aggregate

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

class FedSparse(fl.server.strategy.FedAvg):
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
    ) -> None:
        """Custom FedAvg strategy with sparse matrices.

        Parameters
        ----------
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 0.1.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 0.1.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not accept rounds containing failures. Defaults to True.
        initial_parameters : Parameters, optional
            Initial global model parameters.
        """

        if (
            min_fit_clients > min_available_clients
            or min_evaluate_clients > min_available_clients
        ):
            log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

    def aggregate_fit(
            self,
            server_round: int,
            results: List[Tuple[ClientProxy | FitRes]],
            failures: List[Tuple[ClientProxy | FitRes] | BaseException]
            ) -> Tuple[Parameters | None | Dict[str, bool | bytes | float | int | str]]:
        """Aggregate the fit results using weighted average"""

        if not results:
            return None, {}
        
        # do not aggregate if there are failures and they are not accepted
        if not self.accept_failures and failures:
            return None, {}
        
        # deserialize and get the updated parameter & client dataset size
        weights_results = [(sparse_parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results]
        # aggregate the parameters and serialize
        parameters_aggregated = ndarrays_to_sparse_parameters(aggregate(weights_results))
        
        # aggregate custom metrics if aggregration_fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(fit_res.num_examples, fit_res.metrics) for _, fit_res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1: # log warning once
            log(WARNING, 'No fit_metrics_aggregation_fn was provided')

        return parameters_aggregated, metrics_aggregated
    
    # server side evaluation
    def evaluate(
            self,
            server_round: int,
            parameters: Parameters,
            ) -> Tuple[float | Dict[str, bool | bytes | float | int | str]] | None:
        """Evaluate model parameters using an evaluation function."""

        if self.evaluate_fn is None: # no evaluation function provided
            return None
        
        # deserialized the parameters
        parameters_ndarrays = sparse_parameters_to_ndarrays(parameters)

        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

        

In [80]:
### Start The Training

params = get_parameters(TinyVGG())

strategy = FedSparse(
    # fraction_fit=1.0, # C: fraction of client to choose for training
    # fraction_evaluate=0.5, # fraction of client to choose for evaluation
    # min_fit_clients=10, # minimum clients needed for training
    # min_evaluate_clients=5, # minimum clients needed for evaluation
    # min_available_clients=10, # wait till given client are available
    evaluate_metrics_aggregation_fn=eval_weighted_avg, # aggregate the val metrics of clients
    fit_metrics_aggregation_fn=fit_weighted_avg, # aggregate the train metrics of clients
    initial_parameters=fl.common.ndarrays_to_parameters(params), # init parameters passed
    evaluate_fn=eval_server, # server evaluation function passed here
    on_fit_config_fn=fit_config, # client fit config send from server/strategy
)

# client resources (allocate cpus and gpus)
client_resources = {'num_cpus': NUM_WORKER//NUM_CLIENT, 'num_gpus': 1 if torch.cuda.is_available() else 0}

# start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy, # <-- our custom strategy FedSparse
    client_resources=client_resources
)

INFO flwr 2024-03-29 11:43:05,013 | app.py:178 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)


2024-03-29 11:43:08,295	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2024-03-29 11:43:09,108 | app.py:213 | Flower VCE: Ray initialized with resources: {'node:__internal_head__': 1.0, 'object_store_memory': 7404576768.0, 'GPU': 1.0, 'accelerator_type:G': 1.0, 'memory': 14809153536.0, 'node:10.255.93.233': 1.0, 'CPU': 16.0}
INFO flwr 2024-03-29 11:43:09,108 | app.py:219 | Optimize your simulation with Flower VCE: https://flower.dev/docs/framework/how-to-run-simulations.html
INFO flwr 2024-03-29 11:43:09,108 | app.py:242 | Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 1}
INFO flwr 2024-03-29 11:43:09,114 | app.py:288 | Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
INFO flwr 2024-03-29 11:43:09,114 | server.py:89 | Initializing global parameters
INFO flwr 2024-03-29 11:43:09,115 | server.py:272 | Using initial parameters provided by strategy
INFO flwr 2024-03-29 11:43:09,115 | server.py:91 | Evaluating initial parameters
INF

Server round 0: evaluation loss 2.3025579166412355 | accuracy 11.03515625
[2m[36m(DefaultActor pid=908911)[0m [Client 1] fit, config: {'server_round': 1, 'local_epochs': 1}




[2m[36m(DefaultActor pid=908911)[0m [Client 0] fit, config: {'server_round': 1, 'local_epochs': 1}


DEBUG flwr 2024-03-29 11:43:13,374 | server.py:236 | fit_round 1 received 2 results and 0 failures
INFO flwr 2024-03-29 11:43:13,779 | server.py:125 | fit progress: (1, 2.3025342025756834, {'accuracy': 10.9375}, 4.254866642993875)
DEBUG flwr 2024-03-29 11:43:13,780 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2)


Server round 1: evaluation loss 2.3025342025756834 | accuracy 10.9375
[2m[36m(DefaultActor pid=908911)[0m [Client 1] evaluate, config: {}
[2m[36m(DefaultActor pid=908911)[0m [Client 0] evaluate, config: {}


DEBUG flwr 2024-03-29 11:43:15,889 | server.py:187 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2024-03-29 11:43:15,890 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=908911)[0m [Client 1] fit, config: {'server_round': 2, 'local_epochs': 1}
[2m[36m(DefaultActor pid=908911)[0m [Client 0] fit, config: {'server_round': 2, 'local_epochs': 1}


DEBUG flwr 2024-03-29 11:43:18,496 | server.py:236 | fit_round 2 received 2 results and 0 failures
INFO flwr 2024-03-29 11:43:18,890 | server.py:125 | fit progress: (2, 2.3025121841430662, {'accuracy': 10.9375}, 9.365643000928685)
DEBUG flwr 2024-03-29 11:43:18,891 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2)


Server round 2: evaluation loss 2.3025121841430662 | accuracy 10.9375
[2m[36m(DefaultActor pid=908911)[0m [Client 0] evaluate, config: {}
[2m[36m(DefaultActor pid=908911)[0m [Client 1] evaluate, config: {}


DEBUG flwr 2024-03-29 11:43:20,979 | server.py:187 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2024-03-29 11:43:20,979 | server.py:222 | fit_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(DefaultActor pid=908911)[0m [Client 1] fit, config: {'server_round': 3, 'local_epochs': 1}
[2m[36m(DefaultActor pid=908911)[0m [Client 0] fit, config: {'server_round': 3, 'local_epochs': 1}


DEBUG flwr 2024-03-29 11:43:23,793 | server.py:236 | fit_round 3 received 2 results and 0 failures
INFO flwr 2024-03-29 11:43:24,250 | server.py:125 | fit progress: (3, 2.3024918632507325, {'accuracy': 10.9375}, 14.725478978944011)
DEBUG flwr 2024-03-29 11:43:24,251 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 2)


Server round 3: evaluation loss 2.3024918632507325 | accuracy 10.9375
[2m[36m(DefaultActor pid=908911)[0m [Client 1] evaluate, config: {}
[2m[36m(DefaultActor pid=908911)[0m [Client 0] evaluate, config: {}


DEBUG flwr 2024-03-29 11:43:26,326 | server.py:187 | evaluate_round 3 received 2 results and 0 failures
INFO flwr 2024-03-29 11:43:26,327 | server.py:153 | FL finished in 16.802086564945057
INFO flwr 2024-03-29 11:43:26,327 | app.py:226 | app_fit: losses_distributed [(1, 2.3029259300231932), (2, 2.3028764419555663), (3, 2.3028296508789063)]
INFO flwr 2024-03-29 11:43:26,327 | app.py:227 | app_fit: metrics_distributed_fit {'train_loss': [(1, 2.302782864570618), (2, 2.3027063970565798), (3, 2.3026384296417235)], 'train_accuracy': [(1, 10.2375), (2, 10.2625), (3, 10.2375)]}
INFO flwr 2024-03-29 11:43:26,327 | app.py:228 | app_fit: metrics_distributed {'test_accuracy': [(1, 10.3515625), (2, 10.302734375), (3, 10.302734375)]}
INFO flwr 2024-03-29 11:43:26,328 | app.py:229 | app_fit: losses_centralized [(0, 2.3025579166412355), (1, 2.3025342025756834), (2, 2.3025121841430662), (3, 2.3024918632507325)]
INFO flwr 2024-03-29 11:43:26,328 | app.py:230 | app_fit: metrics_centralized {'accuracy': 

History (loss, distributed):
	round 1: 2.3029259300231932
	round 2: 2.3028764419555663
	round 3: 2.3028296508789063
History (loss, centralized):
	round 0: 2.3025579166412355
	round 1: 2.3025342025756834
	round 2: 2.3025121841430662
	round 3: 2.3024918632507325
History (metrics, distributed, fit):
{'train_loss': [(1, 2.302782864570618), (2, 2.3027063970565798), (3, 2.3026384296417235)], 'train_accuracy': [(1, 10.2375), (2, 10.2625), (3, 10.2375)]}History (metrics, distributed, evaluate):
{'test_accuracy': [(1, 10.3515625), (2, 10.302734375), (3, 10.302734375)]}History (metrics, centralized):
{'accuracy': [(0, 11.03515625), (1, 10.9375), (2, 10.9375), (3, 10.9375)]}