# Missing data imputation with Fedbiomed using MIWAE

In this notebook we show how to impute missing not at random (MAR) data in a federated setting using MIWAE (https://arxiv.org/abs/2006.12871). 

In [1]:
%load_ext autoreload
%autoreload 2

## Prepare the data

For this experiment we will use the breast cancer data from sklearn.

In [2]:
import pandas as pd
import numpy as np

data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]

In [3]:
from sklearn.model_selection import train_test_split

data_train, data_test, labels_train, labels_test = train_test_split(data, target, test_size=0.20, random_state=42)
df_data_train = pd.DataFrame(data_train)
N_train = len(df_data_train)
client_1, client_2, client_3 = np.split(df_data_train.sample(frac=1), \
                                        [int(.33*N_train), int(.66*len(df_data_train))])

Clients_data=[client_1, client_2, client_3]

# from each dataset we will remove randomly 50% of data
np.random.seed(1234)

perc_miss = 0.5 # 50% of missing data

Clients_missing = []
for c in Clients_data:
    n = c.shape[0] # number of observations
    p = c.shape[1] # number of features
    xmiss = np.copy(c)
    xmiss = (xmiss - np.mean(xmiss,0))/np.std(xmiss,0)
    xmiss_flat = xmiss.flatten()
    miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int),\
                                    replace=False)
    xmiss_flat[miss_pattern] = np.nan 
    xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
    mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
    Clients_missing.append(xmiss)

import os 
os.makedirs('clients_data', exist_ok=True) 
for i in range(len(Clients_missing)):
    pd.DataFrame(Clients_missing[i]).to_csv('clients_data/client_'+str(i+1)+'.csv')

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int), replace=False)


## Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

## Setting the nodes up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 1 (csv) to add client_1 dataset to the first node
  * Provide the correct tag by entering:  breast_cancer
  * Pick the folder where client_1 dataset has been saved
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node start`. Wait until you get `Starting task manager`. it means you are online.
4. Following the same procedure, you can create additional nodes for clients 2 and 3.

Check available clients:

In [4]:
from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)

2022-04-15 14:20:57,824 fedbiomed INFO - Component environment:
2022-04-15 14:20:57,825 fedbiomed INFO - type = ComponentType.RESEARCHER
2022-04-15 14:20:57,865 fedbiomed INFO - Messaging researcher_cda2fb34-f12f-41cc-9c23-d61d107b97f4 successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x11b3c5f70>
2022-04-15 14:20:57,915 fedbiomed INFO - Listing available datasets in all nodes... 
2022-04-15 14:21:07,941 fedbiomed INFO - 
 Node: node_0fe17841-14d6-4f76-b61d-7b2708b6d95d | Number of Datasets: 1 
+---------------+-------------+-------------------+---------------+-----------+
| name          | data_type   | tags              | description   | shape     |
| breast_cancer | csv         | ['breast_cancer'] | breast_cancer | [133, 14] |
+---------------+-------------+-------------------+---------------+-----------+

2022-04-15 14:21:07,942 fedbiomed INFO - 
 Node: node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764 | Number of Datasets: 1 
+---------

{'node_0fe17841-14d6-4f76-b61d-7b2708b6d95d': [{'name': 'breast_cancer',
   'data_type': 'csv',
   'tags': ['breast_cancer'],
   'description': 'breast_cancer',
   'shape': [133, 14]}],
 'node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764': [{'name': 'breast_cancer',
   'data_type': 'csv',
   'tags': ['breast_cancer'],
   'description': 'breast_cancer',
   'shape': [138, 14]}],
 'node_bc08b50a-1b46-407a-8341-2bf317f2506d': [{'name': 'breast_cancer',
   'data_type': 'csv',
   'tags': ['breast_cancer'],
   'description': 'breast_cancer',
   'shape': [133, 14]}]}

## Define an experiment model and parameters

Declare a torch.nn MIWAETrainingPlan class to send for training on the node

Note : write **only** the code to export in the following cell

In [59]:
import torch
import torch.nn as nn
#from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
import numpy as np
#import scipy.stats
#import scipy.io
#import scipy.sparse
#from scipy.io import loadmat
import torch.distributions as td
import pandas as pd

from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MIWAETrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(MIWAETrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from torchvision import datasets, transforms",
               "import torch.distributions as td",
               "import pandas as pd",
               "import numpy as np"]
        
        self.n_features=model_args['n_features']
        self.n_latent=model_args['n_latent']
        self.n_hidden=model_args['n_hidden']
        self.n_samples=model_args['n_samples']
        
        self.add_dependency(deps)
        
        # the encoder will output both the mean and the diagonal covariance
        self.encoder=nn.Sequential(
                        torch.nn.Linear(self.n_features, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 2*self.n_latent),  
                        )
        # the decoder will output both the mean, the scale, 
        # and the number of degrees of freedoms (hence the 3*p)
        self.decoder = nn.Sequential(
                        torch.nn.Linear(self.n_latent, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 3*self.n_features),  
                        )
        
        self.optimizer = torch.optim.Adam(list(self.encoder.parameters()) \
                                    + list(self.decoder.parameters()),lr=1e-3)
    
    def miwae_loss(self,iota_x,mask):
        batch_size = iota_x.shape[0]
        out_encoder = self.encoder(iota_x)
        # prior
        p_z = td.Independent(td.Normal(loc=torch.zeros(self.n_latent).to(self.device)\
                                       ,scale=torch.ones(self.n_latent).to(self.device)),1)
        
        q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :self.n_latent],\
                                                scale=torch.nn.Softplus()\
                                                (out_encoder[..., self.n_latent:\
                                                             (2*self.n_latent)])),1)

        zgivenx = q_zgivenxobs.rsample([self.n_samples])
        zgivenx_flat = zgivenx.reshape([self.n_samples*batch_size,self.n_latent])

        out_decoder = self.decoder(zgivenx_flat)
        all_means_obs_model = out_decoder[..., :self.n_features]
        all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., self.n_features:\
                                                               (2*self.n_features)]) + 0.001
        all_degfreedom_obs_model = torch.nn.Softplus()\
        (out_decoder[..., (2*self.n_features):(3*self.n_features)]) + 3

        data_flat = torch.Tensor.repeat(iota_x,[self.n_samples,1]).reshape([-1,1])
        tiledmask = torch.Tensor.repeat(mask,[self.n_samples,1])

        all_log_pxgivenz_flat = torch.distributions.StudentT\
        (loc=all_means_obs_model.reshape([-1,1]),\
         scale=all_scales_obs_model.reshape([-1,1]),\
         df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
        all_log_pxgivenz = all_log_pxgivenz_flat.reshape([self.n_samples*batch_size,self.n_features])

        logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([self.n_samples,batch_size])
        logpz = p_z.log_prob(zgivenx)
        logq = q_zgivenxobs.log_prob(zgivenx)

        neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

        return neg_bound

    def training_data(self,  batch_size = 48):
        
        df = pd.read_csv(self.dataset_path, sep=',', index_col=False)
        x_train = df.values
        x_mask = np.isfinite(x_train)
        xhat_0 = np.copy(x_train)
        xhat_0[np.isnan(x_train)] = 0
        print(x_train[0])
        print(x_mask[0])
        print(xhat_0[0])
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        
        data_manager = DataManager(dataset=xhat_0 , target=x_mask , **train_kwargs)
        
        return data_manager
    
    def training_step(self, data, mask):
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        loss = self.miwae_loss(iota_x = data,mask = mask)
        return loss

This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. 
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.
If FedProx optimisation is requested, `fedprox_mu` parameter must be defined here. It also must be a float between XX and YY.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [60]:
model_args = {'n_features':14, 'n_latent':10,'n_hidden':128,'n_samples':20}

training_args = {
    'batch_size': 64, 
    'lr': 1e-3, 
    #'fedprox_mu': 0.01, 
    'epochs': 5, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `round_limit` rounds, applying the `node_selection_strategy` between the rounds

In [61]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['breast_cancer']
rounds = 3

exp = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2022-04-15 16:16:39,597 fedbiomed INFO - Searching dataset with data tags: ['breast_cancer'] for all nodes
2022-04-15 16:16:49,619 fedbiomed INFO - Node selected for training -> node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764
2022-04-15 16:16:49,622 fedbiomed INFO - Node selected for training -> node_bc08b50a-1b46-407a-8341-2bf317f2506d
2022-04-15 16:16:49,623 fedbiomed INFO - Node selected for training -> node_0fe17841-14d6-4f76-b61d-7b2708b6d95d
2022-04-15 16:16:49,638 fedbiomed INFO - Checking data quality of federated datasets...
2022-04-15 16:16:49,653 fedbiomed DEBUG - Model file has been saved: /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0015/my_model_f2898798-9ca7-488b-8f17-06f63d214d0e.py
2022-04-15 16:16:50,104 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0015/my_model_f2898798-9ca7-488b-8f17-06f63d214d0e.py successful, with status code 201
2022-04-15 

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [62]:
exp.run()

2022-04-15 16:16:51,031 fedbiomed INFO - Sampled nodes in round 0 ['node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764', 'node_bc08b50a-1b46-407a-8341-2bf317f2506d', 'node_0fe17841-14d6-4f76-b61d-7b2708b6d95d']
2022-04-15 16:16:51,033 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764 
					[1m Request: [0m: Perform training with the arguments: {'researcher_id': 'researcher_cda2fb34-f12f-41cc-9c23-d61d107b97f4', 'job_id': 'c5dcaa30-9128-46bc-92ea-c7133cb1492f', 'training_args': {'test_ratio': 0.0, 'test_on_local_updates': False, 'test_on_global_updates': False, 'test_metric': None, 'test_metric_args': {}, 'batch_size': 64, 'lr': 0.001, 'epochs': 5, 'dry_run': False, 'batch_maxnum': 100}, 'training': True, 'model_args': {'n_features': 14, 'n_latent': 10, 'n_hidden': 128, 'n_samples': 20}, 'command': 'train', 'model_url': 'http://localhost:8844/media/uploads/2022/04/15/my_model_f2898798-9ca7-488b-8f17-06f63d214d0e.py', 'params_url': 'http://loc

2022-04-15 16:17:01,252 fedbiomed INFO - Nodes that successfully reply in round 0 ['node_bc08b50a-1b46-407a-8341-2bf317f2506d', 'node_0fe17841-14d6-4f76-b61d-7b2708b6d95d', 'node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764']
2022-04-15 16:17:01,393 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0015/aggregated_params_82db454f-d3b3-4324-abbc-eaaf19efccac.pt successful, with status code 201
2022-04-15 16:17:01,395 fedbiomed INFO - Saved aggregated params for round 0 in /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0015/aggregated_params_82db454f-d3b3-4324-abbc-eaaf19efccac.pt
2022-04-15 16:17:01,398 fedbiomed INFO - Sampled nodes in round 1 ['node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764', 'node_bc08b50a-1b46-407a-8341-2bf317f2506d', 'node_0fe17841-14d6-4f76-b61d-7b2708b6d95d']
2022-04-15 16:17:01,399 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: node_d1b0882

2022-04-15 16:17:11,534 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_a1dadf27-5b67-4a51-b424-fff1b61086fb.pt successful, with status code 200
2022-04-15 16:17:11,543 fedbiomed INFO - Downloading model params after training on node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764 - from http://localhost:8844/media/uploads/2022/04/15/node_params_df8c0b10-ecda-4700-8272-ada65f0f38c4.pt
2022-04-15 16:17:11,573 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_8eaac61e-97bb-4800-aa60-75b9161b43ba.pt successful, with status code 200
2022-04-15 16:17:11,580 fedbiomed INFO - Nodes that successfully reply in round 1 ['node_bc08b50a-1b46-407a-8341-2bf317f2506d', 'node_0fe17841-14d6-4f76-b61d-7b2708b6d95d', 'node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764']
2022-04-15 16:17:11,915 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0015/aggregated_params_7812a501-da99-480e-ba49-f100f3dd834f.pt succ

2022-04-15 16:17:21,952 fedbiomed INFO - Downloading model params after training on node_0fe17841-14d6-4f76-b61d-7b2708b6d95d - from http://localhost:8844/media/uploads/2022/04/15/node_params_71ad15ba-a8e8-4be0-9789-f4d37e5367e9.pt
2022-04-15 16:17:22,015 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_7b5b4161-67e0-4413-a591-2c3f8c267cad.pt successful, with status code 200
2022-04-15 16:17:22,032 fedbiomed INFO - Downloading model params after training on node_bc08b50a-1b46-407a-8341-2bf317f2506d - from http://localhost:8844/media/uploads/2022/04/15/node_params_f010b2b3-a2cf-4428-bfe4-66e0c7ad11bc.pt
2022-04-15 16:17:22,076 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_e2ce8f1d-fce7-4c88-9cb2-de7ef2af16d4.pt successful, with status code 200
2022-04-15 16:17:22,081 fedbiomed INFO - Downloading model params after training on node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764 - from http://localhost:8844/media/uploads/2022/04/15/node_params_07ee8476-79c6-4040-94b

3

Local training results for each round and each node are available via `exp.training_replies()` (index 0 to (`rounds` - 1) ).

For example you can view the training results for the last round below.

Different timings (in seconds) are reported for each dataset of a node participating in a round :
- `rtime_training` real time (clock time) spent in the training function on the node
- `ptime_training` process time (user and system CPU) spent in the training function on the node
- `rtime_total` real time (clock time) spent in the researcher between sending the request and handling the response, at the `Job()` layer

In [63]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1].data()
for c in range(len(round_data)):
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['node_id'],
        rtraining = round_data[c]['timing']['rtime_training'],
        ptraining = round_data[c]['timing']['ptime_training'],
        rtotal = round_data[c]['timing']['rtime_total']))
print('\n')
    
exp.training_replies()[rounds - 1].dataframe()


List the training rounds :  dict_keys([0, 1, 2])

List the nodes for the last training round and their timings : 
	- node_0fe17841-14d6-4f76-b61d-7b2708b6d95d :    
		rtime_training=0.43 seconds    
		ptime_training=0.31 seconds    
		rtime_total=10.02 seconds
	- node_bc08b50a-1b46-407a-8341-2bf317f2506d :    
		rtime_training=0.38 seconds    
		ptime_training=0.31 seconds    
		rtime_total=10.11 seconds
	- node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764 :    
		rtime_training=0.43 seconds    
		ptime_training=0.35 seconds    
		rtime_total=10.16 seconds




Unnamed: 0,success,msg,dataset_id,node_id,params_path,params,timing
0,True,,dataset_94c0396b-5784-4b8c-9fb5-f2ede90c4e1e,node_0fe17841-14d6-4f76-b61d-7b2708b6d95d,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.0412), tensor...","{'rtime_training': 0.4267892830002893, 'ptime_..."
1,True,,dataset_04e1d4c6-ec4b-4175-8873-bdbbd0293f6d,node_bc08b50a-1b46-407a-8341-2bf317f2506d,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.0469), tensor...","{'rtime_training': 0.381567835000169, 'ptime_t..."
2,True,,dataset_7d188ba5-0398-489c-92fb-005643ba5cd8,node_d1b08824-64cd-43a6-a4e5-e0ef62d3d764,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.0374), tensor...","{'rtime_training': 0.4315536129997781, 'ptime_..."


Federated parameters for each round are available via `exp.aggregated_params()` (index 0 to (`rounds` - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [64]:
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- params_path: ", exp.aggregated_params()[rounds - 1]['params_path'])
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())



List the training rounds :  dict_keys([0, 1, 2])

Access the federated params for the last training round :
	- params_path:  /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0015/aggregated_params_be7ed1df-1702-4395-ba71-aa126704ee97.pt
	- parameter data:  odict_keys(['encoder.0.weight', 'encoder.0.bias', 'encoder.2.weight', 'encoder.2.bias', 'encoder.4.weight', 'encoder.4.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias', 'decoder.4.weight', 'decoder.4.bias'])
