# Performing Validation at Each Round of Training 

Use for developing (autoreloads changes made across packages)

In [None]:
%load_ext autoreload
%autoreload 2

## Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

## Setting the node up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 2 (default) to add MNIST to the node
  * Confirm default tags by hitting "y" and ENTER
  * Pick the folder where MNIST is downloaded (this is due torch issue https://github.com/pytorch/vision/issues/3549)
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node start`. Wait until you get `Starting task manager`. it means you are online.

## 1. Validating Pytorch Model Using Predefined Evalution Metrics at each Round of Federeated Training

Declare a torch training plan MyTrainingPlan class to send for training on the node.

In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms


# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlan(TorchTrainingPlan):
    
    # Defines and return model 
    def init_model(self, model_args):
        return self.Net(model_args = model_args)
    
    # Defines and return optimizer
    def init_optimizer(self, optimizer_args):
        return torch.optim.Adam(self.model().parameters(), lr = optimizer_args["lr"])
    
    # Declares and return dependencies
    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms"]
        return deps
    
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)


            output = F.log_softmax(x, dim=1)
            return output

    def training_data(self):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=dataset1, **train_kwargs)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


### 3.1 Declare and run the experiment
The model is trained on the **MNIST dataset** for classification. For validation, we will be using the **F1-Score**  as a metric. Validation will be performed on both **local updates and global updates**.

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage


model_args = {}

training_args = {
    'loader_args': { 'batch_size': 48, }, 
    'optimizer_args': {
        "lr" : 1e-3
    },
    'epochs': 1, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}


tags =  ['#MNIST', '#dataset']
rounds = 2

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None,
                tensorboard=True)

#### Declaring Validation Arguments 

- **test_ratio:** The ratio for validation partition 
- **test_metric:** The metric that is going to be used for validation
- **Validation on local updates:** Means that validation is going to be perform after training is performed over aggreated paramaters  
- **Validation on global updates**: Means that validation will be perform on aggregated parameters before performing the training. 


You can display all the default metrics that are supported in Fed-BioMed. They are all based on sklearn metrics

In [None]:
from fedbiomed.common.metrics import MetricTypes
MetricTypes.get_all_metrics()

In [None]:
exp.set_test_ratio(0.1)
exp.set_test_on_local_updates(True)
exp.set_test_on_global_updates(True)
exp.set_test_metric(MetricTypes.F1_SCORE)

Launch tensorboard

In [None]:
from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']

In [None]:
%load_ext tensorboard

In [None]:
tensorboard --logdir "$tensorboard_dir"

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [None]:
exp.run()

2023-09-05 14:43:51,807 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 14:43:51,809 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 14:43:51,811 fedbiomed INFO - Waiting for tasks
2023-09-05 14:43:51,815 fedbiomed INFO - Received request form node_41533df5-d07b-4027-a826-d1f67410d627
2023-09-05 14:43:51,816 fedbiomed INFO - Node agent created node_41533df5-d07b-4027-a826-d1f67410d627
2023-09-05 14:43:51,817 fedbiomed INFO - Waiting for tasks
2023-09-05 14:43:55,859 fedbiomed INFO - Received request form node_41533df5-d07b-4027-a826-d1f67410d627
2023-09-05 14:43:55,863 fedbiomed INFO - Node agent created node_41533df5-d07b-4027-a826-d1f67410d627
2023-09-05 14:43:55,869 fedbiomed INFO - Waiting for tasks
2023-09-05 14:44:51,812 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 14:44:51,814 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759

2023-09-05 14:58:53,884 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 14:58:53,886 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 14:58:53,889 fedbiomed INFO - Waiting for tasks
2023-09-05 14:58:55,895 fedbiomed INFO - Received request form node_41533df5-d07b-4027-a826-d1f67410d627
2023-09-05 14:58:55,898 fedbiomed INFO - Node agent created node_41533df5-d07b-4027-a826-d1f67410d627
2023-09-05 14:58:55,901 fedbiomed INFO - Waiting for tasks
2023-09-05 14:59:53,887 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 14:59:53,888 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 14:59:53,891 fedbiomed INFO - Waiting for tasks
2023-09-05 14:59:55,898 fedbiomed INFO - Received request form node_41533df5-d07b-4027-a826-d1f67410d627
2023-09-05 14:59:55,900 fedbiomed INFO - Node agent created node_41533df5-d07b-4027-a826-d1f67410

2023-09-05 15:15:24,180 fedbiomed INFO - Waiting for tasks
2023-09-05 15:15:25,446 fedbiomed INFO - [1mCRITICAL[0m
					[1m NODE[0m node_41533df5-d07b-4027-a826-d1f67410d627
					[1m MESSAGE:[0m Node stopped in signal_handler, probably by user decision (Ctrl C)[0m
-----------------------------------------------------------------
2023-09-05 15:15:25,663 fedbiomed INFO - [1mCRITICAL[0m
					[1m NODE[0m node_41533df5-d07b-4027-a826-d1f67410d627
					[1m MESSAGE:[0m Node stopped in signal_handler, probably by user decision (Ctrl C)[0m
-----------------------------------------------------------------
2023-09-05 15:15:53,953 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:15:53,955 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:15:53,957 fedbiomed INFO - Waiting for tasks
2023-09-05 15:16:53,958 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 1

2023-09-05 15:37:54,048 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:37:54,052 fedbiomed INFO - Waiting for tasks
2023-09-05 15:38:54,043 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:38:54,044 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:38:54,044 fedbiomed INFO - Waiting for tasks
2023-09-05 15:39:54,050 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:39:54,052 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:39:54,054 fedbiomed INFO - Waiting for tasks
2023-09-05 15:40:54,054 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:40:54,055 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 15:40:54,057 fedbiomed INFO - Waiting for tasks
2023-09-05 15:41:54,056 fedbiomed INFO - 

2023-09-05 16:08:54,171 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:08:54,174 fedbiomed INFO - Waiting for tasks
2023-09-05 16:09:54,175 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:09:54,178 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:09:54,184 fedbiomed INFO - Waiting for tasks
2023-09-05 16:10:54,180 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:10:54,183 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:10:54,185 fedbiomed INFO - Waiting for tasks
2023-09-05 16:11:54,182 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:11:54,184 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:11:54,185 fedbiomed INFO - Waiting for tasks
2023-09-05 16:12:54,195 fedbiomed INFO - 

2023-09-05 16:39:54,300 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:39:54,301 fedbiomed INFO - Waiting for tasks
2023-09-05 16:40:54,303 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:40:54,304 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:40:54,304 fedbiomed INFO - Waiting for tasks
2023-09-05 16:41:54,308 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:41:54,310 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:41:54,314 fedbiomed INFO - Waiting for tasks
2023-09-05 16:42:54,314 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:42:54,318 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 16:42:54,320 fedbiomed INFO - Waiting for tasks
2023-09-05 16:43:54,320 fedbiomed INFO - 

2023-09-05 17:10:54,440 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 17:10:54,441 fedbiomed INFO - Waiting for tasks
2023-09-05 17:11:54,443 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 17:11:54,445 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 17:11:54,448 fedbiomed INFO - Waiting for tasks
2023-09-05 17:12:54,447 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 17:12:54,448 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 17:12:54,449 fedbiomed INFO - Waiting for tasks
2023-09-05 17:13:54,451 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 17:13:54,452 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-05 17:13:54,453 fedbiomed INFO - Waiting for tasks
2023-09-05 17:14:54,457 fedbiomed INFO - 

2023-09-06 07:06:56,455 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:06:56,455 fedbiomed INFO - Waiting for tasks
2023-09-06 07:07:56,456 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:07:56,456 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:07:56,457 fedbiomed INFO - Waiting for tasks
2023-09-06 07:08:56,458 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:08:56,459 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:08:56,459 fedbiomed INFO - Waiting for tasks
2023-09-06 07:09:56,459 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:09:56,459 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:09:56,460 fedbiomed INFO - Waiting for tasks
2023-09-06 07:10:56,461 fedbiomed INFO - 

2023-09-06 07:38:02,586 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:38:02,588 fedbiomed INFO - Waiting for tasks
2023-09-06 07:39:02,589 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:39:02,592 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:39:02,595 fedbiomed INFO - Waiting for tasks
2023-09-06 07:40:02,592 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:40:02,594 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:40:02,596 fedbiomed INFO - Waiting for tasks
2023-09-06 07:41:02,598 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:41:02,600 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 07:41:02,603 fedbiomed INFO - Waiting for tasks
2023-09-06 07:42:02,600 fedbiomed INFO - 

2023-09-06 08:13:28,167 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 08:13:28,168 fedbiomed INFO - Waiting for tasks
2023-09-06 08:14:28,172 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 08:14:28,174 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 08:14:28,177 fedbiomed INFO - Waiting for tasks
2023-09-06 08:15:28,175 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 08:15:28,177 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 08:15:28,180 fedbiomed INFO - Waiting for tasks
2023-09-06 08:16:28,180 fedbiomed INFO - Received request form node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 08:16:28,184 fedbiomed INFO - Node agent created node_415010f8-df19-4a90-8059-cf76a759d3f5
2023-09-06 08:16:28,186 fedbiomed INFO - Waiting for tasks
2023-09-06 08:17:28,184 fedbiomed INFO - 



## 2. Training and Validation with sklearn Perceptron model


Now we will use the validation facility on Skelearn training plan

In [None]:
from fedbiomed.common.training_plans import FedPerceptron


class SkLearnClassifierTrainingPlan(FedPerceptron):
    def init_dependencies(self):
        return ["from torchvision import datasets, transforms",]

    def training_data(self):
        # Custom torch Dataloader for MNIST data: np.ndarray
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        
        X_train = dataset.data.numpy()
        X_train = X_train.reshape(-1, 28*28)
        Y_train = dataset.targets.numpy()
        
        return DataManager(dataset=X_train,target=Y_train,  shuffle=True)

It is also possible to define validation option in the training arguments. 

In [None]:
model_args = { 'max_iter':1000,
              'tol': 1e-4 ,
              'model': 'Perceptron' ,
              'n_features': 28*28,
              'n_classes' : 10,
              'eta0':1e-6,
              'random_state':1234,
              'alpha':0.1 }

training_args = {
    'epochs': 5, 
}

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 10

# select nodes participing to this experiment
exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=SkLearnClassifierTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None, 
                 tensorboard=True)


exp.set_test_ratio(.2)
#exp.set_test_metric(MetricTypes.PRECISION, average='macro')
exp.set_test_on_global_updates(True)

In [None]:
exp.run_once(increase=True)

Feel free to run other sample notebooks or try your own models :D

# 3. Validation facility using your own testing metric

If the user wants to define its own testing metric, he can do so by defining the `testing_step` method in the Training plan. 

`testing_step` is defined the same way as `training_step`:

When defining a `testing_step` method in the TrainingPlan, user has to:
- predict classes or probabilities from model
- compute a scalar or a list of scalars

Method `testing_step` can return either a scalar or a list of scalars: in Tensorboard, list of scalars will be seen as the output of several metrics


## 3.1 PyTorch Training Plan

Below we showcase an example of a TorchTrainingPlan with a `testing_step` computing 3 metrics: log likelihood loss, a cross entropy loss, and a custom accuracy metric 

In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MyTrainingPlanCM(TorchTrainingPlan):
    
    # Defines and return model 
    def init_model(self, model_args):
        return self.Net(model_args = model_args)
    
    # Defines and return optimizer
    def init_optimizer(self, optimizer_args):
        return torch.optim.Adam(self.model().parameters(), lr = optimizer_args["lr"])
    
    # Declares and return dependencies
    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms"]
        return deps
    
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)


            output = F.log_softmax(x, dim=1)
            return output

    def training_data(self):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=dataset1, **train_kwargs)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss

    def testing_step(self, data, target):
        output = self.model().forward(data)

        #negative log likelihood loss
        loss1 = torch.nn.functional.nll_loss(output, target)

        #cross entropy
        loss2 = torch.nn.functional.cross_entropy(output,target)

        # accuracy
        _,predicted = torch.max(output.data,1)
        acc = torch.sum(predicted==target)
        loss3 = acc/len(target)

        # Returning results as list
        return [loss1,loss2,loss3]

In [None]:
model_args = {}

training_args = {
    'loader_args': { 'batch_size': 48, }, 
    'optimizer_args': {
        'lr': 1e-3,   
    },
    'epochs': 1, 
    'dry_run': False,  
    'batch_maxnum': 100, # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
    'test_ratio': .3,
    'test_on_local_updates': True, 
    'test_on_global_updates': True
}

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 2

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=MyTrainingPlanCM,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None, 
                 tensorboard=True)

In [None]:
exp.run()

## 3.2 Sklearn Training Plan

Below we showcase an example of a SklearnTrainingPlan with a `testing_step` computing several metrics

In [None]:
from fedbiomed.common.training_plans import FedPerceptron
from fedbiomed.common.data import DataManager
import numpy as np
from sklearn.metrics import hinge_loss


class SkLearnClassifierTrainingPlan(FedPerceptron):
    def init_dependencies(self):
        return ["from torchvision import datasets, transforms",
                "from torch.utils.data import DataLoader",
                "from sklearn.metrics import hinge_loss"]

    def compute_accuracy_for_specific_digit(self, data, target, digit: int):
        idx_data_equal_to_digit = (target.squeeze() == digit)
        
        predicted = self.model().predict(data[idx_data_equal_to_digit])
        well_predicted_label = np.sum(predicted == digit) / np.sum(idx_data_equal_to_digit)
        return well_predicted_label
    
    def training_data(self):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        
        train_kwargs = { 'shuffle': True}  # number of data passed to classifier
        X_train = dataset.data.numpy()
        X_train = X_train.reshape(-1, 28*28)
        Y_train = dataset.targets.numpy()
        
        return DataManager(dataset=X_train, target=Y_train)
    
    def testing_step(self, data, target):
        # hinge loss
        distance_from_hyperplan = self.model().decision_function(data)
        loss = hinge_loss(target, distance_from_hyperplan)
        
        # get the accuracy only on images representing digit 1
        well_predicted_label_1 = self.compute_accuracy_for_specific_digit(data, target, 1)
        
        # Returning results as dict
        return {'Hinge Loss': loss, 'Well Predicted Label 1' : well_predicted_label_1}

In [None]:
model_args = { 'max_iter':1000,
              'tol': 1e-4 ,
              'model': 'Perceptron' ,
              'n_features': 28*28,
              'n_classes' : 10,
              'eta0':1e-6,
              'random_state':1234,
              'alpha':0.1 }

training_args = {
    'epochs': 5, 
}


In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 10

# select nodes participing to this experiment
exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=SkLearnClassifierTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None, 
                 tensorboard=True)


exp.set_test_ratio(.2)
#exp.set_test_metric(MetricTypes.PRECISION, average='macro')
exp.set_test_on_global_updates(True)
exp.set_test_on_local_updates(True)

In [None]:
exp.run(increase=True)