<a href="https://colab.research.google.com/github/TyRoBr/FreeCodeCamp-Pandas-Real-Life-Example/blob/master/Fed_BioMed_practical_session_fl_algorithms_ai4health_exercise.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# EXERCICE 2: PLAY AROUND WITH FEDERATED LEARNING

Now that you have understood the broad concept of federated learning from the previous tutorial, the goal of this tutorial is to introduce various specific federated learning algorithms and their associated parameters. Additionally, we will explore the limitations of these algorithms in the context of data heterogeneity.

# 1 - Environment set-up

First we install some additional packages that are required by the tutorial but not included in Fed-BioMed

In [None]:
%pip install -q tqdm colab-xterm "jedi>=0.16"

## 1.1 - Install Fed-BioMed
First, download the wheel file.

In [None]:
!wget 'https://docs.google.com/uc?export=download&id=1R8P5GcAsNQZDPy2ucmixkzd4huFpKvCR' -O fedbiomed-6.1.0-py3-none-any.whl

Install the wheel file. This will take some time.

There might be some errors at the end, which can be safely ignored.

If prompted, restart the session.


In [None]:
%pip install ./fedbiomed-6.1.0-py3-none-any.whl

## 1.2 - Import other packages

Import other python packages useful for the rest of the tutorial.

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np

## 1.3 - Define useful functions

In [None]:
def plot_perf(exp, num_rounds):
  nodes = exp.monitor()._metric_store.keys()

  global_test_losses = {}
  for node in nodes:
    global_test_losses[node] = list()
    for round in exp.monitor()._metric_store[node]['testing_global_updates']['ACCURACY'].keys():
      round_loss = exp.monitor()._metric_store[node]['testing_global_updates']['ACCURACY'][round]['values']
      global_test_losses[node].append(np.mean(round_loss))

  train_losses = {}
  for node in nodes:
    train_losses[node] = list()
    for round in exp.monitor()._metric_store[node]['training']['Loss'].keys():
      round_loss = exp.monitor()._metric_store[node]['training']['Loss'][round]['values']
      train_losses[node].append(np.mean(round_loss))

  colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']

  plt.figure()
  for i,node in enumerate(nodes):
    color = colors[i % len(colors)]
    plt.plot(range(1,num_rounds+1), train_losses[node], '-', label=f'{node} training loss', color=color)
  _ = plt.xlabel('Round')
  _ = plt.ylabel('Loss')
  plt.legend()

  plt.figure()
  for i,node in enumerate(nodes):
    color = colors[i % len(colors)]
    plt.plot(range(num_rounds+1), global_test_losses[node], '-', label=f'{node} global accuracy', color=color)
  _ = plt.xlabel('Round')
  _ = plt.ylabel('Accuracy')
  plt.legend()

  plt.show()

# 2 - Tutorial use-case and federated learning set-up

In this tutorial, we will use a set-up with 2 clients. We will first create the two components.

## 2.1 - Create node components



In [None]:
%%shell
for i in {1..2}; do
    fedbiomed component create -c node -p /content/fedbiomed_components/client_${i}
done

## 2.2 - Download the data

As a use-case, we will use the [MedNIST dataset](https://medmnist.com/). The MedNIST dataset is a standardized collection of medical images designed for benchmarking and training machine learning models. It includes various types of medical images such as X-rays, CT scans, and MRI images, categorized into different classes based on anatomical regions or imaging modalities. It serves as a medical equivalent to the MNIST dataset, facilitating research and development in medical image analysis.

The goal of the tutorial is to create a model for modality classification.

First, download the different datasets.

In [None]:
%%shell
pip install gdown
mkdir -p /content/fedbiomed_components/notebooks/data
cd /content/fedbiomed_components/notebooks/data
gdown https://drive.google.com/uc?id=1SldskvQPiLwAmadgd2dBK_so4hYBYwCV -O /content/fedbiomed_components/notebooks/data/mednist_datasets.zip
unzip /content/fedbiomed_components/notebooks/data/mednist_datasets.zip -d /content/fedbiomed_components/notebooks/data

## 2.3 - Explore the datasets

For each client, there are three datasets downloaded:
- the 'no_skew' dataset, for which the number of images and the labels are honogeneous across the nodes,
- the 'quantity_skew' dataset, for which the number of images is heterogeneous across the nodes,
- the 'label_skew' dataset, for which the distribution of labels is heterogeneous across the nodes.

Let's investigate a bit these differences.

### 2.3.1 Explore the no skew dataset


In [None]:
# Define the base directory path
base_dir = "/content/fedbiomed_components/notebooks/data/mednist_datasets"

# Define clients
clients = ["client_1", "client_2"]

# Initialize a dictionary to store image counts
data = {client: {} for client in clients}

# Initialize a set to store unique modalities
modalities = set()

# Iterate over each client directory to collect modalities
for client in clients:
    client_path = os.path.join(base_dir, client, "no_skew", "MedNIST")
    if os.path.exists(client_path):
        # Retrieve modalities for the client
        detected_modalities = [d for d in os.listdir(client_path) if os.path.isdir(os.path.join(client_path, d))]
        modalities.update(detected_modalities)

# Convert the set of modalities to a sorted list
modalities = sorted(modalities)

# Prepare data storage with dynamic modalities
for client in clients:
    for modality in modalities:
        data[client][modality] = 0

# Iterate over each client directory to count images
for client in clients:
    client_path = os.path.join(base_dir, client, "no_skew", "MedNIST")
    if os.path.exists(client_path):
        # Iterate over each modality folder
        for modality in modalities:
            modality_path = os.path.join(client_path, modality)
            if os.path.exists(modality_path):
                # Count the number of JPEG images in the modality folder
                image_count = len([f for f in os.listdir(modality_path) if f.endswith('.jpeg')])
                data[client][modality] = image_count

# Plotting
client_indices = np.arange(len(clients))
bar_width = 0.1

fig, ax = plt.subplots(figsize=(10, 6))

# Plot bars for each modality
for i, modality in enumerate(modalities):
    counts = [data[client][modality] for client in clients]
    ax.bar(client_indices + i * bar_width, counts, width=bar_width, label=modality)

# Add labels, title, and legend
ax.set_xlabel('Client')
ax.set_ylabel('Number of Images')
ax.set_title('Distribution of Images by Client and Modality for no skew dataset')
ax.set_xticks(client_indices + bar_width * (len(modalities) - 1) / 2)
ax.set_xticklabels(clients)
ax.legend()

plt.show()

Look at one example image for each modality:

In [None]:
# Define the base directory path for client 1
base_dir = "/content/fedbiomed_components/notebooks/data/mednist_datasets"
client = "client_1"
client_path = os.path.join(base_dir, client, "no_skew", "MedNIST")

# Check if the client path exists
if os.path.exists(client_path):
    # Retrieve modalities for client 1
    modalities = sorted([d for d in os.listdir(client_path) if os.path.isdir(os.path.join(client_path, d))])

    # Plot an example image from each modality
    for modality in modalities:
        modality_path = os.path.join(client_path, modality)
        # List all JPEG images in the modality directory
        image_files = [f for f in os.listdir(modality_path) if f.endswith('.jpeg')]

        if image_files:
            # Select the first image file
            example_image_path = os.path.join(modality_path, image_files[0])
            # Load and display the image
            img = mpimg.imread(example_image_path)
            plt.figure()
            plt.imshow(img, cmap='gray')
            plt.title(f"Example Image from {client} - {modality}")
            plt.axis('off')
            plt.show()
else:
    print(f"The specified path for {client} does not exist.")

### 2.3.2 Explore the quantity skew dataset



In [None]:
# Define the base directory path
base_dir = "/content/fedbiomed_components/notebooks/data/mednist_datasets"

# Define clients
clients = ["client_1", "client_2"]

# Initialize a dictionary to store image counts
data = {client: {} for client in clients}

# Initialize a set to store unique modalities
modalities = set()

# Iterate over each client directory to collect modalities
for client in clients:
    client_path = os.path.join(base_dir, client, "quantity_skew", "MedNIST")
    if os.path.exists(client_path):
        # Retrieve modalities for the client
        detected_modalities = [d for d in os.listdir(client_path) if os.path.isdir(os.path.join(client_path, d))]
        modalities.update(detected_modalities)

# Convert the set of modalities to a sorted list
modalities = sorted(modalities)

# Prepare data storage with dynamic modalities
for client in clients:
    for modality in modalities:
        data[client][modality] = 0

# Iterate over each client directory to count images
for client in clients:
    client_path = os.path.join(base_dir, client, "quantity_skew", "MedNIST")
    if os.path.exists(client_path):
        # Iterate over each modality folder
        for modality in modalities:
            modality_path = os.path.join(client_path, modality)
            if os.path.exists(modality_path):
                # Count the number of JPEG images in the modality folder
                image_count = len([f for f in os.listdir(modality_path) if f.endswith('.jpeg')])
                data[client][modality] = image_count

# Plotting
client_indices = np.arange(len(clients))
bar_width = 0.1

fig, ax = plt.subplots(figsize=(10, 6))

# Plot bars for each modality
for i, modality in enumerate(modalities):
    counts = [data[client][modality] for client in clients]
    ax.bar(client_indices + i * bar_width, counts, width=bar_width, label=modality)

# Add labels, title, and legend
ax.set_xlabel('Client')
ax.set_ylabel('Number of Images')
ax.set_title('Distribution of Images by Client and Modality for quantity skew dataset')
ax.set_xticks(client_indices + bar_width * (len(modalities) - 1) / 2)
ax.set_xticklabels(clients)
ax.legend()

plt.show()

As you can observe, the number of images in both clients is heterogenous, but the distribution of labels is homogeneous.



### 2.3.3 - Explore the label skew dataset

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

# Define the base directory path
base_dir = "/content/fedbiomed_components/notebooks/data/mednist_datasets"

# Define clients
clients = ["client_1", "client_2"]

# Initialize a dictionary to store image counts
data = {client: {} for client in clients}

# Initialize a set to store unique modalities
modalities = set()

# Iterate over each client directory to collect modalities
for client in clients:
    client_path = os.path.join(base_dir, client, "label_skew", "MedNIST")
    if os.path.exists(client_path):
        # Retrieve modalities for the client
        detected_modalities = [d for d in os.listdir(client_path) if os.path.isdir(os.path.join(client_path, d))]
        modalities.update(detected_modalities)

# Convert the set of modalities to a sorted list
modalities = sorted(modalities)

# Prepare data storage with dynamic modalities
for client in clients:
    for modality in modalities:
        data[client][modality] = 0

# Iterate over each client directory to count images
for client in clients:
    client_path = os.path.join(base_dir, client, "label_skew", "MedNIST")
    if os.path.exists(client_path):
        # Iterate over each modality folder
        for modality in modalities:
            modality_path = os.path.join(client_path, modality)
            if os.path.exists(modality_path):
                # Count the number of JPEG images in the modality folder
                image_count = len([f for f in os.listdir(modality_path) if f.endswith('.jpeg')])
                data[client][modality] = image_count

# Plotting
client_indices = np.arange(len(clients))
bar_width = 0.1

fig, ax = plt.subplots(figsize=(10, 6))

# Plot bars for each modality
for i, modality in enumerate(modalities):
    counts = [data[client][modality] for client in clients]
    ax.bar(client_indices + i * bar_width, counts, width=bar_width, label=modality)

# Add labels, title, and legend
ax.set_xlabel('Client')
ax.set_ylabel('Number of Images')
ax.set_title('Distribution of Images by Client and Modality for label skew dataset')
ax.set_xticks(client_indices + bar_width * (len(modalities) - 1) / 2)
ax.set_xticklabels(clients)
ax.legend()

plt.show()

As you can observe, the number of images in both clients is homogeneous, but the distribution of labels is heterogeneous.

## 2.4 - Integrate the data in Fed-BioMed nodes and start nodes

For each client, create and add two fedbiomed dataset description files, each one corresponding to the two types of MedNIST datasets.

In [None]:
%%shell
for i in {1..2}; do
    for skew_type in no_skew quantity_skew label_skew; do
        mkdir -p "/content/fedbiomed_components/client_${i}/data/${skew_type}"
        cd "/content/fedbiomed_components/client_${i}/data/${skew_type}"
        tee dataset.json << END
{
"data_type": "mednist",
"path": "/content/fedbiomed_components/notebooks/data/mednist_datasets/client_${i}/${skew_type}",
"description": "MedNIST dataset with ${skew_type}",
"name": "MedNIST ${skew_type} client ${i}",
"tags": "mednist_${skew_type}"
}
END
    done
done

Use the command line interface (CLI) to add a dataset to the Fed-BioMed database.

In [None]:
%%shell
for i in {1..2}; do
    for skew_type in no_skew quantity_skew label_skew; do
        fedbiomed node --path /content/fedbiomed_components/client_${i} dataset add -f /content/fedbiomed_components/client_${i}/data/${skew_type}/dataset.json
    done
done

Now, le'ts start the fedbiomed nodes, as we already did in the previous tutorial.

First client:

- Activate the Fed-BioMed software and leave it in standby, waiting to receive requests for training from the orchestrator.
- Execute the cell below and wait for the terminal to appear.
- Copy/paste the following line in the terminal, then hit `Enter`:

```shell
fedbiomed node --path /content/fedbiomed_components/client_1 start
```

In [None]:
%load_ext colabxterm

In [None]:
%xterm

When the node starts, you should see the following message appear multiple times:

```shell
fedbiomed DEBUG - Researcher server is not available, will retry connect in 2 seconds
```

This is normal: the hospital is ready to work, but we have not given any workload yet.

Second client:

- Activate the Fed-BioMed software and leave it in standby, waiting to receive requests for training from the orchestrator.
- Execute the cell below and wait for the terminal to appear.
- Copy/paste the following line in the terminal, then hit `Enter`:

```shell
fedbiomed node --path /content/fedbiomed_components/client_2 start
```

In [None]:
%xterm

When the node starts, you should see the following message appear multiple times:

```shell
fedbiomed DEBUG - Researcher server is not available, will retry connect in 2 seconds
```

This is normal: the hospital is ready to work, but we have not given any workload yet.

# 3 - No skew datasets

In this section of the tutorial, we will investigate the behaviour of a classical algorithm called Federated Averaging, depending on different parameters of the training.

Federated Averaging (also known as [FedAvg](https://arxiv.org/pdf/1602.05629)) is a popular algorithm in federated learning. As most federated learning algorithm, it consists of two main phases: local node optimization and global averaging.

### Initialization

The process begins with initializing a global model with random weights or pre-trained weights. This global model is denoted as $w_0 $.

### Local Node Optimization

In each communication round $t$, a subset of clients $K$ is selected to participate in the training process. Each selected client $k$ performs the following steps:

- **Download the Global Model**: Each client $k$ downloads the current global model weights $w_t$.

- **Local Training**: Each client updates the model locally using its own dataset. This involves performing stochastic gradient descent (SGD) or any other optimization technique on the local data for a certain number of epochs or iterations. The local objective for client $k$ is to minimize the loss function $ F_k(w) $, defined as:

  $$
  F_k(w) = \frac{1}{n_k} \sum_{i=1}^{n_k} f_k(w; x_i, y_i)
  $$

  where $ n_k $ is the number of data samples on client $ k $, and $ f_k(w; x_i, y_i) $ is the loss function for the data sample $ (x_i, y_i) $.

- **Compute Local Update**: After local training, each client computes the update to the global model. The updated local weights are denoted as $ w_{t+1}^k $.

### Global Averaging

After the local training phase, the updates from the participating clients are aggregated to update the global model:

- **Upload Local Updates**: Each client $ k $ sends its local model weights $ w_{t+1}^k $ back to the central server.

- **Aggregate Updates**: The central server aggregates these local updates to form a new global model. The global model update is typically a weighted average of the local models, where the weights are proportional to the number of data samples on each client. The global model update formula is:

  $$
  w_{t+1} = \sum_{k=1}^{K} \frac{n_k}{n} w_{t+1}^k
  $$

  where $ n $ is the total number of data samples across all clients, $ n_k $ is the number of data samples on client $ k $, and $ K $ is the number of participating clients.

### Repeat

The process of local training and global averaging is repeated for a predefined number of communication rounds or until convergence is achieved.

## 3.1 - Experiment 1: FedAverage with small number of rounds and large number of local updates


### 3.1.1 - Define a Training Plan

As the goal of this tutorial is to use federated learning for a classification task of different image modalities in the MedNIST dataset, we will define a simple convolutional neural network using PyTorch. This model is defined in the training plan, as explained in the previous tutorial.

In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan

# Here we define the model to be used.

class FedAvgTrainingPlan(TorchTrainingPlan):
    class MyModel(torch.nn.Module):
        """definition of a PyTorch model, with its __init__ and forward methods"""
        def __init__(self, model_args: dict):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(16 * 16 * 16, model_args.get('num_classes', 6))

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            return F.log_softmax(x, dim=1)

    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms",
                "from torchvision.transforms import ToTensor",
                'from torch.optim import AdamW, SGD',
                ]

        return deps

    def init_model(self, model_args: dict):
        """Defines your model here"""
        return self.MyModel(model_args)

    def init_optimizer(self, optimizer_args):
        """Defines your optimizer here"""
        optimizer = AdamW(self.model().parameters(), lr=optimizer_args.get('lr', 0.001))
        return optimizer

    def training_data(self):
        """Defines data handling/parsing here"""
        # Custom torch Dataloader for MedNIST data

        preprocess = transforms.Compose([transforms.ToTensor()])
        train_data = datasets.ImageFolder(self.dataset_path,transform = preprocess)

        return DataManager(dataset=train_data, shuffle=True)

    def training_step(self, data, target):
        """Defines cost function and how to compute loss"""
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss

### 3.1.2 - Define model parameters

Here you can define `model_args`, a dictionary that contain parameters and hyperparameters for model definition. Here, we specify that the number of classes is 6.

In [None]:
model_args = {
    'num_classes': 6,
}

### 3.1.3 - Define training arguments

For the first experiment, we will use the FedAverage algorithm as discussed previously, with a large number of local updates. This parameter is defined in the training arguments.

Define
- batch size 8
- learning rate 0.001
- 50 local updates per round

In [None]:
training_args1 = {
    ??
    'log_interval': 1,
    'test_ratio' : 0.05,
    'test_on_global_updates': True,
}

### 3.1.4 - Create a FL experiment

When creating the FL experiment, we provide the name of the desired aggregator algorithm: FedAverage, and the number of rounds (set it to 5).

After initializing the experiment, check that both nodes have been selected for training.

In [None]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags = ??

exp1 = Experiment(
    ???
                )

### 3.1.5 - Run the FL training experiment

Launch the experiment.

In [None]:
exp1.run()

### 3.1.6 - Plot the losses

Plot the training losses and the accuracy metric. What do you observe ?

In [None]:
plot_perf(exp1, num_rounds_1)

### 3.1.7 - Check FedAverage algorithm

Check that the federated averaging weighted mean formula presented previously is indeed computed.

In [None]:
aggregated_model = exp1.training_plan().model()

In [None]:
replies = exp1.training_replies()
last_round_replies = replies[num_rounds_1-1]
node_model_weights = list()
sample_sizes = list()

example_layer_name = 'conv1.weight'

??

## 3.2 - Experiment 2: FedAverage with large number of rounds and small number of local updates

We will retart an experiment with Federated averaging algorithm, with a small number of local updates and a large number of global updates.

Keep the same arguments as the previous experiment, but set
- 5 local updates per round

### 3.2.1 - Define training arguments

In [None]:
??

### 3.2.2 - Create a FL experiment

This time, set 50 rounds

In [None]:
exp2 = ??

### 3.2.3 - Run the FL training experiment

In [None]:
exp2.run()

### 3.2.4 - Plot the losses

In [None]:
plot_perf(exp2, num_rounds_2)

There's a trade-off between communication cost and model convergence/stability.

In case of homogeneous datasets:
- More local updates (less communication) = faster local progress but risk of divergence.
- More global updates (more communication) = slower local training but better global consistency.


# 4 - Quantity skew datasets

In federated learning, a "quantity skew dataset" refers to a scenario where the amount of data varies significantly across different clients or nodes. This means that some clients have a large volume of data, while others have much less. Such skew can impact the training process, as clients with more data may have a larger influence on the global model, potentially leading to biased or uneven model performance. Addressing quantity skew is important for ensuring that the federated learning model generalizes well across all clients.

In this section of the tutorial, we will investigate the behaviour of several federated learning optimization algorithms with quantity skew datasets.

## 4.1 - Experiment 3: FedAverage

### 4.1.1 - Define training arguments

For the first experiment, we will use the FedAverage algorithm as discussed previously.

In [None]:
training_args3 = {
    'loader_args': { 'batch_size': 8, },
    'optimizer_args': {
        'lr': 1e-3
    },
    'num_updates': 25,
    'dry_run': False,
    'log_interval': 1,
    'test_ratio' : 0.05,
    'test_on_global_updates': True,
}

### 4.1.2 - Create a FL experiment

When creating the FL experiment, we provide the name of the desired aggregator algorithm: FedAverage.

After initializing the experiment, check that both nodes have been selected for training.

In [None]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['mednist_quantity_skew']
num_rounds_3 = 50

exp3 = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=FedAvgTrainingPlan,
                 training_args=training_args3,
                 round_limit=num_rounds_3,
                 aggregator=FedAverage(),
                 tensorboard=True
                )

### 4.1.5 - Run the FL training experiment

Launch the experiment.

In [None]:
exp3.run()

### 3.1.6 - Plot the losses

Plot the training losses and the accuracy metric. What do you observe ?

In [None]:
plot_perf(exp3, num_rounds_3)

In this example, we can observe problem of convergence for the node with less images, and uneven model performance, with lower performances for the node with less images.

Due to the heterogeneity of data across clients, local updates drift away from the global model, leading to slower convergence and reduced performance.

## 4.2 - Experiment 4: FedProx with large mu

In the previous experiment, we saw the limitations of FedAverage algorithm in case of heterogeneous data, in terms of performance and convergence.

Regularization techniques can help the convergence and the performances of training by smoothing the impact of global updates.

In this experiment, we will use a federated learning algorithm with regularization, called [FedProx](https://arxiv.org/pdf/1812.06127).

Federated Proximal (FedProx) is an advanced federated learning algorithm that builds upon Federated Averaging (FedAvg) by incorporating a proximal term. This term helps to stabilize training and improve convergence in heterogeneous settings.

The initialization and global averaging steps are identical to FedAvg algorithm (see the explanation of FedAvg in previous experiments). Only the local node optimization part is modified as followed:

### Local Node Optimization

In each communication round $ t $, a subset of clients $ K $ is selected to participate in the training process. Each selected client $ k $ performs the following steps:

- **Download the Global Model**: Each client $ k $ downloads the current global model weights $ w_t $.

- **Local Training with Proximal Term**: Each client updates the model locally using its own dataset, but with an added proximal term to constrain the local updates. The local objective for client $ k $ is to minimize the following loss function $F_k(w) $:

  $$
  F_k(w) = \frac{1}{n_k} \sum_{i=1}^{n_k} f_k(w; x_i, y_i) + \frac{\mu}{2} \| w - w_t\|^2
  $$

  where:
  - $ n_k $ is the number of data samples on client $ k $.
  - $ f_k(w; x_i, y_i) $ is the loss function for the data sample $ (x_i, y_i) $.
  - $ \mu $ is a hyperparameter that controls the strength of the proximal term.
  - $ \|w - w_t\|^2 $ is the squared Euclidean distance between the local model weights $ w $ and the global model weights $ w_t $.

- **Compute Local Update**: After local training, each client computes the update to the global model. The updated local weights are denoted as $ w_{t+1}^k $.

The proximal term $ \frac{\mu}{2} \|w - w_t\|^2 $ helps to prevent the local models from diverging too far from the global model, thus stabilizing the training process and improving convergence.

- **Handling Data Heterogeneity**: FedProx is particularly effective in handling non-IID data across clients, as the proximal term helps to mitigate the impact of data heterogeneity.

- **Hyperparameter $ \mu $**: The hyperparameter $ \mu $ controls the strength of the proximal term. A higher value of $ \mu $ enforces the local models to stay closer to the global model, while a lower value allows more flexibility in local updates. Selecting an appropriate value for $ \mu $ is crucial. Too high a value can restrict the local models too much, leading to slow convergence, while too low a value may not effectively mitigate client drift.

### Activate FedProx in Fed-BioMed

To enable Fedprox in Fed-BioMed, you need to set a value for the `fedprox_mu` key inside `training_args`.

### 4.2.1 - Define training arguments

Perform an experiment with
- 10 local updates per round
- a value of 1 for the fedprox regularizer

In [None]:
training_args4 = {
    ??
    'log_interval': 1,
    'test_ratio' : 0.05,
    'test_on_global_updates': True,
}

### 4.2.2 - Create a FL experiment

run for 50 rounds

In [None]:
exp4 = ??

### 4.2.3 - Run the FL training experiment

In [None]:
exp4.run()

### 4.2.4 - Plot the losses

In [None]:
plot_perf(exp4, num_rounds_4)

Both clients exhibit a general downward trend in training loss over the rounds, indicating that the model is learning and improving.

The local accuracy for both clients shows improvement over rounds.

As you can see, the addition of the proximal term in FedProx helps stabilize the training process and improve convergence compared to FedAvg, especially in this heterogeneous settings. The fluctuations in loss and accuracy are typically less severe compared to FedAvg, as the proximal term helps to mitigate client drift by constraining local updates.

However, the convergence of the model is quite slow, due to the quite high value of the mu parameter.

## 4.3 - Experiment 5: FedProx with small mu

In this experiment, we will decrease the proximal parameter mu. What behaviour can you expect from a small mu ?

### 4.3.1 - Define training arguments

In [None]:
training_args5 = {
    ??
    'log_interval': 1,
    'test_ratio' : 0.05,
    'test_on_global_updates': True,
}

### 4.3.2 - Create a FL experiment

In [None]:
exp5 = ??

### 4.3.3 - Run the FL training experiment

In [None]:
exp5.run()

### 4.3.4 - Plot the losses

In [None]:
plot_perf(exp5, num_rounds_5)

Here you can observe similar performances as FedAverage, with less convergence and worse accuracy for the node with less images.

# 5 - Label skew datasets

In federated learning, a "label skew dataset" refers to a scenario where the distribution of class labels varies significantly across different clients or nodes. This means that certain classes may be overrepresented in some clients' datasets while being underrepresented or completely absent in others. Label skew in federated learning can lead to biased models, convergence issues, and generalization challenges.

## 5.1 - Experiment 6: FedAverage

We will try to solve the classification problem of MedNIST with FedAverage.

### 5.1.1 - Define training arguments

In [None]:
training_args6 = {
    'loader_args': { 'batch_size': 8, },
    'optimizer_args': {
        'lr': 1e-3
    },
    'num_updates': 10,
    'dry_run': False,
    'log_interval': 1,
    'test_ratio' : 0.05,
    'test_on_global_updates': True,
}

### 5.1.2 - Create a FL experiment

Don't forget to set the proper `tags` and choose a suitable [`aggregator`](https://fedbiomed.org/latest/tutorials/pytorch/04-Aggregation_in_Fed-BioMed/).

In [None]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['mednist_label_skew']
num_rounds_6 = 50

exp6 = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=FedAvgTrainingPlan,
                 training_args=training_args6,
                 round_limit=num_rounds_6,
                 aggregator=FedAverage(),
                 tensorboard=True
                )

### 5.1.3 - Run the FL training experiment

In [None]:
exp6.run()

### 5.1.6 - Plot the losses

In [None]:
plot_perf(exp6, num_rounds_6)

We can observe slow and 'chaotic' convergence. This is due to the fact that each node converges towards its local minimum, but the global minimum is not reached.

## 5.2 - Experiment 7: Scaffold

We will test another algorithm called [SCAFFOLD](https://arxiv.org/abs/1910.06378).

SCAFFOLD (Stochastic Controlled Averaging for Federated Learning) is an advanced federated learning algorithm designed to improve convergence and performance in the presence of data heterogeneity. It introduces correction terms to mitigate the "client-drift" issue that arises due to the variance in local updates.

For the details of the algorithm, check the [original paper](https://arxiv.org/abs/1910.06378).



### 5.2.1 - Define a Training Plan

For using [SCAFFOLD in Fedbiomed](https://fedbiomed.org/latest/user-guide/researcher/aggregation/#scaffold), we need to change the training plan, because as explained previously, this algorithm has optimization part on the researcher side, and on the nodes sides.

Moreover, for using this advanced optimizer, Fedbiomed relies on a library called [Declearn](https://fedbiomed.org/latest/user-guide/advanced-optimization/).


In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan

class ScaffoldTrainingPlan(TorchTrainingPlan):
    class MyModel(torch.nn.Module):
        """definition of a PyTorch model, with its __init__ and forward methods"""
        def __init__(self, model_args: dict):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(16 * 16 * 16, model_args.get('num_classes', 6))

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            return F.log_softmax(x, dim=1)

    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms",
                "from torchvision.transforms import ToTensor",
                "from fedbiomed.common.optimizers.optimizer import Optimizer",
                "from fedbiomed.common.optimizers.declearn import ScaffoldClientModule"
                ]

        return deps

    def init_model(self, model_args: dict):
        """Defines your model here"""
        return self.MyModel(model_args)

    def init_optimizer(self, optimizer_args):
        """Defines your optimizer here"""
        return Optimizer(lr=optimizer_args["lr"],
                         modules=[ScaffoldClientModule()])

    def training_data(self):
        """Defines data handling/parsing here"""
        # Custom torch Dataloader for MedNIST data

        preprocess = transforms.Compose([transforms.ToTensor()])
        train_data = datasets.ImageFolder(self.dataset_path,transform = preprocess)

        return DataManager(dataset=train_data, shuffle=True)

    def training_step(self, data, target):
        """Defines cost function and how to compute loss"""
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss

### 5.2.2 - Define model parameters

Here you can define `model_args`, a dictionary that contain parameters and hyperparameters for model definition.

In [None]:
model_args = {
    'num_classes': 6,
}

### 5.2.3 - Define training arguments

In [None]:
training_args7 = {
    'loader_args': { 'batch_size': 8, },
    'optimizer_args': {
        'lr': 1e-3
    },
    'num_updates': 10,
    'dry_run': False,
    'log_interval': 1,
    'test_ratio' : 0.05,
    'test_on_global_updates': True,
}

### 5.2.4 - Create a FL experiment

Don't forget to set the proper `tags` and choose a suitable [`aggregator`](https://fedbiomed.org/latest/tutorials/pytorch/04-Aggregation_in_Fed-BioMed/).

In [None]:
tags =  ['mednist_label_skew']
num_rounds_7 = 50

exp7 = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=ScaffoldTrainingPlan,
                 training_args=training_args7,
                 round_limit=num_rounds_7,
                 aggregator=FedAverage(),
                 tensorboard=True
                )

### 5.2.5 - Define the server optimizer for SCAFFOLD

In [None]:
from fedbiomed.common.optimizers import Optimizer
from fedbiomed.common.optimizers.declearn import ScaffoldServerModule

scaffold_opt = Optimizer(lr=.8, modules=[ScaffoldServerModule()])
exp7.set_agg_optimizer(scaffold_opt)

### 5.2.6 - Run the FL training experiment

In [None]:
exp7.run()

### 5.2.7 - Plot the losses

In [None]:
plot_perf(exp7, num_rounds_7)

The convergence of SCAFFOLD is slower, but can be more homogeneous.

# 6 - Optional exercices

- Look for the other [optimizers](https://fedbiomed.org/latest/user-guide/advanced-optimization/) available in Fedbiomed
- Play around with [batches](https://fedbiomed.org/latest/user-guide/researcher/experiment/#sub-arguments-for-optimizer-and-differential-privacy) instead of num_updates
- Try to find an optimal mu in fedprox (for several mu, try to optimize both time of convergence and loss)
- Investigate how you can implement a [custom aggregator](https://fedbiomed.org/latest/user-guide/researcher/aggregation/) in fedbiomed
