# Federated Learning

Welcome to the Flower federated learning tutorial!

In this note, we will build a federated learning system using Flower and PyTorch that drops clients based on their performance. To achieve this, we will create a custom strategy and criterion. And then in each round, we will calculate a threshold to determine which clients are underperforming and exclude them from sampling.

Note that this method is intented for learning purposes only. For real protection against malicious clients you might consider Byzantine-resilient methods like Krum or Bulyan.

> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Slack to connect, ask questions, and get help: [Join Slack](https://flower.dev/join-slack) 🌼 We'd love to hear from you in the `#introductions` channel! And if anything is unclear, head over to the `#questions` channel.

Let's get stated!

In [None]:
!pip install -q flwr[simulation] torch torchvision matplotlib pandas

In [None]:
from collections import OrderedDict
from functools import partial
from logging import WARNING, DEBUG
import os
from typing import Callable, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.datasets import CIFAR10

import flwr as fl
from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    Metrics,
    NDArray,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.criterion import Criterion
from flwr.server.client_proxy import ClientProxy
from flwr.server.client_manager import ClientManager
from flwr.server.strategy.aggregate import aggregate


DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)


### Loading the data

Federated learning can be applied to many different types of tasks across different domains. In this tutorial, we introduce federated learning by training a simple convolutional neural network (CNN) on the popular CIFAR-10 dataset. CIFAR-10 can be used to train image classifiers that distinguish between images from ten different classes:

In [None]:
CLASSES = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

We simulate having multiple datasets from multiple organizations (also called the "cross-silo" setting in federated learning) by splitting the original CIFAR-10 dataset into multiple partitions. Each partition will represent the data from a single organization. We're doing this purely for experimentation purposes. In the real world, there's no need for data splitting because each organization already has its own data (so the data is naturally partitioned).

Each organization will act as a client in the federated learning system. So having ten organizations participate in a federation means having ten clients connected to the federated learning server:


In [None]:
NUM_CLIENTS = 10

Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap the resulting partitions by creating a PyTorch `DataLoader` for each of them:

In [None]:
BATCH_SIZE = 32


def load_datasets():
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into 10 partitions to simulate the individual dataset
    partition_size = len(trainset) // NUM_CLIENTS
    lengths = [partition_size] * NUM_CLIENTS
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets()

We now have a list of ten training sets and ten validation sets (`trainloaders` and `valloaders`) representing the data of ten different organizations. Each `trainloader`/`valloader` pair contains 4500 training examples and 500 validation examples. There's also a single `testloader` (we did not split the test set). Again, this is only necessary for building research or educational systems. Actual federated learning systems have their data naturally distributed across multiple partitions.

Now, in order to get poor results from one node, we will change the labels of a single client such that they are random. The updates from this client should be meaningless and harmful to the general model.

In [None]:
class MaliciousDataset(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset

    def __getitem__(self, index):
        data, _ = self.original_dataset[index]  # we ignore original label
        new_label = random.randint(0, 9)  # generate a random integer between 0 and 9
        return data, new_label

    def __len__(self):
        return len(self.original_dataset)

In [None]:
original_train_dataset = trainloaders[0].dataset
original_valid_dataset = valloaders[0].dataset

malicious_train_dataset = MaliciousDataset(original_train_dataset)
malicious_valid_dataset = MaliciousDataset(original_valid_dataset)

trainloaders[0] = DataLoader(
    malicious_train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
valloaders[0] = DataLoader(
    malicious_valid_dataset, batch_size=BATCH_SIZE, shuffle=False
)

Let's take a look at the malicious dataset before we move on.

In [None]:
images, labels = next(iter(trainloaders[0]))

# Reshape and convert images to a NumPy array
# matplotlib requires images with the shape (height, width, 3)
images = images.permute(0, 2, 3, 1).numpy()
# Denormalize
images = images / 2 + 0.5

# Create a figure and a grid of subplots
fig, axs = plt.subplots(4, 8, figsize=(12, 6))

# Loop over the images and plot them
for i, ax in enumerate(axs.flat):
    ax.imshow(images[i])
    ax.set_title(CLASSES[labels[i]])
    ax.axis("off")

# Show the plot
fig.tight_layout()
plt.show()

As expected the labels don't match the images.

## Defining the model

In [None]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Let's continue with the usual training and test functions:

In [None]:
def train(net, trainloader, epochs: int, verbose=False):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
    return epoch_loss, epoch_acc


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

Federated Learning with Flower

Now, we'll simulate a situation where we have multiple datasets in multiple organizations and where we train a model over these organizations using federated learning.

### Updating model parameters

In federated learning, the server sends the global model parameters to the client, and the client updates the local model with the parameters received from the server. It then trains the model on the local data (which changes the model parameters locally) and sends the updated/changed model parameters back to the server (or, alternatively, it sends just the gradients back to the server, not the full model parameters).

We need two helper functions to update the local model with parameters received from the server and to get the updated model parameters from the local model: `set_parameters` and `get_parameters`. The following two functions do just that for the PyTorch model above.

The details of how this works are not really important here (feel free to consult the PyTorch documentation if you want to learn more). In essence, we use `state_dict` to access PyTorch model parameter tensors. The parameter tensors are then converted to/from a list of NumPy ndarray's (which Flower knows how to serialize/deserialize):

In [None]:
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

### Implementing a Flower client

With that out of the way, let's move on to the interesting part. Federated learning systems consist of a server and multiple clients. In Flower, we create clients by implementing subclasses of `flwr.client.Client` or `flwr.client.NumPyClient`. We use `NumPyClient` in this tutorial because it is easier to implement and requires us to write less boilerplate.

To implement the Flower client, we create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`:

* `get_parameters`: Return the current local model parameters
* `fit`: Receive model parameters from the server, train the model parameters on the local data, and return the (updated) model parameters to the server
* `evaluate`: Receive model parameters from the server, evaluate the model parameters on the local data, and return the evaluation result to the server

We mentioned that our clients will use the previously defined PyTorch components for model training and evaluation. Let's see a simple Flower client implementation that brings everything together:

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = train(self.net, self.trainloader, epochs=1)
        return (
            get_parameters(self.net),
            len(self.trainloader),
            {"accuracy": float(accuracy)},
        )

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

Our class `FlowerClient` defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through `fit` and `evaluate`. Each instance of `FlowerClient` represents a *single client* in our federated learning system. Federated learning systems have multiple clients (otherwise, there's not much to federate), so each client will be represented by its own instance of `FlowerClient`. If we have, for example, three clients in our workload, then we'd have three instances of `FlowerClient`. Flower calls `FlowerClient.fit` on the respective instance when the server selects a particular client for training (and `FlowerClient.evaluate` for evaluation).

### Using the Virtual Client Engine

In this notebook, we want to simulate a federated learning system with 10 clients on a single machine. This means that the server and all 10 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 10 clients would mean having 10 instances of `FlowerClient` in memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.

In addition to the regular capabilities where server and clients run on multiple machines, Flower, therefore, provides special simulation capabilities that create `FlowerClient` instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function called `client_fn` that creates a `FlowerClient` instance on demand. Flower calls `client_fn` whenever it needs an instance of one particular client to call `fit` or `evaluate` (those instances are usually discarded after use, so they should not keep any local state). Clients are identified by a client ID, or short `cid`. The `cid` can be used, for example, to load different local data partitions for different clients, as can be seen below:

In [None]:
def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    # Create a  single Flower client representing a single organization
    return FlowerClient(net, trainloader, valloader)

### Centralized Evaluation
Finally we define the function that will be used for centralized evaluation.

In [None]:
def evaluate(
    server_round: int,
    parameters: fl.common.NDArrays,
    config: Dict[str, fl.common.Scalar],
    verbose: bool = False,
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
    """Centralized evaluation function"""
    net = Net().to(DEVICE)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    if verbose:
        print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

### Low Performance Detection
Now, we have all the building blocks besides the actual clients removal. Firstly, we need to come up with a heuristic for dropping the clients.

We have to keep in mind the following:
* we might evaluate only on the fraction of the clients (important for design decsisions for dropping),
* we might allow some patience before droping the clinet.

We can not drop the client that has the worst results, because we alawys have a client that performs the worst. We need to have a way to compare clients performance while taking into account more sophisticated computation based on other clients and/or the historical values of their performances.

In this example we will drop the clients that is below an adjustable threshold that is `n` standard deviations below the mean. (Please note once again that this example is developed mainly for the teaching purposes, for robust protection agains malign client check out Byzantine-resilient methods like Krum or Bulyan.) We will make use of the centralized dataset. Yet, normally it is used to evaluate model with newly aggregated weights. In this example we will use the centalized dataset for every fit client evaluation.

In order to achieve these goals we will create a custom `Strategy` and `Criterion`.

#### Custom Criterion

`Criterion` is a way of enforcing certain selection of clients in the strategy. We will modify the Criterion's thershold each round and it will reflect the performance of the selected nodes for training.

In [None]:
class NoLowPerformanceCriterion:
    def __init__(self, n_stds: int = 2):
        self.n_stds: int = n_stds
        # self.cid_to_performance: Dict[str: float] = {}
        self.low_performance_cid: List[str] = []
        # everything below threshold will be discarded
        # the threshold will be calculated for each round
        self.threshold: float = float("-inf")

    def _calculate_threshold(self, cid_to_performance):
        """Calculate threshold used for client selection based on the
        current round performance.

        Note:
            This function needs to be called before select.

            This function needs to be called after the self.cid_to_performances was set.
            (The cid_to_performance has to be the attribute since it is used also in select)

        """
        performances = np.array(list(cid_to_performance.values()))
        perf_mean = np.mean(performances)
        perf_std = np.std(performances)
        self.threshold = perf_mean - self.n_stds * perf_std
        log(DEBUG, "Mean performance: %s", perf_mean)
        log(DEBUG, "Std preformance: %s", perf_std)
        log(DEBUG, "cid_to_performance: %s", cid_to_performance)
        log(DEBUG, "Threshold: %s", self.threshold)

    def exclude_low_performers(self, cid_to_performance) -> list[str]:
        """
        Check and exclude any client from the current round that underperforms.

        Firstly calculates the threshold, then goes through every client.
        It is inteneded to be used after the the
        It adds the low performers' cids to the self.low_performance_cid so they won't be selected anymore.
        """
        previous_low_performance_nodes = len(self.low_performance_cid)
        self._calculate_threshold(cid_to_performance)
        new_low_performers = []
        for cid, performance in cid_to_performance.items():
            if performance < self.threshold:
                new_low_performers.append(cid)
        self.low_performance_cid.extend(new_low_performers)
        log(
            DEBUG,
            "Total number of low performance clients is %s",
            len(self.low_performance_cid),
        )
        log(
            DEBUG,
            "This round %s new low perfomance clients were removed.",
            len(new_low_performers),
        )
        return new_low_performers

    def select(self, client: ClientProxy) -> bool:
        """
        Select if the clinet can be sampled.

        The clinet can be sampled if it was not previously identified as low performer (its cid is in self.low_performance_cid).
        """
        cid = client.cid

        # Check if it is in the list of low performers
        if cid in self.low_performance_cid:
            return False
        else:
            return True

#### Custom Strategy

This custom strategy will use the `Criterion` defined above.

In [None]:
class FedAvgWithDrop(fl.server.strategy.FedAvg):
    """FedAvg that drops the clients based on performance on centralized evaluation."""

    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        no_low_perfromance_criterion: Optional[Criterion] = None,
        criterion_performance_metric: Optional[str] = None,
    ) -> None:
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.no_low_perfromance_criterion = no_low_perfromance_criterion
        self.criterion_performance_metric = criterion_performance_metric

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}

        # THIS IS NEW: START
        # Drop the clients based on performance
        cid_to_params = {
            client_proxy.cid: fit_res.parameters for client_proxy, fit_res in results
        }

        cid_to_performance = {}
        for cid, params in cid_to_params.items():
            loss, metrics = self.evaluate(server_round, params)
            cid_to_performance[cid] = metrics[self.criterion_performance_metric]
        new_low_performers = self.no_low_perfromance_criterion.exclude_low_performers(
            cid_to_performance
        )

        # Exclude the low_perfomers' parameters from aggregation
        results = [res for res in results if res[0].cid not in new_low_performers]
        # THIS IS NEW: END

        # Convert results
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]
        parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
        fit_ins = FitIns(parameters, config)

        # Sample clients
        # THIS IS MODIFIED: START
        # We want to treat the low performance clinets as if they were removed so we decrease
        # the total number of client by subtracting then number of low performance clients
        # This might be optional step - it depends how you want the system to behave
        # In this case the fraction fit is calculated as the fraction of non-malicious clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
            - len(self.no_low_perfromance_criterion.low_performance_cid)
        )
        clients = client_manager.sample(
            num_clients=sample_size,
            min_num_clients=min_num_clients,
            criterion=self.no_low_perfromance_criterion,
        )
        # THIS IS MODIFIED: END

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        # Do not configure federated evaluation if fraction eval is 0.
        if self.fraction_evaluate == 0.0:
            return []

        # Parameters and config
        config = {}
        if self.on_evaluate_config_fn is not None:
            # Custom evaluation config function provided
            config = self.on_evaluate_config_fn(server_round)
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
            - len(self.no_low_perfromance_criterion.low_performance_cid)
        )
        # THIS IS MODIFIED: START
        # We add the criterion there to the client manager
        clients = client_manager.sample(
            num_clients=sample_size,
            min_num_clients=min_num_clients,
            criterion=self.no_low_perfromance_criterion,
        )
        # THIS IS MODIFIED: END

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

In [None]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

Now we can use the building blocks we created.

In [None]:
no_low_perfromance_criterion = NoLowPerformanceCriterion(n_stds=2)

# Create FedAvg strategy
strategy = FedAvgWithDrop(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=2,
    min_evaluate_clients=2,
    min_available_clients=2,
    fit_metrics_aggregation_fn=weighted_average,
    evaluate_metrics_aggregation_fn=weighted_average,
    evaluate_fn=evaluate,
    no_low_perfromance_criterion=no_low_perfromance_criterion,
    criterion_performance_metric="accuracy",
)

### Start Simulation

In [None]:
# Start simulation

history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=2),
    strategy=strategy,
    client_resources={"num_cpus": 2},
)

In [None]:
history

## Final remarks

Congratulations, you just created and trained a Federated Learning system that drops clients based on their performance! You got familiar with custom `Strategy` and `Criterion` created to controll sampling. The same approach you've seen can be used with other machine learning frameworks (not just PyTorch) and tasks (not just CIFAR-10 images classification), for example NLP with Hugging Face Transformers or speech with SpeechBrain.


## Next steps

Make sure to join the Flower community on Slack: [Join Slack](https://flower.dev/join-slack/)

There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!