# Missing data imputation with Fedbiomed using MIWAE

In this notebook we show how to impute missing not at random (MAR) data in a federated setting using MIWAE (https://arxiv.org/abs/2006.12871). 

In [1]:
%load_ext autoreload
%autoreload 2

## Prepare the data

For this experiment we will use the breast cancer data from sklearn.

In [2]:
import pandas as pd
import numpy as np

data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]

In [3]:
from sklearn.model_selection import train_test_split

data_train, data_test, labels_train, labels_test = train_test_split(data, target, test_size=0.20, random_state=42)
df_data_train = pd.DataFrame(data_train)
N_train = len(df_data_train)
client_1, client_2, client_3 = np.split(df_data_train.sample(frac=1), \
                                        [int(.33*N_train), int(.66*len(df_data_train))])

Clients_data=[client_1, client_2, client_3]

# from each dataset we will remove randomly 50% of data
np.random.seed(1234)

perc_miss = 0.5 # 50% of missing data

Clients_missing = []
for c in Clients_data:
    n = c.shape[0] # number of observations
    p = c.shape[1] # number of features
    xmiss = np.copy(c)
    xmiss = (xmiss - np.mean(xmiss,0))/np.std(xmiss,0)
    xmiss_flat = xmiss.flatten()
    miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                    replace=False)
    xmiss_flat[miss_pattern] = np.nan 
    xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
    mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
    Clients_missing.append(xmiss)

import os 
os.makedirs('clients_data', exist_ok=True) 
for i in range(len(Clients_missing)):
    pd.DataFrame(Clients_missing[i]).to_csv('clients_data/client_'+str(i+1)+'.csv',index=False)

## Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

## Setting the nodes up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 1 (csv) to add client_1 dataset to the first node
  * Provide the correct tag by entering:  breast_cancer
  * Pick the folder where client_1 dataset has been saved
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node start`. Wait until you get `Starting task manager`. it means you are online.
4. Following the same procedure, you can create additional nodes for clients 2 and 3.

Check available clients:

In [4]:
from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)
xx = req.list()
dataset_size = [xx[i][0]['shape'][1] for i in xx]
assert min(dataset_size)==max(dataset_size)
data_size = dataset_size[0]

2022-04-20 15:13:35,639 fedbiomed INFO - Component environment:
2022-04-20 15:13:35,642 fedbiomed INFO - type = ComponentType.RESEARCHER
2022-04-20 15:13:35,682 fedbiomed INFO - Messaging researcher_aaf86456-e652-46b0-8054-b7bb516705db successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x122cc1820>
2022-04-20 15:13:35,721 fedbiomed INFO - Listing available datasets in all nodes... 
2022-04-20 15:13:45,758 fedbiomed INFO - 
 Node: node_13d7233c-daad-49e1-8f1c-c8dbac2aa845 | Number of Datasets: 1 
+---------------+-------------+-------------------+---------------+-----------+
| name          | data_type   | tags              | description   | shape     |
| breast_cancer | csv         | ['breast_cancer'] | breast_cancer | [134, 13] |
+---------------+-------------+-------------------+---------------+-----------+

2022-04-20 15:13:45,760 fedbiomed INFO - 
 Node: node_1ff16015-8a76-43a9-a0c9-9d9f9167f500 | Number of Datasets: 1 
+---------

## Define an experiment model and parameters

Declare a torch.nn MIWAETrainingPlan class to send for training on the node

Note : write **only** the code to export in the following cell

In [5]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import numpy as np
import torch.distributions as td
import pandas as pd

from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MIWAETrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(MIWAETrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from torchvision import datasets, transforms",
               "import torch.distributions as td",
               "import pandas as pd",
               "import numpy as np"]
        
        self.n_features=model_args['n_features']
        self.n_latent=model_args['n_latent']
        self.n_hidden=model_args['n_hidden']
        self.n_samples=model_args['n_samples']
        
        self.add_dependency(deps)
        
        # the encoder will output both the mean and the diagonal covariance
        self.encoder=nn.Sequential(
                        torch.nn.Linear(self.n_features, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 2*self.n_latent),  
                        )
        # the decoder will output both the mean, the scale, 
        # and the number of degrees of freedoms (hence the 3*p)
        self.decoder = nn.Sequential(
                        torch.nn.Linear(self.n_latent, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 3*self.n_features),  
                        )
        
        self.optimizer = torch.optim.Adam(list(self.encoder.parameters()) \
                                    + list(self.decoder.parameters()),lr=1e-3)
              
        self.encoder.apply(self.weights_init)
        self.decoder.apply(self.weights_init)
    
    def weights_init(self,layer):
        if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
    
    def miwae_loss(self,iota_x,mask):
        batch_size = iota_x.shape[0]
        out_encoder = self.encoder(iota_x)
        # prior
        p_z = td.Independent(td.Normal(loc=torch.zeros(self.n_latent).to(self.device)\
                                       ,scale=torch.ones(self.n_latent).to(self.device)),1)
        
        q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :self.n_latent],\
                                                scale=torch.nn.Softplus()\
                                                (out_encoder[..., self.n_latent:\
                                                             (2*self.n_latent)])),1)

        zgivenx = q_zgivenxobs.rsample([self.n_samples])
        zgivenx_flat = zgivenx.reshape([self.n_samples*batch_size,self.n_latent])

        out_decoder = self.decoder(zgivenx_flat)
        all_means_obs_model = out_decoder[..., :self.n_features]
        all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., self.n_features:\
                                                               (2*self.n_features)]) + 0.001
        all_degfreedom_obs_model = torch.nn.Softplus()\
        (out_decoder[..., (2*self.n_features):(3*self.n_features)]) + 3

        data_flat = torch.Tensor.repeat(iota_x,[self.n_samples,1]).reshape([-1,1])
        tiledmask = torch.Tensor.repeat(mask,[self.n_samples,1])

        all_log_pxgivenz_flat = torch.distributions.StudentT\
        (loc=all_means_obs_model.reshape([-1,1]),\
         scale=all_scales_obs_model.reshape([-1,1]),\
         df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
        all_log_pxgivenz = all_log_pxgivenz_flat.reshape([self.n_samples*batch_size,self.n_features])

        logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([self.n_samples,batch_size])
        logpz = p_z.log_prob(zgivenx)
        logq = q_zgivenxobs.log_prob(zgivenx)

        neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

        return neg_bound

    def training_data(self,  batch_size = 48):
        
        df = pd.read_csv(self.dataset_path, sep=',', index_col=False)
        x_train = df.values
        x_mask = np.isfinite(x_train)
        # xhat_0: missing values are replaced by zeros. 
        #This x_hat0 is what will be fed to our encoder.
        xhat_0 = np.copy(x_train)
        xhat_0[np.isnan(x_train)] = 0
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        
        data_manager = DataManager(dataset=xhat_0 , target=x_mask , **train_kwargs)
        
        return data_manager
    
    def training_step(self, data, mask):
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        loss = self.miwae_loss(iota_x = data,mask = mask)
        return loss

This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. 
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.
If FedProx optimisation is requested, `fedprox_mu` parameter must be defined here. It also must be a float between XX and YY.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [6]:
h = 128 # number of hidden units in (same for all MLPs)
d = 10 # dimension of the latent space, we choose d=1 for visualisation purposes
K = 20 # number of IS during training

model_args = {'n_features':data_size, 'n_latent':d,'n_hidden':h,'n_samples':K}

training_args = {
    'batch_size': 64, 
    'lr': 1e-3, 
    #'fedprox_mu': 0.01, 
    'epochs': 10, 
    'dry_run': False,  
    'batch_maxnum': 250 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

tags =  ['breast_cancer']
rounds = 10

## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `round_limit` rounds, applying the `node_selection_strategy` between the rounds

In [7]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

exp = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2022-04-20 15:13:59,193 fedbiomed INFO - Searching dataset with data tags: ['breast_cancer'] for all nodes
2022-04-20 15:14:09,207 fedbiomed INFO - Node selected for training -> node_8a14aca2-59e6-45fd-b00a-4c74206b334f
2022-04-20 15:14:09,209 fedbiomed INFO - Node selected for training -> node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
2022-04-20 15:14:09,210 fedbiomed INFO - Node selected for training -> node_13d7233c-daad-49e1-8f1c-c8dbac2aa845
2022-04-20 15:14:09,217 fedbiomed INFO - Checking data quality of federated datasets...
2022-04-20 15:14:09,347 fedbiomed DEBUG - Model file has been saved: /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0010/my_model_ba90f1a3-3b9d-4ad2-80b0-2340dcd3ae2e.py
2022-04-20 15:14:09,441 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0010/my_model_ba90f1a3-3b9d-4ad2-80b0-2340dcd3ae2e.py successful, with status code 201
2022-04-20 

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [8]:
exp.run()

2022-04-20 15:14:09,684 fedbiomed INFO - Sampled nodes in round 0 ['node_8a14aca2-59e6-45fd-b00a-4c74206b334f', 'node_1ff16015-8a76-43a9-a0c9-9d9f9167f500', 'node_13d7233c-daad-49e1-8f1c-c8dbac2aa845']
2022-04-20 15:14:09,686 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: node_8a14aca2-59e6-45fd-b00a-4c74206b334f 
					[1m Request: [0m: Perform training with the arguments: {'researcher_id': 'researcher_aaf86456-e652-46b0-8054-b7bb516705db', 'job_id': '56a8c6d8-af1c-470a-ab7f-09c09b33a5f5', 'training_args': {'test_ratio': 0.0, 'test_on_local_updates': False, 'test_on_global_updates': False, 'test_metric': None, 'test_metric_args': {}, 'batch_size': 64, 'lr': 0.001, 'epochs': 10, 'dry_run': False, 'batch_maxnum': 250}, 'training': True, 'model_args': {'n_features': 13, 'n_latent': 10, 'n_hidden': 128, 'n_samples': 20}, 'command': 'train', 'model_url': 'http://localhost:8844/media/uploads/2022/04/20/my_model_ba90f1a3-3b9d-4ad2-80b0-2340dcd3ae2e.py', 'params_url': 'http://lo

2022-04-20 15:14:19,952 fedbiomed INFO - Nodes that successfully reply in round 0 ['node_8a14aca2-59e6-45fd-b00a-4c74206b334f', 'node_1ff16015-8a76-43a9-a0c9-9d9f9167f500', 'node_13d7233c-daad-49e1-8f1c-c8dbac2aa845']
2022-04-20 15:14:20,142 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0010/aggregated_params_b4110c39-34f2-49e8-8f82-d2137820c40c.pt successful, with status code 201
2022-04-20 15:14:20,143 fedbiomed INFO - Saved aggregated params for round 0 in /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0010/aggregated_params_b4110c39-34f2-49e8-8f82-d2137820c40c.pt
2022-04-20 15:14:20,147 fedbiomed INFO - Sampled nodes in round 1 ['node_8a14aca2-59e6-45fd-b00a-4c74206b334f', 'node_1ff16015-8a76-43a9-a0c9-9d9f9167f500', 'node_13d7233c-daad-49e1-8f1c-c8dbac2aa845']
2022-04-20 15:14:20,152 fedbiomed INFO - [1mSending request[0m 
					[1m To[0m: node_8a14aca

2022-04-20 15:14:30,268 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_61f06d0a-efe2-466a-bcc2-2dbba8a1e049.pt successful, with status code 200
2022-04-20 15:14:30,273 fedbiomed INFO - Downloading model params after training on node_13d7233c-daad-49e1-8f1c-c8dbac2aa845 - from http://localhost:8844/media/uploads/2022/04/20/node_params_9f6ed3ae-1650-43ed-a58e-24f6082d3d48.pt
2022-04-20 15:14:30,300 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_9b0c4144-f4b4-4fae-be9b-a41abf17011f.pt successful, with status code 200
2022-04-20 15:14:30,307 fedbiomed INFO - Nodes that successfully reply in round 1 ['node_8a14aca2-59e6-45fd-b00a-4c74206b334f', 'node_1ff16015-8a76-43a9-a0c9-9d9f9167f500', 'node_13d7233c-daad-49e1-8f1c-c8dbac2aa845']
2022-04-20 15:14:30,409 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0010/aggregated_params_bd7815fb-b38d-452b-885e-1a2dee70ff11.pt succ

2022-04-20 15:14:40,462 fedbiomed INFO - Downloading model params after training on node_13d7233c-daad-49e1-8f1c-c8dbac2aa845 - from http://localhost:8844/media/uploads/2022/04/20/node_params_781b8840-4166-4c9b-8b0c-ff25088d0973.pt
2022-04-20 15:14:40,496 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_94053e80-0e7d-4ab4-830e-5c815a87b296.pt successful, with status code 200
2022-04-20 15:14:40,504 fedbiomed INFO - Downloading model params after training on node_8a14aca2-59e6-45fd-b00a-4c74206b334f - from http://localhost:8844/media/uploads/2022/04/20/node_params_4ed6febb-3959-4026-94df-a027e6db9730.pt
2022-04-20 15:14:40,725 fedbiomed DEBUG - upload (HTTP GET request) of file node_params_a3d59072-c746-4398-8bfc-129fb87f5ebc.pt successful, with status code 200
2022-04-20 15:14:40,731 fedbiomed INFO - Downloading model params after training on node_1ff16015-8a76-43a9-a0c9-9d9f9167f500 - from http://localhost:8844/media/uploads/2022/04/20/node_params_b63289de-bc6a-4233-982

2022-04-20 15:14:44,084 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m results uploaded successfully [0m
-----------------------------------------------------------------
2022-04-20 15:14:44,263 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m results uploaded successfully [0m
-----------------------------------------------------------------
2022-04-20 15:14:44,362 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_13d7233c-daad-49e1-8f1c-c8dbac2aa845
					[1m MESSAGE:[0m results uploaded successfully [0m
-----------------------------------------------------------------
2022-04-20 15:14:51,690 fedbiomed INFO - Downloading model params after training on node_1ff16015-8a76-43a9-a0c9-9d9f9167f500 - from http://localhost:8844/media/uploads/2022/04/20/node_params_b4f5ec89-39fc-4d40-99e0-8410edae23a9.pt
2022-04-20 15:14:51,717 fedbiomed DEBUG - upload (H

2022-04-20 15:14:52,216 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x133f5c4c0>, 'node_args': {'gpu': False, 'gpu_num': None, 'gpu_only': False}, 'batch_size': 64, 'lr': 0.001, 'epochs': 10, 'dry_run': False, 'batch_maxnum': 250}[0m
-----------------------------------------------------------------
2022-04-20 15:14:53,099 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_13d7233c-daad-49e1-8f1c-c8dbac2aa845
					[1m MESSAGE:[0m results uploaded successfully [0m
-----------------------------------------------------------------
2022-04-20 15:14:53,241 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m results uploaded successfully [0m
-----------------------------------------------------------------
2022-04-20 15:14:53,350 fedbiomed INFO - [1mINFO

					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m There is no test activated for the round. Please set flag for `test_on_global_updates`, `test_on_local_updates`, or both. Splitting dataset for testing will be ignored[0m
-----------------------------------------------------------------
2022-04-20 15:15:02,510 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x13344e130>, 'node_args': {'gpu': False, 'gpu_num': None, 'gpu_only': False}, 'batch_size': 64, 'lr': 0.001, 'epochs': 10, 'dry_run': False, 'batch_maxnum': 250}[0m
-----------------------------------------------------------------
2022-04-20 15:15:03,453 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_13d7233c-daad-49e1-8f1c-c8dbac2aa845
					[1m MESSAGE:[0m results uploaded successfully [0m
---------------------------------

2022-04-20 15:15:12,862 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_13d7233c-daad-49e1-8f1c-c8dbac2aa845
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x13a76d430>, 'node_args': {'gpu': False, 'gpu_num': None, 'gpu_only': False}, 'batch_size': 64, 'lr': 0.001, 'epochs': 10, 'dry_run': False, 'batch_maxnum': 250}[0m
-----------------------------------------------------------------
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m There is no test activated for the round. Please set flag for `test_on_global_updates`, `test_on_local_updates`, or both. Splitting dataset for testing will be ignored[0m
-----------------------------------------------------------------
2022-04-20 15:15:12,959 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_

					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m There is no test activated for the round. Please set flag for `test_on_global_updates`, `test_on_local_updates`, or both. Splitting dataset for testing will be ignored[0m
-----------------------------------------------------------------
2022-04-20 15:15:23,292 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x13344e460>, 'node_args': {'gpu': False, 'gpu_num': None, 'gpu_only': False}, 'batch_size': 64, 'lr': 0.001, 'epochs': 10, 'dry_run': False, 'batch_maxnum': 250}[0m
-----------------------------------------------------------------
					[1m NODE[0m node_13d7233c-daad-49e1-8f1c-c8dbac2aa845
					[1m MESSAGE:[0m There is no test activated for the round. Please set flag for `test_on_global_updates`, `test_on_local_updates`, or both. 

2022-04-20 15:15:33,535 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x13344e130>, 'node_args': {'gpu': False, 'gpu_num': None, 'gpu_only': False}, 'batch_size': 64, 'lr': 0.001, 'epochs': 10, 'dry_run': False, 'batch_maxnum': 250}[0m
-----------------------------------------------------------------
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m There is no test activated for the round. Please set flag for `test_on_global_updates`, `test_on_local_updates`, or both. Splitting dataset for testing will be ignored[0m
-----------------------------------------------------------------
2022-04-20 15:15:33,601 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_

2022-04-20 15:15:43,841 fedbiomed DEBUG - researcher_aaf86456-e652-46b0-8054-b7bb516705db
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m There is no test activated for the round. Please set flag for `test_on_global_updates`, `test_on_local_updates`, or both. Splitting dataset for testing will be ignored[0m
-----------------------------------------------------------------
2022-04-20 15:15:44,068 fedbiomed INFO - [1mINFO[0m
					[1m NODE[0m node_8a14aca2-59e6-45fd-b00a-4c74206b334f
					[1m MESSAGE:[0m training with arguments {'history_monitor': <fedbiomed.node.history_monitor.HistoryMonitor object at 0x13497deb0>, 'node_args': {'gpu': False, 'gpu_num': None, 'gpu_only': False}, 'batch_size': 64, 'lr': 0.001, 'epochs': 10, 'dry_run': False, 'batch_maxnum': 250}[0m
-----------------------------------------------------------------
					[1m NODE[0m node_1ff16015-8a76-43a9-a0c9-9d9f9167f500
					[1m MESSAGE:[0m There is no test activated for t

10

Local training results for each round and each node are available via `exp.training_replies()` (index 0 to (`rounds` - 1) ).

For example you can view the training results for the last round below.

Different timings (in seconds) are reported for each dataset of a node participating in a round :
- `rtime_training` real time (clock time) spent in the training function on the node
- `ptime_training` process time (user and system CPU) spent in the training function on the node
- `rtime_total` real time (clock time) spent in the researcher between sending the request and handling the response, at the `Job()` layer

In [9]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1].data()
for c in range(len(round_data)):
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['node_id'],
        rtraining = round_data[c]['timing']['rtime_training'],
        ptraining = round_data[c]['timing']['ptime_training'],
        rtotal = round_data[c]['timing']['rtime_total']))
print('\n')
    
exp.training_replies()[rounds - 1].dataframe()


List the training rounds :  dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

List the nodes for the last training round and their timings : 
	- node_1ff16015-8a76-43a9-a0c9-9d9f9167f500 :    
		rtime_training=1.19 seconds    
		ptime_training=0.65 seconds    
		rtime_total=10.03 seconds
	- node_8a14aca2-59e6-45fd-b00a-4c74206b334f :    
		rtime_training=1.32 seconds    
		ptime_training=0.67 seconds    
		rtime_total=10.08 seconds
	- node_13d7233c-daad-49e1-8f1c-c8dbac2aa845 :    
		rtime_training=1.30 seconds    
		ptime_training=0.66 seconds    
		rtime_total=10.12 seconds




Unnamed: 0,success,msg,dataset_id,node_id,params_path,params,timing
0,True,,dataset_a5c55855-7841-4f2d-847d-189b5c0e92b0,node_1ff16015-8a76-43a9-a0c9-9d9f9167f500,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.0364), tensor...","{'rtime_training': 1.1890774219991727, 'ptime_..."
1,True,,dataset_74f13994-4acb-4419-bd39-5d019afee4c3,node_8a14aca2-59e6-45fd-b00a-4c74206b334f,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.0424), tensor...","{'rtime_training': 1.318693543000336, 'ptime_t..."
2,True,,dataset_7bc4bb27-e23b-4f93-8af1-d88f5ff13726,node_13d7233c-daad-49e1-8f1c-c8dbac2aa845,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.0502), tensor...","{'rtime_training': 1.301112667999405, 'ptime_t..."


Federated parameters for each round are available via `exp.aggregated_params()` (index 0 to (`rounds` - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [10]:
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- params_path: ", exp.aggregated_params()[rounds - 1]['params_path'])
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())


List the training rounds :  dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

Access the federated params for the last training round :
	- params_path:  /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0010/aggregated_params_4e575ea7-19dd-40aa-99b0-7ded51b48892.pt
	- parameter data:  odict_keys(['encoder.0.weight', 'encoder.0.bias', 'encoder.2.weight', 'encoder.2.bias', 'encoder.4.weight', 'encoder.4.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias', 'decoder.4.weight', 'decoder.4.bias'])


## Testing on an external dataset

In [11]:
# from the test dataset, we will remove randomly 50% of data
np.random.seed(1234)

perc_miss = 0.5 # 50% of missing data

n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features
print(n,p)
xfull = np.copy(data_test)
xfull = (xfull - np.mean(xfull,0))/np.std(xfull,0)
xmiss = np.copy(xfull)
xmiss_flat = xmiss.flatten()
miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                replace=False)
xmiss_flat[miss_pattern] = np.nan 
xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
xhat_0 = np.copy(xmiss)
xhat_0[np.isnan(xmiss)] = 0
xhat = np.copy(xhat_0) # This will be out imputed data matrix

102 13


In [12]:
L = 100

# extract federated model into PyTorch framework
model = exp.model_instance()
model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

encoder = model.encoder
decoder = model.decoder

In [13]:
p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)

def miwae_impute(iota_x,mask,L):
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([L])
    zgivenx_flat = zgivenx.reshape([L*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[L,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[L,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([L*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([L,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    xgivenz = td.Independent(td.StudentT(loc=all_means_obs_model, scale=all_scales_obs_model, df=all_degfreedom_obs_model),1)

    imp_weights = torch.nn.functional.softmax(logpxobsgivenz + logpz - logq,0) # these are w_1,....,w_L for all observations in the batch
    xms = xgivenz.mean.reshape([L,batch_size,p])  # that's the only line that changed!
    xm=torch.einsum('ki,kij->ij', imp_weights, xms) 

    return xm

In [14]:
def mse(xhat,xtrue,mask): # MSE function for imputations
    xhat = np.array(xhat)
    xtrue = np.array(xtrue)
    return np.mean(np.power(xhat-xtrue,2)[~mask])

In [15]:
### Now we do the imputation

xhat[~mask] = miwae_impute(iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),L= L).cpu().data.numpy()[~mask]
err = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE  %g' %err)
print('-----')

Imputation MSE  0.634149
-----


## Local training on a client

In [16]:
data_client_1 = Clients_data[0]
n = data_client_1.shape[0] # number of observations
p = data_client_1.shape[1] # number of features
print(n,p)
xfull = np.copy(data_client_1)
xfull = (xfull - np.mean(xfull,0))/np.std(xfull,0)
xmiss = np.copy(xfull)
xmiss_flat = xmiss.flatten()
miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                replace=False)
xmiss_flat[miss_pattern] = np.nan 
xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
xhat_0 = np.copy(xmiss)
xhat_0[np.isnan(xmiss)] = 0
xhat = np.copy(xhat_0) # This will be out imputed data matrix

133 13


In [17]:
p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)

def miwae_loss(iota_x,mask):
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([K])
    zgivenx_flat = zgivenx.reshape([K*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[K,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[K,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([K*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([K,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

    return neg_bound

In [18]:
n_epochs = 10*rounds
bs = 64 # batch size

encoder = nn.Sequential(
    torch.nn.Linear(p, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 2*d),  # the encoder will output both the mean and the diagonal covariance
)

decoder = nn.Sequential(
    torch.nn.Linear(d, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 3*p),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
)

optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),lr=1e-3)

def weights_init(layer):
    if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
        
encoder.apply(weights_init)
decoder.apply(weights_init)

for ep in range(1,n_epochs):
    perm = np.random.permutation(n) # We use the "random reshuffling" version of SGD
    batches_data = np.array_split(xhat_0[perm,], n/bs)
    batches_mask = np.array_split(mask[perm,], n/bs)
    for it in range(len(batches_data)):
        optimizer.zero_grad()
        encoder.zero_grad()
        decoder.zero_grad()
        b_data = torch.from_numpy(batches_data[it]).float()
        b_mask = torch.from_numpy(batches_mask[it]).float()
        loss = miwae_loss(iota_x = b_data,mask = b_mask)
        loss.backward()
        optimizer.step()
    if ep % 5 == 1:
        print('Epoch %g' %ep)
        print('MIWAE likelihood bound  %g' %(-np.log(K)-miwae_loss(iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float()).cpu().data.numpy())) # Gradient step      

Epoch 1
MIWAE likelihood bound  -8.96027
Epoch 6
MIWAE likelihood bound  -8.2552
Epoch 11
MIWAE likelihood bound  -7.86011
Epoch 16
MIWAE likelihood bound  -7.32598
Epoch 21
MIWAE likelihood bound  -6.79499
Epoch 26
MIWAE likelihood bound  -6.43164
Epoch 31
MIWAE likelihood bound  -6.11862
Epoch 36
MIWAE likelihood bound  -5.41204
Epoch 41
MIWAE likelihood bound  -4.71362
Epoch 46
MIWAE likelihood bound  -4.57333
Epoch 51
MIWAE likelihood bound  -4.2328
Epoch 56
MIWAE likelihood bound  -4.00634
Epoch 61
MIWAE likelihood bound  -4.06771
Epoch 66
MIWAE likelihood bound  -4.26915
Epoch 71
MIWAE likelihood bound  -4.14849
Epoch 76
MIWAE likelihood bound  -3.71494
Epoch 81
MIWAE likelihood bound  -3.93023
Epoch 86
MIWAE likelihood bound  -3.58164
Epoch 91
MIWAE likelihood bound  -3.51318
Epoch 96
MIWAE likelihood bound  -3.33888


In [19]:
### Now we do the imputation

xhat[~mask] = miwae_impute(iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),L= L).cpu().data.numpy()[~mask]
err = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE  %g' %err)
print('-----')

Imputation MSE  0.73025
-----
