# Using Differential Privacy with OPACUS on Fed-BioMed

In this notebook we show how `opacus` (https://opacus.ai/) can be used in Fed-BioMed. Opacus is a library which allows to train PyTorch models with differential privacy. We will train the basic MNIST example using two nodes.

### Setting the node up
It is necessary to previously configure a node:
1. `{FEDBIOMED_DIR}/scripts/fedbiomed_run node dataset add`
  * Select option 2 (default)
  * Confirm default tags by hitting "y" and ENTER
  * Pick the folder where MNIST is downloaded (this is due torch issue https://github.com/pytorch/vision/issues/3549)
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `{FEDBIOMED_DIR}/scripts/fedbiomed_run node dataset list`
3. Run the node using `{FEDBIOMED_DIR}/scripts/fedbiomed_run node start`. Wait until you get `Starting task manager`. it means you are online.

## Defining a Training Plan and Parameters

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms

# Here we define the training plan to be used in the experiment. 
class MyTrainingPlan(TorchTrainingPlan):
    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms",
                "import torch.nn.functional as F"]
        
        return deps
    
    def init_model(self):
        model = nn.Sequential(nn.Conv2d(1, 32, 3, 1),
                                  nn.ReLU(),
                                  nn.Conv2d(32, 64, 3, 1),
                                  nn.ReLU(),
                                  nn.MaxPool2d(2),
                                  nn.Dropout(0.25),
                                  nn.Flatten(),
                                  nn.Linear(9216, 128),
                                  nn.ReLU(),
                                  nn.Dropout(0.5),
                                  nn.Linear(128, 10),
                                  nn.LogSoftmax(dim=1))
        return model
    

    
    def training_data(self):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        loader_arguments = { 'shuffle': True}
        return DataManager(dataset1, **loader_arguments)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. For instance, the privacy parameters should be passed here.
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In the cell below, we are going to define `dp_args` inside the `training_args` dictionary. Based on the given paremeters node will perform Opacus's differeantal privacy. 

* `noise_multiplier` - `sigma`: The ratio of the standard deviation of the Gaussian noise to the L2-sensitivity of the function to which the noise is added (How much noise to add)

* `max_grad_norm` - `clip`: The maximum norm of the per-sample gradients. Any gradient with norm higher than this will be clipped to this value.

* `type`: Differential privacy type as one of `local` or `central`

In [2]:
model_args = {}

training_args = {
    'loader_args': { 'batch_size': 48, },
    'optimizer_args': {
        'lr': 1e-3
    },
    'epochs': 1, 
    'dry_run': False, 
    'dp_args': # DP Arguments for differential privacy
        {
            "type": "local", 
            "sigma": 0.4, 
            "clip": 0.005
        },
    'batch_maxnum': 50 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

## Declare and run the experiment

In [3]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 3

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2024-04-03 13:33:16,159 fedbiomed INFO - Starting researcher service...

2024-04-03 13:33:16,183 fedbiomed INFO - Waiting 3s for nodes to connect...

2024-04-03 13:33:17,383 fedbiomed DEBUG - Node: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b polling for the tasks

2024-04-03 13:33:17,385 fedbiomed DEBUG - Node: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac polling for the tasks

2024-04-03 13:33:19,187 fedbiomed INFO - Updating training data. This action will update FederatedDataset, and the nodes that will participate to the experiment.

2024-04-03 13:33:19,199 fedbiomed DEBUG - Node: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac polling for the tasks

2024-04-03 13:33:19,200 fedbiomed DEBUG - Node: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b polling for the tasks

2024-04-03 13:33:19,203 fedbiomed INFO - Node selected for training -> NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b

2024-04-03 13:33:19,204 fedbiomed INFO - Node selected for training -> NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac

2024-04-03 13:33:19,208 fedbiomed DEBUG - Model file has been saved: /home/ybouilla/Documents/github/fedbiomed/var/experiments/Experiment_0015/model_e02e0d96-a047-45b6-9c0a-8682f79247f9.py

Secure RNG turned off. This is perfectly fine for experimentation as it allows for much faster training performance, but remember to turn it on and retrain one last time before production with ``secure_mode`` turned on.


Let's start the experiment.

By default, this function doesn't stop until all the `rounds` are done for all the nodes

In [4]:
exp.run()

2024-04-03 13:33:19,226 fedbiomed INFO - Sampled nodes in round 0 ['NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b', 'NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac']

2024-04-03 13:33:19,229 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-04-03 13:33:19,231 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-04-03 13:33:19,287 fedbiomed DEBUG - Node: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac polling for the tasks

2024-04-03 13:33:19,288 fedbiomed DEBUG - Node: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b polling for the tasks

2024-04-03 13:33:21,098 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 1 Epoch: 1 | Iteration: 1/50 (2%) | Samples: 41/2400
 					 Loss: [1m2.331853[0m 
					 ---------

2024-04-03 13:33:21,216 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 1 Epoch: 1 | Iteration: 1/50 (2%) | Samples: 47/2400
 					 Loss: [1m2.299112[0m 
					 ---------

2024-04-03 13:33:35,462 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 1 Epoch: 1 | Iteration: 10/50 (20%) | Samples: 457/2400
 					 Loss: [1m2.288610[0m 
					 ---------

2024-04-03 13:33:36,769 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 1 Epoch: 1 | Iteration: 10/50 (20%) | Samples: 434/2400
 					 Loss: [1m2.229743[0m 
					 ---------

2024-04-03 13:33:51,107 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 1 Epoch: 1 | Iteration: 20/50 (40%) | Samples: 903/2400
 					 Loss: [1m2.209047[0m 
					 ---------

2024-04-03 13:33:51,135 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 1 Epoch: 1 | Iteration: 20/50 (40%) | Samples: 871/2400
 					 Loss: [1m2.174485[0m 
					 ---------

2024-04-03 13:34:05,048 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 1 Epoch: 1 | Iteration: 30/50 (60%) | Samples: 1346/2400
 					 Loss: [1m2.162850[0m 
					 ---------

2024-04-03 13:34:06,069 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 1 Epoch: 1 | Iteration: 30/50 (60%) | Samples: 1409/2400
 					 Loss: [1m2.168503[0m 
					 ---------

2024-04-03 13:34:21,529 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 1 Epoch: 1 | Iteration: 40/50 (80%) | Samples: 1838/2400
 					 Loss: [1m2.000157[0m 
					 ---------

2024-04-03 13:34:21,885 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 1 Epoch: 1 | Iteration: 40/50 (80%) | Samples: 1906/2400
 					 Loss: [1m2.075603[0m 
					 ---------

2024-04-03 13:34:36,096 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 1 Epoch: 1 | Iteration: 50/50 (100%) | Samples: 2303/2303
 					 Loss: [1m1.911565[0m 
					 ---------

2024-04-03 13:34:36,298 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 1 Epoch: 1 | Iteration: 50/50 (100%) | Samples: 2396/2396
 					 Loss: [1m1.868311[0m 
					 ---------

2024-04-03 13:34:36,340 fedbiomed INFO - Nodes that successfully reply in round 0 ['NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b', 'NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac']

2024-04-03 13:34:36,345 fedbiomed INFO - Sampled nodes in round 1 ['NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b', 'NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac']

2024-04-03 13:34:36,350 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-04-03 13:34:36,352 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-04-03 13:34:36,409 fedbiomed DEBUG - Node: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b polling for the tasks

2024-04-03 13:34:36,410 fedbiomed DEBUG - Node: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac polling for the tasks

2024-04-03 13:34:37,853 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 2 Epoch: 1 | Iteration: 1/50 (2%) | Samples: 49/2400
 					 Loss: [1m1.981983[0m 
					 ---------

2024-04-03 13:34:38,480 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 2 Epoch: 1 | Iteration: 1/50 (2%) | Samples: 45/2400
 					 Loss: [1m1.945740[0m 
					 ---------

2024-04-03 13:34:49,596 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 2 Epoch: 1 | Iteration: 10/50 (20%) | Samples: 494/2400
 					 Loss: [1m1.710354[0m 
					 ---------

2024-04-03 13:34:53,897 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 2 Epoch: 1 | Iteration: 10/50 (20%) | Samples: 469/2400
 					 Loss: [1m1.900555[0m 
					 ---------

2024-04-03 13:35:03,463 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 2 Epoch: 1 | Iteration: 20/50 (40%) | Samples: 956/2400
 					 Loss: [1m1.693094[0m 
					 ---------

2024-04-03 13:35:08,025 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 2 Epoch: 1 | Iteration: 20/50 (40%) | Samples: 1010/2400
 					 Loss: [1m1.734439[0m 
					 ---------

2024-04-03 13:35:17,905 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 2 Epoch: 1 | Iteration: 30/50 (60%) | Samples: 1443/2400
 					 Loss: [1m1.647621[0m 
					 ---------

2024-04-03 13:35:23,206 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 2 Epoch: 1 | Iteration: 30/50 (60%) | Samples: 1476/2400
 					 Loss: [1m1.569272[0m 
					 ---------

2024-04-03 13:35:33,140 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 2 Epoch: 1 | Iteration: 40/50 (80%) | Samples: 1894/2400
 					 Loss: [1m1.296453[0m 
					 ---------

2024-04-03 13:35:38,601 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 2 Epoch: 1 | Iteration: 40/50 (80%) | Samples: 1973/2400
 					 Loss: [1m1.475632[0m 
					 ---------

2024-04-03 13:35:46,590 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 2 Epoch: 1 | Iteration: 50/50 (100%) | Samples: 2387/2387
 					 Loss: [1m1.144556[0m 
					 ---------

2024-04-03 13:35:46,971 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 2 Epoch: 1 | Iteration: 50/50 (100%) | Samples: 2493/2493
 					 Loss: [1m1.439208[0m 
					 ---------

2024-04-03 13:35:47,019 fedbiomed INFO - Nodes that successfully reply in round 1 ['NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b', 'NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac']

2024-04-03 13:35:47,024 fedbiomed INFO - Sampled nodes in round 2 ['NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b', 'NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac']

2024-04-03 13:35:47,030 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-04-03 13:35:47,032 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					[1m Request: [0m: TRAIN
 -----------------------------------------------------------------

2024-04-03 13:35:47,099 fedbiomed DEBUG - Node: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b polling for the tasks

2024-04-03 13:35:47,102 fedbiomed DEBUG - Node: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac polling for the tasks

2024-04-03 13:35:48,857 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 3 Epoch: 1 | Iteration: 1/50 (2%) | Samples: 62/2400
 					 Loss: [1m1.584836[0m 
					 ---------

2024-04-03 13:35:49,164 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 3 Epoch: 1 | Iteration: 1/50 (2%) | Samples: 43/2400
 					 Loss: [1m1.528233[0m 
					 ---------

2024-04-03 13:36:04,778 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 3 Epoch: 1 | Iteration: 10/50 (20%) | Samples: 455/2400
 					 Loss: [1m1.245793[0m 
					 ---------

2024-04-03 13:36:05,614 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 3 Epoch: 1 | Iteration: 10/50 (20%) | Samples: 459/2400
 					 Loss: [1m1.171741[0m 
					 ---------

2024-04-03 13:36:21,691 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 3 Epoch: 1 | Iteration: 20/50 (40%) | Samples: 914/2400
 					 Loss: [1m1.048743[0m 
					 ---------

2024-04-03 13:36:22,029 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 3 Epoch: 1 | Iteration: 20/50 (40%) | Samples: 978/2400
 					 Loss: [1m1.151399[0m 
					 ---------

2024-04-03 13:36:37,169 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 3 Epoch: 1 | Iteration: 30/50 (60%) | Samples: 1419/2400
 					 Loss: [1m0.952462[0m 
					 ---------

2024-04-03 13:36:38,318 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 3 Epoch: 1 | Iteration: 30/50 (60%) | Samples: 1473/2400
 					 Loss: [1m0.926908[0m 
					 ---------

2024-04-03 13:36:53,786 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 3 Epoch: 1 | Iteration: 40/50 (80%) | Samples: 1869/2400
 					 Loss: [1m1.245580[0m 
					 ---------

2024-04-03 13:36:53,933 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 3 Epoch: 1 | Iteration: 40/50 (80%) | Samples: 1908/2400
 					 Loss: [1m1.372809[0m 
					 ---------

2024-04-03 13:37:07,889 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac 
					 Round 3 Epoch: 1 | Iteration: 50/50 (100%) | Samples: 2303/2303
 					 Loss: [1m0.875097[0m 
					 ---------

2024-04-03 13:37:08,014 fedbiomed INFO - [1mTRAINING[0m 
					 NODE_ID: NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b 
					 Round 3 Epoch: 1 | Iteration: 50/50 (100%) | Samples: 2387/2387
 					 Loss: [1m1.509128[0m 
					 ---------

2024-04-03 13:37:08,049 fedbiomed INFO - Nodes that successfully reply in round 2 ['NODE_a16525b2-5b6a-42cf-9cef-ff746b2a6d4b', 'NODE_90a7e795-9022-4bf0-b6c6-f1a2394b6dac']

3

Save trained model to file

In [5]:
exp.training_plan().export_model('./trained_model')

Federated parameters for each round are available in `exp.aggregated_params()` (index 0 to (`rounds` - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [6]:
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())


List the training rounds :  dict_keys([0, 1, 2])

Access the federated params for the last training round :
	- parameter data:  dict_keys(['0.weight', '0.bias', '2.weight', '2.bias', '7.weight', '7.bias', '10.weight', '10.bias'])


## Testing

We define a little testing routine to extract the accuracy metrics on the testing dataset

In [7]:
import torch
import torch.nn.functional as F


def testing_Accuracy(model, data_loader):
    model.eval()
    test_loss = 0
    correct = 0
    device = 'cpu'

    correct = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

        pred = output.argmax(dim=1, keepdim=True)

    test_loss /= len(data_loader.dataset)
    accuracy = 100* correct/len(data_loader.dataset)

    return(test_loss, accuracy)

In [8]:
from torchvision import datasets, transforms
from fedbiomed.researcher.environ import environ
import os

local_mnist = os.path.join(environ['TMP_DIR'], 'local_mnist')

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

test_set = datasets.MNIST(root = local_mnist, download = True, train = False, transform = transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/train-images-idx3-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/ybouilla/Documents/github/fedbiomed/var/tmp/local_mnist/MNIST/raw



In [9]:
fed_model = exp.training_plan().model()
fed_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

acc_federated = testing_Accuracy(fed_model, test_loader)

print('\nAccuracy federated training:  {:.4f}'.format(acc_federated[1]))

print('\nError federated training:  {:.4f}'.format(acc_federated[0]))


Accuracy federated training:  68.3900

Error federated training:  1.0612
