## Running FL with secure aggregation using homomorphic encryption

This notebook will walk you through how to setup FL with homomorphic encryption (HE). 


## Prerequisites
Before starting this notebook, please make yourself familiar with other FL notebooks in this repo.

- (Optional) Look at the introduction Notebook for [Federated Learning with Clara Train SDK](FederatedLearning.ipynb).
- (Optional) Look at [Client Notebook](Client.ipynb).
- (Optional) Look at [Admin Notebook](Admin.ipynb).
- Run [Provisioning Notebook](Provisioning.ipynb) and started the server.

Make sure the project.yml used for provision contains these HE related settings:

    # homomorphic encryption
    he:
      lib: tenseal
      config:
        poly_modulus_degree: 8192
        coeff_mod_bit_sizes: [60, 40, 40]
        scale_bits: 40
        scheme: CKKS
        
*Note:* These settings are recommended and should work for most tasks but could be further optimized depending on your specific model architecture and machine learning task. See this [tutorial on the CKKS scheme](https://github.com/OpenMined/TenSEAL/blob/master/tutorials/Tutorial%202%20-%20Working%20with%20Approximate%20Numbers.ipynb) and [benchmarking](https://github.com/OpenMined/TenSEAL/blob/master/tutorials/Tutorial%203%20-%20Benchmarks.ipynb) for more information of different settings.

## Dataset 

##### Option 1 
This notebook uses a sample dataset (ie. a single image volume of the spleen dataset) provided in the package to train a small neural network for a few epochs. 
This single file is duplicated 32 times for the training set and 9 times for the validation set to mimic the full spleen dataset. 

##### Option 2  
You could do minor changes as recommended in the excersise section to train on the spleen segmentation task. The dataset used is Task09_Spleen.tar from 
the [Medical Segmentation Decathlon](http://medicaldecathlon.com/). 
Prior to running this notebook the data should be downloaded following 
the steps in [Data Download Notebook](../../Data_Download.ipynb).

### Disclaimer  
We will be training a small networks so that both clients can fit the model on 1 gpu. 
Training will run for a couple of epochs, in order to show the concepts, we are not targeting accuracy.

# Lets get started
In order to learn how FL works with homomorphic encryption (HE) in Clara Train SDK we will first give some background on what homomorphic encryption is and how the MMAR configurations need to be modifyed to enable it.
<br><img src="./screenShots/homomorphic_encryption.png" alt="Drawing" style="height: 450px;"/><br> 

## TODOs: 

### Explain new HE components*
- Cite TenSEAL and Microsoft SEAL - done
- Link to API doc?
- Encryptor (all layers, partial, regex) - done
- Decryptor - done
- Just in time HE aggregator - done
- HE ShareableGenerator - done
- HE Persistor - done
- Cross-site validation with HE - done

### Show functionality
- Show that server connot decrypt - done
- Show how client can decrypt - done
- Show how client can decrypt global model and save as torch checkpoint

# New HE components

We implemented secure aggregation during FL with homomorphic encryption using the [TenSEAL library](https://github.com/OpenMined/TenSEAL) by OpenMined, a convienent wrapper around [Microsoft SEAL](https://github.com/microsoft/SEAL). Both libraries are available as open-source and provide an implementation of ["Homomorphic encryption for arithmetic of approximate numbers"](https://eprint.iacr.org/2016/421.pdf), aka the "CKKS" scheme, which was proposed as a solution for [encrypted machine learning](https://en.wikipedia.org/wiki/Homomorphic_encryption#Fourth-generation_FHE) and which we use for these FL experiments.

The configuration files in `adminMMAR_HE` use the following new componets that are needed on top or instead of standard FL components used in Clara Train.

## Client-side 
See `config_fed_client.json`:

### `HEModelEncryptor`
A filter to the encrypt Shareable object that being sent to the server.

```
Args:
    tenseal_context_file: tenseal context files containing encryption keys and parameters
    encrypt_layers: if not specified (None), all layers are being encrypted;
                    if list of variable/layer names, only specified variables are encrypted;
                    if string containing regular expression (e.g. "conv"), only matched variables are being encrypted.
    aggregation_weights: dictionary of client aggregation `{"client1": 1.0, "client2": 2.0, "client3": 3.0}`;
                         defaults to a weight of 1.0 if not specified.
    weigh_by_local_iter: If true, multiply client weights on first before encryption (default: `True` which is recommended for HE)
```

HE will increase the message sizes when encrypting the model updates of each client. One can choose to not encrypt all layers but specify which layers to enrypt, see arg `encrypt_layers`.

To choose the layer names for a given model, one can use:

In [1]:
from monai.networks.nets.unet import UNet

# use the same configuration as in adminMMAR_HE
net = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=[16, 32, 64, 128, 256],
    strides=[2, 2, 2, 2],
    num_res_units=2    
)

# here, we only print convolutional layers that we might want to encrypt
for key in net.state_dict().keys():
    if 'conv' in key:
        print(key)

model.0.conv.unit0.conv.weight
model.0.conv.unit0.conv.bias
model.0.conv.unit0.adn.A.weight
model.0.conv.unit1.conv.weight
model.0.conv.unit1.conv.bias
model.0.conv.unit1.adn.A.weight
model.1.submodule.0.conv.unit0.conv.weight
model.1.submodule.0.conv.unit0.conv.bias
model.1.submodule.0.conv.unit0.adn.A.weight
model.1.submodule.0.conv.unit1.conv.weight
model.1.submodule.0.conv.unit1.conv.bias
model.1.submodule.0.conv.unit1.adn.A.weight
model.1.submodule.1.submodule.0.conv.unit0.conv.weight
model.1.submodule.1.submodule.0.conv.unit0.conv.bias
model.1.submodule.1.submodule.0.conv.unit0.adn.A.weight
model.1.submodule.1.submodule.0.conv.unit1.conv.weight
model.1.submodule.1.submodule.0.conv.unit1.conv.bias
model.1.submodule.1.submodule.0.conv.unit1.adn.A.weight
model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.weight
model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.bias
model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.A.weight
model.1.submodule.1.submodule.1.

Based on this output, our example `config_fed_client.json` chooses three layers to encrypt:

```
"outbound_filters": [
  {
    "path": "flare.experimental.homomorphic_encryption.he_model_encryptor.HEModelEncryptor",
    "args": {
      "encrypt_layers": [
        "model.0.conv.unit0.conv.weight",
        "model.1.submodule.1.submodule.1.submodule.2.0.conv.weight",
        "model.2.1.conv.unit0.conv.weight"
      ],
      "aggregation_weights": {
        "client1":  0.4,
        "client2":  0.6
      }
    }
  }
```

### `HEModelDecryptor`
A filter to decrypt Shareable object, i.e. the updated global model received from the server.

```
Args:
    tenseal_context_file: tenseal context files containing decryption keys and parameters
```

*Note:* The tenseal_context_file for the client will be generated by the provision tool and is part of the startup kit, see [Provisioning](./Provisioning.ipynb).

### `HEPTModelReaderWriter`

This component is used as argument to `ClientTrainer` to reshape the decrypted parameter vectors to the local Pytorch model for training.

### `HEEvalDecryptor`
Filter to decrypt encrypted Shareable object (i.e. global model(s)) distributed during cross-site validation. Currently, only the global server models are encrypted. Locally best models are shared unencrypted.

*Note:* cross-site validation is optional and a client does not need to participate if not wanted.

```
Args:
    tenseal_context_file: tenseal context files containing decryption keys and parameters
    defaults to `False` for use during FL training
```

## Server-side 
See `config_fed_server.json`:

### `HEInTimeAccumulateWeightedAggregator`

This aggregator can perform federated averaging (i.e. the [`FedAvg`](http://proceedings.mlr.press/v54/mcmahan17a.html) algorithm) in encrypted space. The server doesn't have a key for decryption and only processes the encrypted values sent by the clients.)

```
Args:
    exclude_vars: variable names that should be excluded from aggregation (use regular expression)
    aggregation_weights: dictionary of client aggregation `{"client1": 1.0, "client2": 2.0, "client3": 3.0}`;
                         defaults to a weight of 1.0 if not specified. Will be ignored if weigh_by_local_iter: False (default for HE)
    weigh_by_local_iter: If true, multiply client weights on first in encryption space
                         (default: `False` which is recommended for HE, first multiply happens in `HEModelEncryptor`)
```

### `HEPTFileModelPersistor`

This model persistor is used to save the encrypted models on the server.

### `HEModelShareableGenerator`

ShareableGenerator converts between Shareable and Learnable objects generated with HE. It is used to update the global model weights using the averaged encrypted updates from the clients. The updated global stays encrypted.
    
```
Args:
    tenseal_context_file: tenseal context files containing decryption keys and parameters
```

# Running FL experiment with HE

## 1 - Start server, and clients (if they are not already running)
Open four terminals in JupyterLab.

In the server terminal run:
```
cd /claraDevDay/FL/project1/server/startup
./start.sh
```  
In the client1 terminal run:
```
cd /claraDevDay/FL/project1/client1/startup
./start.sh
```  
In the client2 terminal run:
```
cd /claraDevDay/FL/project1/client2/startup
./start.sh
```  

## 2 - Starting Admin Shell
In the admin terminal, if you haven't already started the admin console you should to admin folder in side your project and run
```
cd /claraDevDay/FL/project1/admin/startup
./fl_admin.sh
``` 
you should see
```
Admin Server: localhost on port 5000
User Name: `
```
type `admin@admin.com` 

Admin Server: localhost on port 8003
User Name: admin@admin.com

Type ? to list commands; type "? cmdName" to show usage of a command.

## 3 - Check server/client status
type 
```
> check_status server
```
to see 
```
FL run number has not been set.
FL server status: training not started
Registered clients: 2 
-------------------------------------------------------------------------------------------------
| CLIENT NAME | TOKEN                                | LAST ACCEPTED ROUND | CONTRIBUTION COUNT |
-------------------------------------------------------------------------------------------------
| client1     | f735c245-ce35-4a08-89e0-0292bb053a9c |                     | 0                  |
| client2     | e36db52e-2624-4989-855a-28fa195f58e9 |                     | 0                  |
-------------------------------------------------------------------------------------------------
```
To check on clients type 
```
> check_status client
```
to see 
```
instance:client1 : client name: client1 token: 3c3d2276-c3bf-40c1-bc02-9be84d7c339f     status: training not started
instance:client2 : client name: client2 token: 92806548-5515-4977-894e-612900ff8b1b     status: training not started
```
To check on folder structure 

```
> info
```
To see
```
Local Upload Source: /claraDevDay/FL/project1/admin/startup/../transfer
Local Download Destination: /claraDevDay/FL/project1/admin/startup/../transfer
Server Upload Destination: /claraDevDay/FL/project1/server/startup/../transfer
Server Download Source: /claraDevDay/FL/project1/server/startup/../transfer

## 4- Upload and deploy the MMAR configurations for HE and set FL run number
First set a run number (Choose a different one if you don't want to overwrite previous results)
```
> set_run_number 1
```

Then, upload the HE MMAR and deploy to server and clients
```
> upload_folder ../../../adminMMAR_HE
> deploy adminMMAR_HE server
> deploy adminMMAR_HE client
```

## 5 - Start Training
Now you can start training by:

1. `> start server`
2. `> start client`

You can check on the status of the training using:

3. `> check_status client` or `> check_status server`  to see 

```
FL run number:1
FL server status: training started
run number:1    start round:0   max round:2     current round:0
min_num_clients:2       max_num_clients:100
Registered clients: 2 
Total number of clients submitted models for current round: 0
-------------------------------------------------------------------------------------------------
| CLIENT NAME | TOKEN                                | LAST ACCEPTED ROUND | CONTRIBUTION COUNT |
-------------------------------------------------------------------------------------------------
| client1     | f735c245-ce35-4a08-89e0-0292bb053a9c |                     | 0                  |
| client2     | e36db52e-2624-4989-855a-28fa195f58e9 |                     | 0                  |
-------------------------------------------------------------------------------------------------
```

4. get logs from server or clients using `cat server log.txt` or `cat client1 log.txt`

## 6 - Stop Training (if needed ) 
You could send signals to stop the training if you need to using:
- `abort client`
- `abort server`

## 7 - Cross-site validate
Once training is completed, you would like to get the validation matrices. 
This is another area where Clara FL shines. 
One of the promises of FL is that it enables training more generalizable models due to the more diverse datasets accessed by each client. The off-diagonal values show how well locally best and global models trained in FL generalize across the different client sites. 
Without Clara, you would need to move either the data or the selected model to each site and run validation at each site separately. 
With the cross-site validation feature, it is done automatically for you.
All you need to do is have the file `config_cross_site_validataion.json` as part of your MMAR, and have set the flag 
`"cross_site_validate": true` in the client section of the config_fed_client.json. 
These setting is already set up in this example, so all that's left is to

Run `validate all` to show the cross-site validation results. You could also run `validate source_site target_site` to see the performance of a certain model on a certain site.

You should see something like 
```
validate all
{'client1': {'client2': {'validation': {'mean_dice': 0.0637669786810875}}, 'client1': {'validation': {'mean_dice': 0.07123523205518723}}, 'server': {'validation': {'mean_dice': 0.07032141834497452}}}, 'client2': {'client2': {'validation': {'mean_dice': 0.06376668065786362}}, 'client1': {'validation': {'mean_dice': 0.07123514264822006}}, 'server': {'validation': {'mean_dice': 0.07032135874032974}}}}
Done [11570 usecs] 2020-09-03 18:49:41.485214
``` 
parsing this json and putting it in a table would look like  

 _ | Client 1 | Client 2 | Server  
 :--- | :--- | :---: | --- 
Client 1 | 0.07123523205518723 | 0.0637669786810875 | 0.07032141834497452
Client 2 | 0.07123514264822006 | 0.06376668065786362| 0.07032135874032974

## 9 - Server security

To illustrate that the server cannot decrypt the messages sent by the client, we can execute this small test script.

In [25]:
import pickle
import numpy as np
import tenseal as ts
from flare.experimental.homomorphic_encryption.homomorphic_encrypt import count_encrypted_layers, load_tenseal_context

global_model_file = "/claraDevDay/FL/project1/server/run_1/mmar_server/models/best_FL_global_model.pt"
server_context_file = "/claraDevDay/FL/project1/server/startup/server_context.tenseal"
client_context_file = "/claraDevDay/FL/project1/client1/startup/client_context.tenseal"

# load the server and client TenSEAL context files
server_ts_ctx = load_tenseal_context(server_context_file)
client_ts_ctx = load_tenseal_context(client_context_file)

# load the global model saved on the server
with open(global_model_file, "rb") as f:
    model = pickle.load(f)
    
print("model:", list(model.keys()))

encrypted_layers = model["he_encrypted_layers"]
model = model["model"]

count_encrypted_layers(encrypted_layers)

# try decrypting the first encrypted layer
for encrypted_layer in encrypted_layers:
    if encrypted_layer:
        print(f"{encrypted_layer} is encrypted. Trying to decrypt...")
        print(type(model[encrypted_layer]))
        
        try:
            # server can deserialize the bytes
            ckks_vector = ts.ckks_vector_from(server_ts_ctx, model[encrypted_layer])

            # this is supposed to fail with the available server context as it doesn't hold a secret key!
            ckks_vector.decrypt()
        except Exception as e:
            print(f"Server decryption failed with: {e}!")
            pass
        
        # However, the client can decrypt using its own TenSEAL context
        ckks_vector = ts.ckks_vector_from(client_ts_ctx, model[encrypted_layer])
        decrypted_params = ckks_vector.decrypt()
        
        print(f"Client decrypted parameters for {encrypted_layer}:")
        np.set_printoptions(threshold=10) # don't show all values
        print(np.asarray(decrypted_params))
        
        break
        

Loaded TenSEAL context from /claraDevDay/FL/project1/server/startup/server_context.tenseal
Loaded TenSEAL context from /claraDevDay/FL/project1/client1/startup/client_context.tenseal
model: ['model', 'train_conf', 'he_encrypted_layers']
3 of 63 layers are encrypted.
model.0.conv.unit0.conv.weight is encrypted. Trying to decrypt...
<class 'bytes'>
Server decryption failed with: the current context of the tensor doesn't hold a secret_key, please provide one as argument!
Client decrypted parameters for model.0.conv.unit0.conv.weight:
[ 0.06758314 -0.02286799  0.15225781 ...  0.17290769 -0.14571007
  0.04857256]


## 8 - Done 

Congratulations! You have trained and evaluated an FL model using secure aggregation with homomrophic encryption.