# Fed-BioMed Researcher base example with FedProx

In this notebook we show how to use the FedProx (https://proceedings.mlsys.org/paper/2020/file/38af86134b65d0f10fe33d30dd76442e-Paper.pdf) optimization scheme in the basic pytorch example with MNIST dataset. FedProx adress the problem of heterogeneity across datasets by optimizing a regularized loss with a proximal term enforcing local parameters to remain closer to the latest optimized aggregated parameters.

This example uses MNIST dataset. Please check `README.md` file in `notebooks` directory for the instructions to load MNIST dataset and configure nodes.

Check available clients:

In [1]:
from fedbiomed.researcher.requests import Requests
from fedbiomed.researcher.config import config
req = Requests(config)
req.list(verbose=True)

2025-07-31 14:26:31,008 fedbiomed INFO - Starting researcher service...

2025-07-31 14:26:31,022 fedbiomed INFO - Waiting 3s for nodes to connect...

2025-07-31 14:26:32,139 fedbiomed DEBUG - Node: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 polling for the tasks

2025-07-31 14:26:34,109 fedbiomed DEBUG - Node: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 polling for the tasks

2025-07-31 14:26:34,114 fedbiomed INFO - 
 Node: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 | Number of Datasets: 1 
+--------+-------------+------------------------+----------------+--------------------+----------------------------------------------+----------------------+
| name   | data_type   | tags                   | description    | shape              | dataset_id                                   | dataset_parameters   |
| MNIST  | default     | ['#MNIST', '#dataset'] | MNIST database | [60000, 1, 28, 28] | dataset_bd9de2a0-25e4-4241-9706-2a2422e29d77 |                      |
+--------+-------------+------------------------+----------------+--------------------+----------------------------------------------+----------------------+


{'NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57': [{'name': 'MNIST',
   'data_type': 'default',
   'tags': ['#MNIST', '#dataset'],
   'description': 'MNIST database',
   'shape': [60000, 1, 28, 28],
   'dataset_id': 'dataset_bd9de2a0-25e4-4241-9706-2a2422e29d77',
   'dataset_parameters': None}]}

## Define an experiment model and parameters

Declare a torch training plan MyTrainingPlan class to send for training on the node

Note : write **only** the code to export in the following cell

In [2]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms


# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlan(TorchTrainingPlan):
    
    # Defines and return model 
    def init_model(self, model_args):
        return self.Net(model_args = model_args)
    
    # Defines and return optimizer
    def init_optimizer(self, optimizer_args):
        return torch.optim.Adam(self.model().parameters(), lr = optimizer_args["lr"])
    
    # Declares and return dependencies
    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms"]
        return deps
    
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)


            output = F.log_softmax(x, dim=1)
            return output

    def training_data(self):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=dataset1, **train_kwargs)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. 
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.
If FedProx optimisation is requested, `fedprox_mu` parameter must be defined here. It also must be a float between XX and YY.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [3]:
model_args = {}

training_args = {
    'loader_args': { 'batch_size': 48, },
    'optimizer_args': {
        'lr': 1e-3,
    },
    'fedprox_mu': 0.01, 
    'epochs': 1, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `round_limit` rounds, applying the `node_selection_strategy` between the rounds

In [4]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 3

exp = Experiment(tags=tags,
                 training_plan_class=MyTrainingPlan,
                 model_args=model_args,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2025-07-31 14:26:34,170 fedbiomed INFO - Updating training data. This action will update FederatedDataset, and the nodes that will participate to the experiment.

2025-07-31 14:26:34,230 fedbiomed DEBUG - Node: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 polling for the tasks

2025-07-31 14:26:34,232 fedbiomed INFO - Node selected for training -> NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57

<function extract_symbols at 0x7f80efd67b50>


2025-07-31 14:26:34,239 fedbiomed DEBUG - Model file has been saved: /home/gersa/fedbiomed-dcm/Modulo_DICOM/fbm-researcher/var/experiments/Experiment_0020/model_e0fe74c2-4ff4-4a73-a318-2b8d823eeb10.py

Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_mode`` turned on.


2025-07-31 14:26:34,256 fedbiomed INFO - Removing tensorboard logs from previous experiment

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [5]:
exp.run()

2025-07-31 14:26:34,271 fedbiomed INFO - Sampled nodes in round 0 ['NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57']

<function extract_symbols at 0x7f80efd67b50>


2025-07-31 14:26:34,278 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2025-07-31 14:26:34,332 fedbiomed DEBUG - Node: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 polling for the tasks

2025-07-31 14:26:35,031 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 1/100 (1%) | Samples: 48/4800
 					 Loss: [1m2.317689[0m 
					 ---------

2025-07-31 14:26:35,614 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 10/100 (10%) | Samples: 480/4800
 					 Loss: [1m1.534979[0m 
					 ---------

2025-07-31 14:26:36,147 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 20/100 (20%) | Samples: 960/4800
 					 Loss: [1m0.976598[0m 
					 ---------

2025-07-31 14:26:36,613 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 30/100 (30%) | Samples: 1440/4800
 					 Loss: [1m0.664990[0m 
					 ---------

2025-07-31 14:26:37,155 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 40/100 (40%) | Samples: 1920/4800
 					 Loss: [1m0.807914[0m 
					 ---------

2025-07-31 14:26:37,721 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 50/100 (50%) | Samples: 2400/4800
 					 Loss: [1m0.444567[0m 
					 ---------

2025-07-31 14:26:38,216 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 60/100 (60%) | Samples: 2880/4800
 					 Loss: [1m0.628533[0m 
					 ---------

2025-07-31 14:26:38,679 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 70/100 (70%) | Samples: 3360/4800
 					 Loss: [1m0.397321[0m 
					 ---------

2025-07-31 14:26:39,097 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 80/100 (80%) | Samples: 3840/4800
 					 Loss: [1m0.577502[0m 
					 ---------

2025-07-31 14:26:39,554 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 90/100 (90%) | Samples: 4320/4800
 					 Loss: [1m0.261285[0m 
					 ---------

2025-07-31 14:26:39,969 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 1 Epoch: 1 | Iteration: 100/100 (100%) | Samples: 4800/4800
 					 Loss: [1m0.494717[0m 
					 ---------

2025-07-31 14:26:40,280 fedbiomed INFO - Nodes that successfully reply in round 0 ['NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57']

2025-07-31 14:26:40,292 fedbiomed INFO - Sampled nodes in round 1 ['NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57']

<function extract_symbols at 0x7f80efd67b50>


2025-07-31 14:26:40,297 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2025-07-31 14:26:40,342 fedbiomed DEBUG - Node: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 polling for the tasks

2025-07-31 14:26:40,533 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 1/100 (1%) | Samples: 48/4800
 					 Loss: [1m0.273614[0m 
					 ---------

2025-07-31 14:26:40,899 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 10/100 (10%) | Samples: 480/4800
 					 Loss: [1m0.332518[0m 
					 ---------

2025-07-31 14:26:41,239 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 20/100 (20%) | Samples: 960/4800
 					 Loss: [1m0.398060[0m 
					 ---------

2025-07-31 14:26:41,575 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 30/100 (30%) | Samples: 1440/4800
 					 Loss: [1m0.364511[0m 
					 ---------

2025-07-31 14:26:41,940 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 40/100 (40%) | Samples: 1920/4800
 					 Loss: [1m0.311296[0m 
					 ---------

2025-07-31 14:26:42,296 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 50/100 (50%) | Samples: 2400/4800
 					 Loss: [1m0.284638[0m 
					 ---------

2025-07-31 14:26:42,633 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 60/100 (60%) | Samples: 2880/4800
 					 Loss: [1m0.197522[0m 
					 ---------

2025-07-31 14:26:42,973 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 70/100 (70%) | Samples: 3360/4800
 					 Loss: [1m0.534620[0m 
					 ---------

2025-07-31 14:26:43,335 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 80/100 (80%) | Samples: 3840/4800
 					 Loss: [1m0.255883[0m 
					 ---------

2025-07-31 14:26:43,759 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 90/100 (90%) | Samples: 4320/4800
 					 Loss: [1m0.140865[0m 
					 ---------

2025-07-31 14:26:44,130 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 2 Epoch: 1 | Iteration: 100/100 (100%) | Samples: 4800/4800
 					 Loss: [1m0.168257[0m 
					 ---------

2025-07-31 14:26:44,442 fedbiomed INFO - Nodes that successfully reply in round 1 ['NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57']

2025-07-31 14:26:44,452 fedbiomed INFO - Sampled nodes in round 2 ['NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57']

<function extract_symbols at 0x7f80efd67b50>


2025-07-31 14:26:44,458 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2025-07-31 14:26:44,564 fedbiomed DEBUG - Node: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 polling for the tasks

2025-07-31 14:26:44,793 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 1/100 (1%) | Samples: 48/4800
 					 Loss: [1m0.201485[0m 
					 ---------

2025-07-31 14:26:45,182 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 10/100 (10%) | Samples: 480/4800
 					 Loss: [1m0.172084[0m 
					 ---------

2025-07-31 14:26:45,593 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 20/100 (20%) | Samples: 960/4800
 					 Loss: [1m0.199399[0m 
					 ---------

2025-07-31 14:26:46,089 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 30/100 (30%) | Samples: 1440/4800
 					 Loss: [1m0.159567[0m 
					 ---------

2025-07-31 14:26:46,533 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 40/100 (40%) | Samples: 1920/4800
 					 Loss: [1m0.064499[0m 
					 ---------

2025-07-31 14:26:46,998 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 50/100 (50%) | Samples: 2400/4800
 					 Loss: [1m0.234464[0m 
					 ---------

2025-07-31 14:26:47,381 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 60/100 (60%) | Samples: 2880/4800
 					 Loss: [1m0.199160[0m 
					 ---------

2025-07-31 14:26:47,831 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 70/100 (70%) | Samples: 3360/4800
 					 Loss: [1m0.114253[0m 
					 ---------

2025-07-31 14:26:48,290 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 80/100 (80%) | Samples: 3840/4800
 					 Loss: [1m0.220274[0m 
					 ---------

2025-07-31 14:26:48,773 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 90/100 (90%) | Samples: 4320/4800
 					 Loss: [1m0.177317[0m 
					 ---------

2025-07-31 14:26:49,225 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 
					 Round 3 Epoch: 1 | Iteration: 100/100 (100%) | Samples: 4800/4800
 					 Loss: [1m0.336939[0m 
					 ---------

2025-07-31 14:26:49,518 fedbiomed INFO - Nodes that successfully reply in round 2 ['NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57']

3

Save trained model to file

In [6]:
exp.training_plan().export_model('./trained_model')

Local training results for each round and each node are available via `exp.training_replies()` (index 0 to (`rounds` - 1) ).

For example you can view the training results for the last round below.

Different timings (in seconds) are reported for each dataset of a node participating in a round :
- `rtime_training` real time (clock time) spent in the training function on the node
- `ptime_training` process time (user and system CPU) spent in the training function on the node
- `rtime_total` real time (clock time) spent in the researcher between sending the request and handling the response, at the `Job()` layer

In [7]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1]
for r in round_data.values():
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = r['node_id'],
        rtraining = r['timing']['rtime_training'],
        ptraining = r['timing']['ptime_training'],
        rtotal = r['timing']['rtime_total']))
print('\n')



List the training rounds :  dict_keys([0, 1, 2])

List the nodes for the last training round and their timings : 
	- NODE_8761223e-feb5-4fb9-9f0b-9238a6db2a57 :    
		rtime_training=4.49 seconds    
		ptime_training=35.65 seconds    
		rtime_total=5.05 seconds




Federated parameters for each round are available via `exp.aggregated_params()` (index 0 to (`rounds` - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [8]:
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())



List the training rounds :  dict_keys([0, 1, 2])

Access the federated params for the last training round :
	- parameter data:  dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])


Feel free to run other sample notebooks or try your own models :D