In [16]:
from collections import OrderedDict
from typing import List, Tuple, Optional, Union, Dict
import os
from pathlib import Path
from logging import WARNING, INFO, DEBUG

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import SGD
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    NDArrays,
    Parameters,
    Metrics, 
    Context, 
    Scalar, 
    ndarrays_to_parameters,
    parameters_to_ndarrays,
    )
from flwr.common.logger import log
from flwr.server import ServerApp, ServerConfig, ServerAppComponents, Server
from flwr.server.strategy import FedAvg, Strategy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg
from flwr.server.client_proxy import ClientProxy
from flwr.server.client_manager import ClientManager, SimpleClientManager
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from fedlearn.model import SmallCNN

In [17]:
datadir = Path().cwd().parent / "data" / "flower_dataset"
logdir = Path().cwd().parent / "logs" / "scaffold"

if not logdir.exists():
    logdir.mkdir(parents=True, exist_ok=True)

In [18]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

Training on cuda
Flower 1.18.0 / PyTorch 2.7.0+cu126


In [None]:
NUM_PARTITIONS = 10 # Number of partitions for the federated dataset same as the number of clients
BATCH_SIZE = 32


def load_datasets(partition_id: int, num_partitions: int):
    #partitioner = 
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}, cache_dir=datadir)
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

### Define Scaffold Optimizer

Recall that the local update in Scaffold is given by

$$
w^{(i)} \gets w^{(i)} - \eta_l \left( g_i(w^{(i)}) + c - c_i \right)
$$

Which can be seen as a gradient correction to Stochastic Gradient Descent (SGD). We may therefore extend the pytorch ```SGD``` class. We do this by computing the the regular SGD step, then adding the correction manually:

$$
\begin{align*}
w^{(i)} &\gets w^{(i)} - \eta_l \, g_i\left(w^{(i)}\right) \\
w^{(i)} &\gets w^{(i)} - \eta_l (c - c_i)
\end{align*}
$$

In [5]:
class ScaffoldOptimizer(SGD):
    def __init__(self, params, lr, momentum=0., weight_decay=0.):
        super().__init__(params, lr, momentum, weight_decay)

    def step_custom(self, global_cv, client_cv):
        """
        Perform a single optimization step.
        :param global_cv: Global control variable
        :param client_cv: Client control variable
        """
        # compute regular SGD step
        #   w <- w - lr * grad
        super().step() 

        # now add the correction term
        #   w <- w - lr * (g_cv - c_cv)
        device = self.param_groups[0]["params"][0].device
        for group in self.param_groups:
            for param, g_cv, c_cv in zip(group["params"], global_cv, client_cv):
                # here we add the correction term to each parameter tensor.
                # the alpha value scales the correction term
                    g_cv, c_cv = g_cv.to(device), c_cv.to(device)
                    param.data.add_(g_cv - c_cv, alpha=-group["lr"]) 
                #if param.grad is not None:
                    #g_cv, c_cv = g_cv.to(device), c_cv.to(device)
                    #param.grad.add_(g_cv - c_cv)  #, alpha=-group["lr"]) 
        #super().step()

We can now write a function for the local training. In this function, we want simply want to perform gradient corrected SGD updates over the local data for $E$ epochs.

In [None]:
def train_scaffold(net: torch.nn.Module, 
                   device: torch.device, 
                   trainloader: torch.utils.data.DataLoader,
                   criterion: nn.Module,
                   num_epochs: int, 
                   lr: float, 
                   momentum: float, 
                   weight_decay: float, 
                   global_cv: List[torch.Tensor], 
                   client_cv: List[torch.Tensor],
                   ) -> None:
    """
    Function that trains a model using the Scaffold optimization algorithm.
    Parameters:
        net:            The neural network model to train.
        device:         The device to run the training on (CPU or GPU).
        trainloader:    DataLoader for the training data.
        criterion:      Loss function to use for training.
        num_epochs:     Number of epochs to train the model.
        lr:             Learning rate for the optimizer.
        momentum:       Momentum factor for the optimizer.
        weight_decay:   Weight decay (L2 penalty) for the optimizer.
        global_cv:      Global control variables for Scaffold.
        client_cv:      Client control variables for Scaffold.
    """
    net.to(device)
    net.train()
    optimizer = ScaffoldOptimizer(
        net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )
    
    for _ in range(num_epochs):
        for batch in trainloader:
            Xtrain, Ytrain = batch["img"].to(device), batch["label"].to(device)
            optimizer.zero_grad()
            output = net(Xtrain)
            loss = criterion(output, Ytrain)

            # for debugging purposes, exit if loss is NaN
            #if torch.isnan(loss):
            #    raise ValueError("Loss is NaN, check your model and data.")

            loss.backward()
            
            # Perform a single optimization step with the control variables
            optimizer.step_custom(global_cv, client_cv)

We will also define a test function, which will be called to evaluate our model. This will give us some metrics to evaluate the performance of the model. As we are working with a classifier, we are interested in both the loss and accuracy of the model.

In [None]:
def test(net: torch.nn.Module, 
         device: torch.device, 
         testloader: torch.utils.data.DataLoader,
         criterion: nn.Module,
         ) -> Tuple[float, float]:
    """
    Function that tests a model on the test dataset.
    Parameters:
        net:        The neural network model to test.
        device:     The device to run the testing on (CPU or GPU).
        testloader: DataLoader for the test data.
        criterion:  Loss function to use for testing.
    Returns:
        Tuple containing the average loss and accuracy on the test set.
    """
    net.to(device)
    net.eval()
    total_loss = 0.0    # Accumulator for total loss
    correct = 0         # tracker for correct predictions
    total = 0           # tracker for total predictions
    
    with torch.no_grad():
        for batch in testloader:
            Xtest, Ytest = batch["img"].to(device), batch["label"].to(device)
            output = net(Xtest)
            loss = criterion(output, Ytest)

            if torch.isnan(loss):
                raise ValueError("Loss is NaN, check your model and data.")

            total_loss += loss.item()
            _, predicted = output.max(1)
            total += Ytest.size(0)
            correct += predicted.eq(Ytest).sum().item()
    
    avg_loss = total_loss / len(testloader) # compute the average loss
    accuracy = correct / total              # compute the accuracy
    return avg_loss, accuracy

For the simulation, we will use the flower framework, which was introduced in the _ notebook. To do so, we need to specify both Client and Server classes. We need to consider a couple of things: 
1. We can inherit the  ```NumPyClient``` class from the flower framework, however we need to remember to convert between ```np.ndarray``` and ```torch.tensor``` before and after local updates.
2. We need to specify a ```client.fit()``` method, containing all the logic for the local update. This method has 2 inputs
   1. parameters: a list of ```np.ndarray```, containing both global model parameters and global control variates
   2. config: dict for specifying training configuration (we will ignore this for now)
3. 

In [8]:
# define helper functions to set and get model parameters
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net: torch.nn.Module, parameters: List[np.ndarray]) -> None:
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k : torch.tensor(v).to(torch.float32) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


class ScaffoldClient(NumPyClient):
    def __init__(self, 
                 partition_id: int, 
                 net: torch.nn.Module, 
                 trainloader: torch.utils.data.DataLoader, 
                 valloader: torch.utils.data.DataLoader,
                 criterion: nn.Module,
                 device: torch.device,
                 num_epochs: int,
                 lr: float,
                 momentum: float,
                 weight_decay: float,
                 save_dir: Optional[str] = None,
                 ):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.criterion = criterion
        self.device = device
        self.num_epochs = num_epochs
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        # define directory to save client control variates
        if save_dir is None:
            save_dir = "client_cvs"

        # create directory if it does not exist
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # define the path to save the client control variates
        self.save_name = os.path.join(save_dir, f"client_{self.partition_id}_cv.pt")

        # initialize client control variates
        self.client_cv = [torch.zeros(param.shape).to(torch.float32) for param in self.net.state_dict().values()]


    # Here is where all the training logic and control variate updates happen
    def fit(self, parameters: List[np.ndarray], config: dict) -> Tuple[List[np.ndarray], int, dict]:

        # the global parameters are packed together with the global control variates
        # in the form [params, global_cv]. we start by separating them
        params = parameters[:len(parameters) // 2]          # list of np.ndarray
        global_cv = parameters[len(parameters) // 2:]       # list of np.ndarray

        # load the current global model:
        set_parameters(self.net, params)

        # load client control variates, if they exist:
        if os.path.exists(self.save_name):
            self.client_cv = torch.load(self.save_name)     # list of torch.tensor

        # convert global control variates to tensors
        global_cv_torch = [torch.tensor(cv).to(torch.float32) for cv in global_cv]  # list of torch.tensor

        # call the training function
        train_scaffold(
            net=self.net,
            device=self.device,
            trainloader=self.trainloader,
            criterion=self.criterion,
            num_epochs=self.num_epochs,
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
            global_cv=global_cv_torch,          # passing list of torch.tensor
            client_cv=self.client_cv            # passing list of torch.tensor
        )

        # update the client control variates
        yi = get_parameters(self.net)           # list of np.ndarray

        # compute coefficient for the control variates
        # 1 / (K * eta) where K is the number of backward passes (num_epochs * len(trainloader))
        coeff = 1. / (self.num_epochs * len(self.trainloader) * self.lr) 

        client_cv = [cv.numpy() for cv in self.client_cv]  # list of np.ndarray

        # define new list for udated client control variates
        client_cv_new = []

        # compute client control variate update, list of np.ndarray
        for xj, yj, cj, cij in zip(params, yi, global_cv, client_cv):
            client_cv_new.append(
                cij - cj + coeff * (xj - yj)
            ) 

        # compute server updates
        server_update_x = [yj - xj for xj, yj in zip(params, yi)]
        server_update_c = [cij_n - cij for cij_n, cij in zip(client_cv_new, client_cv)]

        # convert client cvs back to torch tensors
        self.client_cv = [torch.tensor(cv).to(torch.float32) for cv in client_cv_new]  

        # save the updated client control variates
        torch.save(self.client_cv, self.save_name)

        #concatenate server updates
        server_update = server_update_x + server_update_c

        return server_update, len(self.trainloader.dataset), {}



    def evaluate(self, parameters: List[np.ndarray], config: dict) -> Tuple[float, int, dict]:
        set_parameters(self.net, parameters)
        avg_loss, accuracy = test(
            net=self.net,
            device=self.device,
            testloader=self.valloader,
            criterion=self.criterion
        )
        return float(avg_loss), len(self.valloader), {"accuracy": accuracy}

Now that we have the flower client defined, we need to define a constructor function which the flower framework can use to instatiate clients as it goes.

In [9]:
def client_fn(context: Context) -> Client:
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)

    net = SmallCNN().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    
    # Define hyperparameters for training
    num_epochs = 5
    lr = 1e-4
    momentum = 0.
    weight_decay = 0.

    return ScaffoldClient(
        partition_id=partition_id,
        net=net,
        trainloader=trainloader,
        valloader=valloader,
        criterion=criterion,
        device=DEVICE,
        num_epochs=num_epochs,
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
        save_dir="client_cvs"
    ).to_client()


client = ClientApp(client_fn=client_fn)

We also need to implement a custom strategy. We can inherit the ```FedAvg``` class. All we need to do is redefine the ```aggregate_fit()``` method. This method must take the following as input:
1. server_round: the current round 
2. results:
3. failures

In [10]:
class ScaffoldStrategy(Strategy):
    def __init__(
        self,
        total_num_clients: int,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[callable] = None,
        accept_failures: bool = False,
        fit_metrics_aggregation_fn: Optional[callable] = None,
    ) -> None:
        super().__init__()
        self.total_num_clients = total_num_clients
        total_num_clients = total_num_clients
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.accept_failures = accept_failures
        self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn

        self.evaluate_fn = evaluate_fn


    def __repr__(self) -> str:
        return "ScaffoldStrategy"


    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = SmallCNN()
        parameters = get_parameters(net)        
        return ndarrays_to_parameters(parameters)


    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        config = {}
        fit_ins = FitIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        fit_configurations = [(client, fit_ins) for client in clients]
        
        return fit_configurations


    def aggregate_fit(self, 
                      server_round: int, 
                      results: List[Tuple[ClientProxy, FitRes]],
                      failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
                      ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """
        aggregation method for Scaffold strategy.
        """
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}
        
        combined_parameters = [
            parameters_to_ndarrays(fit_res.parameters) for _, fit_res in results
        ]

        len_combined_parameters = len(combined_parameters[0]) # combined number of model parameters and control variates

        num_samples_all = [fit_res.num_examples for _, fit_res in results]  # number of training samples from each client

        # The "aggregate()" function expects a list of tuples, where each tuple contains
        # the local parameters and the number of samples for that client.
        aggregation_inputs_parameters = [
            (local_params[:len_combined_parameters // 2], num_samples) 
            for local_params, num_samples in zip(combined_parameters, num_samples_all)
        ]
        
        parameters_aggregated = aggregate(aggregation_inputs_parameters)

        aggregation_inputs_cv = [
            (local_params[len_combined_parameters // 2:], num_samples) 
            for local_params, num_samples in zip(combined_parameters, num_samples_all)
        ]

        cv_aggregated = aggregate(aggregation_inputs_cv)

        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn is not None:
            fit_metrics = [
                (fit_res.num_examples, fit_res.metrics)
                for _, fit_res in results
            ]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:
            log(WARNING, "No fit_metrics_aggregation_fn provided")
        

        return (
            ndarrays_to_parameters(parameters_aggregated + cv_aggregated),
            metrics_aggregated,
        )


    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]


    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated


    # method for evaluating the global model
    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate global model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            return None
            # If an evaluation function is provided, use it
        parameters_ndarray = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarray, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics


    # boilerplate code
    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients


    # boilerplate code
    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients

In [11]:
from flwr.server.server import FitResultsAndFailures, fit_clients

def concat_params(parameters: Parameters, global_cv: List[np.ndarray]) -> Parameters:
    """
    Concatenate model parameters and global control variates.
    """
    parameters_ndarrays = parameters_to_ndarrays(parameters)
    parameters_ndarrays.extend(global_cv)
    return ndarrays_to_parameters(parameters_ndarrays)

class ScaffoldServer(Server):

    def __init__(self, 
                 strategy: Strategy, 
                 client_manager: ClientManager = SimpleClientManager(),
                 ) -> None:
        super().__init__(strategy=strategy, client_manager=client_manager)
        
        self.global_cv: List[np.ndarray] = []  # Global control variates for Scaffold
    
    def _get_initial_parameters(
        self, server_round: int, timeout: Optional[float]
        ) -> Parameters: 
        
        parameters = self.strategy.initialize_parameters(self.client_manager)

        if parameters is not None:
            log(INFO, "Using initial parameters provided by strategy")

            self.global_cv = [
                np.zeros_like(param, dtype=np.float32) for param in parameters_to_ndarrays(parameters)
            ]

            return parameters
        
        log(WARNING, "No initial parameters provided by strategy, shutting down")
        self.disconnect_all_clients()


    def fit_round(
            self,
            server_round: int,
            timeout: Optional[float],
            ) -> Optional[Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]]:
        
        # define client instructions to be passed to "fit_clients" function
        client_instructions = self.strategy.configure_fit(
            server_round=server_round,
            parameters=concat_params(self.parameters, self.global_cv),
            client_manager=self._client_manager,
        )

        # if no clients are selected, return None
        if not client_instructions:
            log(INFO, f"fit_round {server_round}: no clients selected.")
            return None
        
        log(
            DEBUG,
            f"fit_round {server_round}: selected {len(client_instructions)} clients.",
        )

        # Call the "fit_clients" function to perform the training on selected clients
        results, failures = fit_clients(
            client_instructions=client_instructions,
            max_workers=self.max_workers,
            timeout=timeout,
            group_id=server_round,
        )

        log(DEBUG,
            f"fit_round {server_round}: received {len(results)} results and {len(failures)} failures.",
        )

        # Aggregate the results from the clients
        aggregated_results = self.strategy.aggregate_fit(
            server_round=server_round,
            results=results,
            failures=failures,
        )

        # Extract the aggregated parameters and control variates
        aggregated_results_combined = []
        if aggregated_results[0] is not None:
            aggregated_results_combined = parameters_to_ndarrays(aggregated_results[0])

        # Split the aggregated results into model parameters and control variates
        aggregated_parameters = aggregated_results_combined[:len(aggregated_results_combined) // 2] # model parameters
        aggregated_cv = aggregated_results_combined[len(aggregated_results_combined) // 2:]         # control variates

        # define the update coefficient for the control variates
        cv_coeff = len(results) / len(self._client_manager.all())

        # Update the global control variates according to
        # global_cv <- global_cv + cv_coeff * aggregated_cv
        # where cv_coeff = |S| / N, |S| is the number of clients that participated in the round
        # and aggregated_cv = (1 / |S|) * sum_{i in S} (c_i^+ - c_i)
        self.global_cv = [
            cv + cv_coeff * new_cv for cv, new_cv in zip(self.global_cv, aggregated_cv)
        ]


        # Update the global model parameters
        # new_parameters = current_parameters + aggregated_parameters
        # where current_parameters are the parameters of the global model before the round
        # and aggregated_parameters = (1 / |S|) * sum_{i in S} (w_i^+ - w)
        current_parameters = parameters_to_ndarrays(self.parameters)
        new_parameters = [
            param + update for param, update in zip(current_parameters, aggregated_parameters)
        ]

        new_parameters = ndarrays_to_parameters(new_parameters)

        return new_parameters, aggregated_results[1], (results, failures)

We are now ready for the server

In [12]:
# Create an instance of the model and get the parameters
params = get_parameters(SmallCNN())
criterion = nn.CrossEntropyLoss()

# The `evaluate` function will be called by Flower after every round
def evaluate(
    server_round: int,
    parameters: list[np.ndarray],
    config: dict[str, Scalar],
    ) -> Optional[Tuple[float, dict[str, Scalar]]]:
    
    net = SmallCNN().to(DEVICE)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(
        net=net, 
        device=DEVICE, 
        testloader=testloader, 
        criterion=criterion,
        )
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}



def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = ScaffoldStrategy(
        total_num_clients=NUM_PARTITIONS,       # Total number of clients
        fraction_fit=1.0,                       # Use all clients for training, C
        fraction_evaluate=0.5,                  # Use 50% of clients for evaluation
        min_fit_clients=10,                     # Minimum number of clients to train
        min_evaluate_clients=5,                 # Minimum number of clients to evaluate
        min_available_clients=NUM_PARTITIONS,   # Minimum number of clients available (enforce all clients to be available)
        evaluate_fn=evaluate,                   # Pass the evaluation function
    )

    server = ScaffoldServer(strategy=strategy)

    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(server=server, config=config)

# Create the ServerApp
server = ServerApp(server_fn=server_fn)

In [13]:
NUM_PARTITIONS = 10  # Number of partitions (clients)
backend_config = {
    "ray_init_args": {
        "num_cpus": 1,
        "num_gpus": 1,
    }
}
run_simulation(
    server_app=server, client_app=client, num_supernodes=NUM_PARTITIONS, backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      


[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      initial parameters (loss, other metrics): 2.3037926693693898, {'accuracy': 0.1001}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]


Server-side evaluation loss 2.3037926693693898 / accuracy 0.1001


[91mERROR [0m:     ServerApp thread raised an exception: Loss is NaN, check your model and data.
[91mERROR [0m:     Traceback (most recent call last):
  File "c:\Users\Phill\anaconda3\envs\FLenv2\Lib\site-packages\flwr\simulation\run_simulation.py", line 268, in server_th_with_start_checks
    updated_context = _run(
                      ^^^^^
  File "c:\Users\Phill\anaconda3\envs\FLenv2\Lib\site-packages\flwr\server\run_serverapp.py", line 62, in run
    server_app(grid=grid, context=context)
  File "c:\Users\Phill\anaconda3\envs\FLenv2\Lib\site-packages\flwr\server\server_app.py", line 166, in __call__
    start_grid(
  File "c:\Users\Phill\anaconda3\envs\FLenv2\Lib\site-packages\flwr\server\compat\app.py", line 90, in start_grid
    hist = run_fl(
           ^^^^^^^
  File "c:\Users\Phill\anaconda3\envs\FLenv2\Lib\site-packages\flwr\server\server.py", line 492, in run_fl
    hist, elapsed_time = server.fit(
                         ^^^^^^^^^^^
  File "c:\Users\Phill\anaconda3\e

RuntimeError: Exception in ServerApp thread