# Fedbiomed Researcher - Saving and Loading breakpoints

## Setting the client up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 2 (default) to add MNIST to the client
  * Confirm default tags by hitting "y" and ENTER
  * Pick the folder where MNIST is downloaded (this is due torch issue https://github.com/pytorch/vision/issues/3549)
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node run`. Wait until you get `Connected with result code 0`. it means you are online.

## Create an experiment to train a model on the data found

Declare a torch.nn MyTrainingPlan class to send for training on the node

In [1]:
from fedbiomed.researcher.environ import TMP_DIR
import tempfile
tmp_dir_model = tempfile.TemporaryDirectory(dir=TMP_DIR+'/')
model_file = tmp_dir_model.name + '/class_export_mnist.py'

Note : write **only** the code to export in the following cell

In [2]:
%%writefile "$model_file"

import torch
import torch.nn as nn
from fedbiomed.common.torchnn import TorchTrainingPlan
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlan(TorchTrainingPlan):
    def __init__(self):
        super(MyTrainingPlan, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        # In this case, we need the torch DataLoader classes
        # Since we will train on MNIST, we need datasets and transform from torchvision
        deps = ["from torchvision import datasets, transforms",
               "from torch.utils.data import DataLoader"]
        self.add_dependency(deps)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

    def training_data(self, batch_size = 48):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        data_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
        return data_loader
    
    def training_step(self, data, target):
        output = self.forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


Writing /home/ybouilla/fedbiomed/var/tmp/tmpx2s4it24/class_export_mnist.py


This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the client side.
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the client side.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [3]:
model_args = {}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'epochs': 1, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

Define an experiment with saved breakpoints
- search nodes serving data for these `tags`, optionally filter on a list of client ID with `clients`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `rounds` rounds, applying the `client_selection_strategy` between the rounds
- specify `save_breakpoints` for saving breakpoint at the end of each round. Breakpoints will be saved under `Experiment_xx` folder at `fedbiomed/var/breakpoints`

In [4]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 5

exp = Experiment(tags=tags,
                 #clients=None,
                 model_path=model_file,
                 model_args=model_args,
                 model_class='MyTrainingPlan',
                 training_args=training_args,
                 rounds=rounds,
                 aggregator=FedAverage(),
                 client_selection_strategy=None,
                 save_breakpoints=True)

2021-10-19 18:46:58,612 fedbiomed INFO - Messaging researcher_5969754c-72e2-477e-bd27-1a3648f445c4 successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x7f1f2f842700>
2021-10-19 18:46:58,659 fedbiomed INFO - Searching dataset with data tags: ['#MNIST', '#dataset'] for all nodes


TIMEOUT =  5


2021-10-19 18:47:03,665 fedbiomed INFO - No available dataset has found in nodes with tags: ['#MNIST', '#dataset']
2021-10-19 18:47:03,918 fedbiomed INFO - Messaging NodeTrainingFeedbackClient successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x7f1f2f7e0880>


filename /home/ybouilla/fedbiomed/var/tmpt95vqg_g/my_model_0e447078-c864-49ac-8716-467b16cf9cd4.py


In [5]:
exp.run()


2021-10-19 18:47:05,166 fedbiomed INFO - Sampled clients in round 0 []
2021-10-19 18:47:05,170 fedbiomed INFO - Clients that successfully reply in round 0 []


AssertionError: An empty list of models was passed.

## Resume an experiment

While experiment is running, you can shut it down (after the first round) and resume the experiment from the following cell. Or wait for the experiment completion.


**To load the latest breakpoint of the latest experiment**

Run :
`Experiment.load_breakpoint()`. It reloads latest breakpoint, and will bypass `search` method

and then use `.run` method as you would do with an existing experiment.

**To load a specific breakpoint** specify breakpoint folder when using `Experiment.load_breakpoint("breakpoints/Experiment_xx/breakpoint_yy)`. Replace `xx` and `yy` by the real values.

In [None]:
from fedbiomed.researcher.experiment import Experiment

loaded_exp = Experiment.load_breakpoint()
loaded_exp.run()

In [None]:
print("\nList the training rounds : ", exp.training_replies.keys())

print("\nList the clients for the last training round and their timings : ")
round_data = exp.training_replies[rounds - 1].data
for c in range(len(round_data)):
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['client_id'],
        rtraining = round_data[c]['timing']['rtime_training'],
        ptraining = round_data[c]['timing']['ptime_training'],
        rtotal = round_data[c]['timing']['rtime_total']))
print('\n')
    
exp.training_replies[rounds - 1].dataframe