# Using Differential Privacy with OPACUS on Fed-BioMed

In this notebook we show how `opacus` (https://opacus.ai/) can be used in Fed-BioMed. Opacus is a library which allows to train PyTorch models with differential privacy. We will train the basic MNIST example using two nodes.

## Setting up Fed-BioMed Environment

### Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

### Setting the node up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 2 (default)
  * Confirm default tags by hitting "y" and ENTER
  * Pick the folder where MNIST is downloaded (this is due torch issue https://github.com/pytorch/vision/issues/3549)
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node run`. Wait until you get `Starting task manager`. it means you are online.

## Define a model and parameters

Declare a torch.nn MyTrainingPlan class to send for training on the node

In the cell below, we are going to define the model using opacus for differential privacy. For this example, we are going to use the function `make_private` from `opacus.privacy_engine`. Two hyperparameters should be defined:
* `noise_multiplier`: The ratio of the standard deviation of the Gaussian noise to the L2-sensitivity of the function to which the noise is added (How much noise to add)
* `max_grad_norm`: The maximum norm of the per-sample gradients. Any gradient with norm higher than this will be clipped to this value.

It is worth noting that in order to use the opacus `PrivacyEngine` class we need to properly define as training plan attributes a `model`, a `dataloader` and an `optimizer`.

In [1]:
from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)

2022-03-31 14:51:31,828 fedbiomed INFO - Component environment:
2022-03-31 14:51:31,829 fedbiomed INFO - type = ComponentType.RESEARCHER
2022-03-31 14:51:32,405 fedbiomed INFO - Messaging researcher_2a29744e-4430-4e85-9661-1c02afbdd825 successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x1116e54c0>
2022-03-31 14:51:32,425 fedbiomed INFO - Listing available datasets in all nodes... 
2022-03-31 14:51:32,442 fedbiomed INFO - log from: node_b472c750-0198-450d-85cf-8faddc7f54e0 / DEBUG - Message received: {'researcher_id': 'researcher_2a29744e-4430-4e85-9661-1c02afbdd825', 'command': 'list'}
2022-03-31 14:51:32,443 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - Message received: {'researcher_id': 'researcher_2a29744e-4430-4e85-9661-1c02afbdd825', 'command': 'list'}
2022-03-31 14:51:32,451 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - Message received: {'researcher_id': 'resea

{'node_b472c750-0198-450d-85cf-8faddc7f54e0': [{'name': 'mednist',
   'data_type': 'images',
   'tags': ['mednist'],
   'description': 'bla',
   'shape': [16954, 3, 64, 64]}],
 'node_aabe8200-9df6-48e7-a0c8-820be37261e2': [{'name': 'mnist',
   'data_type': 'images',
   'tags': ['mnist'],
   'description': 'bla',
   'shape': [18000, 3, 64, 64]}],
 'node_278d405c-015c-4089-a6a4-25506c07fd24': [{'name': 'mnist',
   'data_type': 'images',
   'tags': ['mnist'],
   'description': 'bla',
   'shape': [18000, 3, 64, 64]}]}

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import decollate_batch
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import (
    Activations,
    AddChannel,
    AsDiscrete,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    EnsureType,
)
from monai.utils import set_determinism



# Here we define the model to be used. 
# You can use any class name (here 'DenseNet121')
class MyTrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(MyTrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        # In this case, we need the torch DataLoader classes
        # Since we will train on MNIST, we need datasets and transform from torchvision
        deps = ["import numpy as np",
                "import os",
                "from torch.utils.data import DataLoader",
                "from monai.apps import download_and_extract",
                "from monai.config import print_config",
                "from monai.data import decollate_batch",
                "from monai.metrics import ROCAUCMetric",
                "from monai.networks.nets import DenseNet121",
                "from monai.transforms import ( Activations, AddChannel, AsDiscrete, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, EnsureType, )",
                "from monai.utils import set_determinism",
                "from opacus import PrivacyEngine",]
        self.add_dependency(deps)
         
        self.num_class =  model_args['num_class']
        
        self.model = DenseNet121(spatial_dims=2, in_channels=1,
                    out_channels = self.num_class, norm=('GROUP', {'num_groups': 8}))
        
        self.loss_function = torch.nn.CrossEntropyLoss()
        
        self.noise_multiplier = model_args['noise_multiplier']
        self.max_grad_norm = model_args['max_grad_norm']
        

    def forward(self, x):
        return self.model(x)

    class MedNISTDataset(torch.utils.data.Dataset):
            def __init__(self, image_files, labels, transforms):
                self.image_files = image_files
                self.labels = labels
                self.transforms = transforms

            def __len__(self):
                return len(self.image_files)

            def __getitem__(self, index):
                return self.transforms(self.image_files[index]), self.labels[index]
    
    def parse_data(self, path):
        print(self.dataset_path)
        class_names = sorted(x for x in os.listdir(path)
                     if os.path.isdir(os.path.join(path, x)))
        num_class = len(class_names)
        image_files = [
                        [
                            os.path.join(path, class_names[i], x)
                            for x in os.listdir(os.path.join(path, class_names[i]))
                        ]
                        for i in range(num_class)
                      ]
        
        return image_files, num_class
    
    def training_data(self, batch_size = 48):
        self.image_files, num_class = self.parse_data(self.dataset_path)
        
        if self.num_class!=num_class:
                raise Exception('number of available classes does not match declared classes')
        
        num_each = [len(self.image_files[i]) for i in range(self.num_class)]
        image_files_list = []
        image_class = []
        
        for i in range(self.num_class):
            image_files_list.extend(self.image_files[i])
            image_class.extend([i] * num_each[i])
        num_total = len(image_class)
        
        
        length = len(image_files_list)
        indices = np.arange(length)
        np.random.shuffle(indices)

        val_split = int(1. * length) 
        train_indices = indices[:val_split]

        train_x = [image_files_list[i] for i in train_indices]
        train_y = [image_class[i] for i in train_indices]


        train_transforms = Compose(
            [
                LoadImage(image_only=True),
                AddChannel(),
                ScaleIntensity(),
                RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
                RandFlip(spatial_axis=0, prob=0.5),
                RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
                EnsureType(),
            ]
        )

        val_transforms = Compose(
            [LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])

        y_pred_trans = Compose([EnsureType(), Activations(softmax=True)])
        y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=num_class)])

        print(
            f"Training count: {len(train_x)}")
        
        
        train_ds = self.MedNISTDataset(train_x, train_y, train_transforms)
        train_loader = torch.utils.data.DataLoader(
            train_ds, batch_size, shuffle=True)
        
        
        # enter PrivacyEngine
        privacy_engine = PrivacyEngine()
        self.model, self.optimizer, data_loader = privacy_engine.make_private(module=self.model,
                                                                    optimizer=self.optimizer,
                                                                    data_loader=train_loader,
                                                                    noise_multiplier=self.noise_multiplier,
                                                                    max_grad_norm=self.max_grad_norm,
                                                                    )
        
        return train_loader
    
    def training_step(self, data, target):
        output = self.forward(data)
        loss   = self.loss_function(output, target)
        return loss


    def postprocess(self, params):
        # params keys are changed by the privacy engine (as _module.param_key): should be re-named
        params_keys = list(params.keys())
        for key in params_keys:
            if '_module' in key:
                newkey = key.replace('_module.', '')
                params[newkey] = params.pop(key)
        return params

This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. For instance, the privacy parameters should be passed here.
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [3]:
model_args = {'noise_multiplier':0., 'max_grad_norm':1.0, 'num_class':6,}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'epochs': 3, 
    'dry_run': False,  
    'batch_maxnum': 250 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `rounds` rounds, applying the `node_selection_strategy` between the rounds

In [4]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['mnist']
rounds = 3

exp = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MyTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2022-03-31 14:51:45,747 fedbiomed INFO - Searching dataset with data tags: ['mnist'] for all nodes
2022-03-31 14:51:45,755 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - Message received: {'researcher_id': 'researcher_2a29744e-4430-4e85-9661-1c02afbdd825', 'tags': ['mnist'], 'command': 'search'}
2022-03-31 14:51:45,757 fedbiomed INFO - log from: node_b472c750-0198-450d-85cf-8faddc7f54e0 / DEBUG - Message received: {'researcher_id': 'researcher_2a29744e-4430-4e85-9661-1c02afbdd825', 'tags': ['mnist'], 'command': 'search'}
2022-03-31 14:51:45,758 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - Message received: {'researcher_id': 'researcher_2a29744e-4430-4e85-9661-1c02afbdd825', 'tags': ['mnist'], 'command': 'search'}
2022-03-31 14:51:55,754 fedbiomed INFO - Node selected for training -> node_aabe8200-9df6-48e7-a0c8-820be37261e2
2022-03-31 14:51:55,755 fedbiomed INFO - Node selected for training -> node_278d405c-015c-4089-a6a

Let's start the experiment.

By default, this function doesn't stop until all the `rounds` are done for all the nodes

In [5]:
exp.run()

2022-03-31 14:51:58,897 fedbiomed INFO - Sampled nodes in round 0 ['node_aabe8200-9df6-48e7-a0c8-820be37261e2', 'node_278d405c-015c-4089-a6a4-25506c07fd24']
2022-03-31 14:51:58,898 fedbiomed INFO - Send message to node node_aabe8200-9df6-48e7-a0c8-820be37261e2 - {'researcher_id': 'researcher_2a29744e-4430-4e85-9661-1c02afbdd825', 'job_id': '2c1f1a9b-e550-4024-b8ba-0ed128e35750', 'training_args': {'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 250}, 'model_args': {'noise_multiplier': 0.0, 'max_grad_norm': 1.0, 'num_class': 6}, 'command': 'train', 'model_url': 'http://localhost:8844/media/uploads/2022/03/31/my_model_8eadde75-dd5c-4e3d-83e4-2293e4eea2c1.py', 'params_url': 'http://localhost:8844/media/uploads/2022/03/31/aggregated_params_init_41ea1992-6a41-4e25-bae0-214aa6201781.pt', 'model_class': 'MyTrainingPlan', 'training_data': {'node_aabe8200-9df6-48e7-a0c8-820be37261e2': ['dataset_130fc886-a393-4d20-8b8f-a59738405b4a']}}
2022-03-31 14:51:58,900 fedbiom

2022-03-31 14:54:27,444 fedbiomed INFO - log from: node_b472c750-0198-450d-85cf-8faddc7f54e0 / CRITICAL - Node stopped in signal_handler, probably by user decision (Ctrl C)
2022-03-31 14:54:34,037 fedbiomed INFO - Error message received during training: FB312: Node stopped in SIGTERM signal handler


2022-03-31 15:17:03,439 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - Reached 250 batches for this epoch, ignore remaining data
2022-03-31 15:17:06,083 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - Reached 250 batches for this epoch, ignore remaining data


2022-03-31 15:40:33,845 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - Reached 250 batches for this epoch, ignore remaining data
2022-03-31 15:40:38,717 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - Reached 250 batches for this epoch, ignore remaining data
2022-03-31 16:04:17,778 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - Reached 250 batches for this epoch, ignore remaining data
2022-03-31 16:04:17,795 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - running model.postprocess() method


2022-03-31 16:04:20,417 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - upload (HTTP POST request) of file /Users/mlorenzi/works/temp/fedbiomed/var/tmp/node_params_f8e96538-b623-4e71-9ca3-d6d079be7f95.pt successful, with status code 201
2022-03-31 16:04:20,460 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / INFO - results uploaded successfully 
2022-03-31 16:04:23,471 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - Reached 250 batches for this epoch, ignore remaining data
2022-03-31 16:04:23,482 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - running model.postprocess() method
2022-03-31 16:04:25,509 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - upload (HTTP POST request) of file /Users/mlorenzi/works/temp/fedbiomed/var/tmp/node_params_c21ebfa2-4c6c-43bd-bffb-8c10c0137c5e.pt successful, with status code 201
2022-03-31 16:04:25,536 fedbiome

2022-03-31 16:04:38,932 fedbiomed INFO - log from: node_278d405c-015c-4089-a6a4-25506c07fd24 / DEBUG - upload (HTTP GET request) of file my_model_cb833a7e-dd24-4367-9583-540831c99c58.pt successful, with status code 200
2022-03-31 16:04:39,268 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / INFO - training with arguments {'monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x137d15070>, 'node_args': {'gpu': False, 'gpu_num': None, 'gpu_only': False}, 'batch_size': 48, 'lr': 0.001, 'epochs': 3, 'dry_run': False, 'batch_maxnum': 250}
2022-03-31 16:04:39,269 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - Dataset path has been set as/Users/mlorenzi/works/temp/MedNIST/client_2
2022-03-31 16:04:39,270 fedbiomed INFO - log from: node_aabe8200-9df6-48e7-a0c8-820be37261e2 / DEBUG - Using device cpu for training (cuda_available=False, gpu=False, gpu_only=False, use_gpu=False, gpu_num=None)
2022-03-31 16:04:39,291 fedbiomed INF


--------------------
Fed-BioMed researcher stopped due to exception:
FB407: list of nodes became empty when training
--------------------


Local training results for each round and each node are available in `exp.training_replies()` (index 0 to (`rounds` - 1) ).

For example you can view the training results for the last round below.

Different timings (in seconds) are reported for each dataset of a node participating in a round :
- `rtime_training` real time (clock time) spent in the training function on the node
- `ptime_training` process time (user and system CPU) spent in the training function on the node
- `rtime_total` real time (clock time) spent in the researcher between sending the request and handling the response, at the `Job()` layer

In [None]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1].data()
for c in range(len(round_data)):
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['node_id'],
        rtraining = round_data[c]['timing']['rtime_training'],
        ptraining = round_data[c]['timing']['ptime_training'],
        rtotal = round_data[c]['timing']['rtime_total']))
print('\n')
    
exp.training_replies()[rounds - 1].dataframe

Federated parameters for each round are available in `exp.aggregated_params()` (index 0 to (`rounds` - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [None]:
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- params_path: ", exp.aggregated_params()[rounds - 1]['params_path'])
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())

## Testing

We define a little testing routine to extract the accuracy metrics on the testing dataset

In [None]:
import torch
import torch.nn.functional as F


def testing_Accuracy(model, data_loader):
    model.eval()
    test_loss = 0
    correct = 0
    device = 'cpu'

    correct = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

        pred = output.argmax(dim=1, keepdim=True)

    test_loss /= len(data_loader.dataset)
    accuracy = 100* correct/len(data_loader.dataset)

    return(test_loss, accuracy)

In [None]:
from torchvision import datasets, transforms
import os

local_mnist = os.path.join(environ['TMP_DIR'], 'local_mnist')

transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

test_set = datasets.MNIST(root = local_mnist, download = True, train = False, transform = transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True)

In [None]:
fed_model = exp.model_instance()
fed_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

acc_federated = testing_Accuracy(fed_model, test_loader)

print('\nAccuracy federated training:  {:.4f}'.format(acc_federated[1]))

print('\nError federated training:  {:.4f}'.format(acc_federated[0]))