# Missing data imputation with Fedbiomed using MIWAE

In this notebook we show how to impute missing not at random (MAR) data in a federated setting using MIWAE (https://arxiv.org/abs/2006.12871). We will compare results of federated training using FedAvg, FedProx and [FedCos](https://arxiv.org/abs/2204.03174) with local results.

In [1]:
%load_ext autoreload
%autoreload 2

## Prepare the data

For this experiment we will use the breast cancer data from sklearn.

In [2]:
import pandas as pd
import numpy as np

data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]

In [3]:
from sklearn.model_selection import train_test_split

data_train, data_test, labels_train, labels_test = train_test_split(data, target, test_size=0.20, random_state=42)
df_data_train = pd.DataFrame(data_train)
N_train = len(df_data_train)
client_1, client_2, client_3 = np.split(df_data_train.sample(frac=1,random_state=42), \
                                        [int(.33*N_train), int(.66*len(df_data_train))])

Clients_data=[client_1, client_2, client_3]

# from each dataset we will remove randomly 50% of data
np.random.seed(1234)

# 50% of missing data for client 1, 30% for client 2, 60% for client 3
perc_miss_list = [0.5,0.3,0.6] 

Clients_missing = []
for perc,c in enumerate(Clients_data):
    perc_miss=perc_miss_list[perc]
    n = c.shape[0] # number of observations
    p = c.shape[1] # number of features
    xmiss = np.copy(c)
    xmiss = (xmiss - np.mean(xmiss,0))/np.std(xmiss,0)
    xmiss_flat = xmiss.flatten()
    miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                    replace=False)
    xmiss_flat[miss_pattern] = np.nan 
    xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
    mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
    Clients_missing.append(xmiss)

import os 
os.makedirs('clients_data', exist_ok=True) 
for i in range(len(Clients_missing)):
    pd.DataFrame(Clients_missing[i]).to_csv('clients_data/client_'+str(i+1)+'.csv',index=False)

## Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

## Setting the nodes up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 1 (csv) to add client_1 dataset to the first node
  * Provide the correct tag by entering:  breast_cancer
  * Pick the folder where client_1 dataset has been saved
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node start`. Wait until you get `Starting task manager`. it means you are online.
4. Following the same procedure, you can create additional nodes for clients 2 and 3.

Check available clients:

In [4]:
from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)
xx = req.list()
dataset_size = [xx[i][0]['shape'][1] for i in xx]
assert min(dataset_size)==max(dataset_size)
data_size = dataset_size[0]

2022-05-20 10:03:32,228 fedbiomed INFO - Component environment:
2022-05-20 10:03:32,229 fedbiomed INFO - type = ComponentType.RESEARCHER
2022-05-20 10:03:32,270 fedbiomed INFO - Messaging researcher_46feb887-a218-4d51-adea-4c996b331186 successfully connected to the message broker, object = <fedbiomed.common.messaging.Messaging object at 0x13850f190>
2022-05-20 10:03:32,371 fedbiomed INFO - Listing available datasets in all nodes... 
2022-05-20 10:03:42,389 fedbiomed INFO - 
 Node: node_249fa875-5b53-49c9-aebd-c53009d774a5 | Number of Datasets: 1 
+---------------+-------------+-------------------+---------------+-----------+
| name          | data_type   | tags              | description   | shape     |
| breast_cancer | csv         | ['breast_cancer'] | breast_cancer | [134, 13] |
+---------------+-------------+-------------------+---------------+-----------+

2022-05-20 10:03:42,391 fedbiomed INFO - 
 Node: node_e27fee9e-415f-4927-8f52-41ac937784cf | Number of Datasets: 1 
+---------

## Define an experiment model and parameters

Declare a torch.nn MIWAETrainingPlan class to send for training on the node

Note : write **only** the code to export in the following cell

In [5]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import numpy as np
import torch.distributions as td
import pandas as pd

from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MIWAETrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(MIWAETrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from torchvision import datasets, transforms",
               "import torch.distributions as td",
               "import pandas as pd",
               "import numpy as np"]
        
        self.n_features=model_args['n_features']
        self.n_latent=model_args['n_latent']
        self.n_hidden=model_args['n_hidden']
        self.n_samples=model_args['n_samples']
        
        self.add_dependency(deps)
        
        # the encoder will output both the mean and the diagonal covariance
        self.encoder=nn.Sequential(
                        torch.nn.Linear(self.n_features, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 2*self.n_latent),  
                        )
        # the decoder will output both the mean, the scale, 
        # and the number of degrees of freedoms (hence the 3*p)
        self.decoder = nn.Sequential(
                        torch.nn.Linear(self.n_latent, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 3*self.n_features),  
                        )
        
        self.optimizer = torch.optim.Adam(list(self.encoder.parameters()) \
                                    + list(self.decoder.parameters()),lr=1e-3)
              
        self.encoder.apply(self.weights_init)
        self.decoder.apply(self.weights_init)
    
    def weights_init(self,layer):
        if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
    
    def miwae_loss(self,iota_x,mask):
        batch_size = iota_x.shape[0]
        out_encoder = self.encoder(iota_x)
        # prior
        p_z = td.Independent(td.Normal(loc=torch.zeros(self.n_latent).to(self.device)\
                                       ,scale=torch.ones(self.n_latent).to(self.device)),1)
        
        q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :self.n_latent],\
                                                scale=torch.nn.Softplus()\
                                                (out_encoder[..., self.n_latent:\
                                                             (2*self.n_latent)])),1)

        zgivenx = q_zgivenxobs.rsample([self.n_samples])
        zgivenx_flat = zgivenx.reshape([self.n_samples*batch_size,self.n_latent])

        out_decoder = self.decoder(zgivenx_flat)
        all_means_obs_model = out_decoder[..., :self.n_features]
        all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., self.n_features:\
                                                               (2*self.n_features)]) + 0.001
        all_degfreedom_obs_model = torch.nn.Softplus()\
        (out_decoder[..., (2*self.n_features):(3*self.n_features)]) + 3

        data_flat = torch.Tensor.repeat(iota_x,[self.n_samples,1]).reshape([-1,1])
        tiledmask = torch.Tensor.repeat(mask,[self.n_samples,1])

        all_log_pxgivenz_flat = torch.distributions.StudentT\
        (loc=all_means_obs_model.reshape([-1,1]),\
         scale=all_scales_obs_model.reshape([-1,1]),\
         df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
        all_log_pxgivenz = all_log_pxgivenz_flat.reshape([self.n_samples*batch_size,self.n_features])

        logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([self.n_samples,batch_size])
        logpz = p_z.log_prob(zgivenx)
        logq = q_zgivenxobs.log_prob(zgivenx)

        neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

        return neg_bound

    def training_data(self,  batch_size = 48):
        
        df = pd.read_csv(self.dataset_path, sep=',', index_col=False)
        x_train = df.values
        x_mask = np.isfinite(x_train)
        # xhat_0: missing values are replaced by zeros. 
        #This x_hat0 is what will be fed to our encoder.
        xhat_0 = np.copy(x_train)
        xhat_0[np.isnan(x_train)] = 0
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        
        data_manager = DataManager(dataset=xhat_0 , target=x_mask , **train_kwargs)
        
        return data_manager
    
    def training_step(self, data, mask):
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        loss = self.miwae_loss(iota_x = data,mask = mask)
        return loss

This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. 
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.
* data `tags` to search nodes for training.
* total number of `rounds`.
If FedProx optimisation is requested, `fedprox_mu` parameter must be defined here. It also must be a float between XX and YY.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [6]:
h = 128 # number of hidden units in (same for all MLPs)
d = 10 # dimension of the latent space, we choose d=1 for visualisation purposes
K = 20 # number of IS during training

n_epochs=5

model_args = {'n_features':data_size, 'n_latent':d,'n_hidden':h,'n_samples':K}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'log_interval' : 1,
    'epochs': n_epochs, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

tags =  ['breast_cancer']
rounds = 30

## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `round_limit` rounds, applying the `node_selection_strategy` between the rounds

In [7]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

exp = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2022-05-20 10:03:55,794 fedbiomed INFO - Searching dataset with data tags: ['breast_cancer'] for all nodes
2022-05-20 10:04:05,812 fedbiomed INFO - Node selected for training -> node_249fa875-5b53-49c9-aebd-c53009d774a5
2022-05-20 10:04:05,813 fedbiomed INFO - Node selected for training -> node_e27fee9e-415f-4927-8f52-41ac937784cf
2022-05-20 10:04:05,814 fedbiomed INFO - Node selected for training -> node_efa8cc39-825f-438a-940c-fce941364667
2022-05-20 10:04:05,817 fedbiomed INFO - Checking data quality of federated datasets...
2022-05-20 10:04:05,976 fedbiomed DEBUG - Model file has been saved: /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0000/my_model_d017b264-e90c-4c8b-a879-a2a02e3157c8.py
2022-05-20 10:04:06,295 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0000/my_model_d017b264-e90c-4c8b-a879-a2a02e3157c8.py successful, with status code 201
2022-05-20 

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [None]:
exp.run()

Local training results for each round and each node are available via `exp.training_replies()` (index 0 to (`rounds` - 1) ).

For example you can view the training results for the last round below.

Different timings (in seconds) are reported for each dataset of a node participating in a round :
- `rtime_training` real time (clock time) spent in the training function on the node
- `ptime_training` process time (user and system CPU) spent in the training function on the node
- `rtime_total` real time (clock time) spent in the researcher between sending the request and handling the response, at the `Job()` layer

In [9]:
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1].data()
for c in range(len(round_data)):
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = round_data[c]['node_id'],
        rtraining = round_data[c]['timing']['rtime_training'],
        ptraining = round_data[c]['timing']['ptime_training'],
        rtotal = round_data[c]['timing']['rtime_total']))
print('\n')
    
exp.training_replies()[rounds - 1].dataframe()


List the training rounds :  dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

List the nodes for the last training round and their timings : 
	- node_249fa875-5b53-49c9-aebd-c53009d774a5 :    
		rtime_training=0.67 seconds    
		ptime_training=0.33 seconds    
		rtime_total=10.03 seconds
	- node_efa8cc39-825f-438a-940c-fce941364667 :    
		rtime_training=0.70 seconds    
		ptime_training=0.35 seconds    
		rtime_total=10.07 seconds
	- node_e27fee9e-415f-4927-8f52-41ac937784cf :    
		rtime_training=0.70 seconds    
		ptime_training=0.35 seconds    
		rtime_total=10.12 seconds




Unnamed: 0,success,msg,dataset_id,node_id,params_path,params,timing
0,True,,dataset_fca98937-6cb8-4587-bc69-7f0fdad09abd,node_249fa875-5b53-49c9-aebd-c53009d774a5,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.1305), tensor...","{'rtime_training': 0.6657301810000149, 'ptime_..."
1,True,,dataset_b231ccd0-b513-4a32-b4d6-af5412e8c679,node_efa8cc39-825f-438a-940c-fce941364667,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.1253), tensor...","{'rtime_training': 0.6970942529999888, 'ptime_..."
2,True,,dataset_60c3d82d-8590-4d12-815d-ba11aae8af13,node_e27fee9e-415f-4927-8f52-41ac937784cf,/Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed...,"{'encoder.0.weight': [[tensor(-0.1385), tensor...","{'rtime_training': 0.7043897529999867, 'ptime_..."


Federated parameters for each round are available via `exp.aggregated_params()` (index 0 to (`rounds` - 1) ).

For example you can view the federated parameters for the last round of the experiment :

In [10]:
print("\nList the training rounds : ", exp.aggregated_params().keys())

print("\nAccess the federated params for the last training round :")
print("\t- params_path: ", exp.aggregated_params()[rounds - 1]['params_path'])
print("\t- parameter data: ", exp.aggregated_params()[rounds - 1]['params'].keys())


List the training rounds :  dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])

Access the federated params for the last training round :
	- params_path:  /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0000/aggregated_params_6e1939d0-890e-430c-b695-945de713f0a4.pt
	- parameter data:  odict_keys(['encoder.0.weight', 'encoder.0.bias', 'encoder.2.weight', 'encoder.2.bias', 'encoder.4.weight', 'encoder.4.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias', 'decoder.4.weight', 'decoder.4.bias'])


## Run the experiment with FedProx

We repeat the federated training but using FedProx as aggregation scheme (starting from the second iteration).

In [11]:
# To make the method fairly comparable with FedCos, during the first round
# we will simply use FedAvg with standard optimization scheme: the FedProx penalization
# term will be introduced exclusively from the second round.
# training_args.update(fedprox_mu = 0.)

exp_fedprox = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

2022-05-20 10:09:53,302 fedbiomed INFO - Searching dataset with data tags: ['breast_cancer'] for all nodes
2022-05-20 10:10:03,324 fedbiomed INFO - Node selected for training -> node_249fa875-5b53-49c9-aebd-c53009d774a5
2022-05-20 10:10:03,325 fedbiomed INFO - Node selected for training -> node_efa8cc39-825f-438a-940c-fce941364667
2022-05-20 10:10:03,326 fedbiomed INFO - Node selected for training -> node_e27fee9e-415f-4927-8f52-41ac937784cf
2022-05-20 10:10:03,328 fedbiomed INFO - Checking data quality of federated datasets...
2022-05-20 10:10:03,356 fedbiomed DEBUG - Model file has been saved: /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0001/my_model_803a2080-cce2-415f-a88b-82b003c0643c.py
2022-05-20 10:10:03,448 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0001/my_model_803a2080-cce2-415f-a88b-82b003c0643c.py successful, with status code 201
2022-05-20 

In [None]:
exp_fedprox.run_once()

In [None]:
# Starting from the second round, FedProx is used with mu=0.1
# We first update the training args
training_args.update(fedprox_mu = 0.1)
# Then update training args in the experiment
exp_fedprox.set_training_args(training_args)
exp_fedprox.run()

## Run the experiment with FedCos

And finally we propose to use FedCos as well, which introduce an alternative penalization term with cosine similarity:

In [14]:
from fedbiomed.researcher.aggregators.fedcos import FedCos

# Warning: we can not perform simultaneously FedProx and FedCos, so only one (at most) 
# between 'fedprox_mu' and 'fedcos_mu' should be provided in 'training_args'.
# Therefore we will remove 'fedprox_mu' from 'training_args' before defining 'fedcos_mu'
del training_args['fedprox_mu'] 

training_args.update(fedcos_mu = 0.01)

exp_fedcos = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedCos(),
                 node_selection_strategy=None)

2022-05-20 10:15:15,530 fedbiomed INFO - Searching dataset with data tags: ['breast_cancer'] for all nodes
2022-05-20 10:15:25,552 fedbiomed INFO - Node selected for training -> node_e27fee9e-415f-4927-8f52-41ac937784cf
2022-05-20 10:15:25,554 fedbiomed INFO - Node selected for training -> node_249fa875-5b53-49c9-aebd-c53009d774a5
2022-05-20 10:15:25,554 fedbiomed INFO - Node selected for training -> node_efa8cc39-825f-438a-940c-fce941364667
2022-05-20 10:15:25,560 fedbiomed INFO - Checking data quality of federated datasets...
2022-05-20 10:15:25,589 fedbiomed DEBUG - Model file has been saved: /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0002/my_model_dcd0983c-ca68-4eb5-b395-85361d347271.py
2022-05-20 10:15:25,742 fedbiomed DEBUG - upload (HTTP POST request) of file /Users/balelli/ownCloud/INRIA_EPIONE/FedBioMed/fedbiomed/var/experiments/Experiment_0002/my_model_dcd0983c-ca68-4eb5-b395-85361d347271.py successful, with status code 201
2022-05-20 

In [None]:
exp_fedcos.run()

# Test and comparison to local training

## 1. Testing on an external dataset

First of all we are going to test the performance of the final federated model to impute missing data on a test dataset. To this extent we are going to remove randomly 50% of samples from the test dataset, `data_test`, defined at the beginning of this notebook.

In [16]:
# from the test dataset, we will remove randomly 50% of data
np.random.seed(1234)

perc_miss = 0.5 # 50% of missing data

n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features
xfull = np.copy(data_test)
xfull = (xfull - np.mean(xfull,0))/np.std(xfull,0)
xmiss = np.copy(xfull)
xmiss_flat = xmiss.flatten()
miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                replace=False)
xmiss_flat[miss_pattern] = np.nan 
xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
xhat_0 = np.copy(xmiss)
xhat_0[np.isnan(xmiss)] = 0
xhat = np.copy(xhat_0) # This will be out imputed data matrix

We define the MIWAE imputation routine:

In [17]:
def miwae_impute(encoder,decoder,iota_x,mask,d,L):
    
    p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)
    
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([L])
    zgivenx_flat = zgivenx.reshape([L*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[L,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[L,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([L*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([L,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    xgivenz = td.Independent(td.StudentT(loc=all_means_obs_model, scale=all_scales_obs_model, df=all_degfreedom_obs_model),1)

    imp_weights = torch.nn.functional.softmax(logpxobsgivenz + logpz - logq,0) # these are w_1,....,w_L for all observations in the batch
    xms = xgivenz.mean.reshape([L,batch_size,p])  # that's the only line that changed!
    xm=torch.einsum('ki,kij->ij', imp_weights, xms) 

    return xm

As well as the MSE function:

In [18]:
def mse(xhat,xtrue,mask): # MSE function for imputations
    xhat = np.array(xhat)
    xtrue = np.array(xtrue)
    return np.mean(np.power(xhat-xtrue,2)[~mask])

We instantiate the model using last updated federated parameters:

In [19]:
# extract federated model into PyTorch framework
model = exp.model_instance()
model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

encoder = model.encoder
decoder = model.decoder

Same for the models trained with FedProx and FedCos

In [20]:
model_fedprox = exp_fedprox.model_instance()
model_fedprox.load_state_dict(exp_fedprox.aggregated_params()[rounds - 1]['params'])

encoder_fedprox = model_fedprox.encoder
decoder_fedprox = model_fedprox.decoder

model_fedcos = exp_fedcos.model_instance()
# We should remove the 'disp_global' key from the aggregated paramters since this
# is not needed to instantiate the model (just needed for training)
aggregated_params_fedcos = exp_fedcos.aggregated_params()[rounds - 1]['params']
del aggregated_params_fedcos['disp_global']
model_fedcos.load_state_dict(aggregated_params_fedcos)

encoder_fedcos = model_fedcos.encoder
decoder_fedcos = model_fedcos.decoder

And we finally do the imputation and evaluate the corresponding imputation error through MSE for each federated model:

In [21]:
L = 100

xhat[~mask] = miwae_impute(encoder = encoder,decoder = decoder,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model on testing data %g' %err_test_data)
print('-----')

xhat[~mask] = miwae_impute(encoder = encoder_fedprox,decoder = decoder_fedprox,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data_fedprox = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model (with fedprox) on testing data  %g' %err_test_data_fedprox)
print('-----')

xhat[~mask] = miwae_impute(encoder = encoder_fedcos,decoder = decoder_fedcos,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data_fedcos = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model (with fedcos) on testing data  %g' %err_test_data_fedcos)
print('-----')

Imputation MSE of fed model on testing data 0.584008
-----
Imputation MSE of fed model (with fedprox) on testing data  0.56444
-----
Imputation MSE of fed model (with fedcos) on testing data  0.581398
-----


## 2. Testing on a client's dataset

We are now going to use the final federated model to impute missing data of client 1, which have been used for training:

In [22]:
# We first recover data (full and with missing entries) from client 1
data_client_1 = np.copy(Clients_data[0])
xfull_cl1 = np.copy(data_client_1)
xfull_cl1 = (xfull_cl1 - np.mean(xfull_cl1,0))/np.std(xfull_cl1,0)
xmiss_cl1 = np.copy(Clients_missing[0])

mask_cl1 = np.isfinite(xmiss_cl1) # binary mask that indicates which values are missing
xhat_0_cl1 = np.copy(xmiss_cl1)
xhat_0_cl1[np.isnan(xmiss_cl1)] = 0
xhat_cl1 = np.copy(xhat_0_cl1) # This will be out imputed data matrix

### Now we do the imputation

xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder,decoder = decoder, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model on data from client 1  %g' %err_cl1_data)
print('-----')

xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder_fedprox,decoder = decoder_fedprox, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data_fedprox = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model (with fedprox) on data from client 1  %g' %err_cl1_data_fedprox)
print('-----')

xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder_fedcos,decoder = decoder_fedcos, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data_fedcos = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model (with fedcos) on data from client 1  %g' %err_cl1_data_fedcos)
print('-----')

Imputation MSE of fed model on data from client 1  0.596102
-----
Imputation MSE of fed model (with fedprox) on data from client 1  0.602942
-----
Imputation MSE of fed model (with fedcos) on data from client 1  0.591246
-----


## 3. Local training and testing on a client

Finally, we test the performance of the same model trained locally and tested on the dataset from client 1. We will use a total of `epochs`x`rounds` local epochs.

In [23]:
def miwae_loss(iota_x,mask,d):
    
    p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)
    
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([K])
    zgivenx_flat = zgivenx.reshape([K*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[K,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[K,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([K*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([K,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

    return neg_bound

We perform the local training:

In [24]:
# We recover again data (full and with missing entries) from client 1
data_client_1 = np.copy(Clients_data[0])
xfull_cl1 = np.copy(data_client_1)
xfull_cl1 = (xfull_cl1 - np.mean(xfull_cl1,0))/np.std(xfull_cl1,0)
xmiss_cl1 = np.copy(Clients_missing[0])

mask_cl1 = np.isfinite(xmiss_cl1) # binary mask that indicates which values are missing
xhat_0_cl1 = np.copy(xmiss_cl1)
xhat_0_cl1[np.isnan(xmiss_cl1)] = 0
xhat_cl1 = np.copy(xhat_0_cl1) # This will be out imputed data matrix

n_epochs_local = n_epochs*rounds
bs = 48 # batch size

encoder_cl1 = nn.Sequential(
    torch.nn.Linear(p, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 2*d),  # the encoder will output both the mean and the diagonal covariance
)

decoder_cl1 = nn.Sequential(
    torch.nn.Linear(d, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 3*p),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
)

optimizer_cl1 = torch.optim.Adam(list(encoder_cl1.parameters()) + list(decoder_cl1.parameters()),lr=1e-3)

def weights_init(layer):
    if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
        
encoder_cl1.apply(weights_init)
decoder_cl1.apply(weights_init)

for ep in range(1,n_epochs_local):
    perm = np.random.permutation(n) # We use the "random reshuffling" version of SGD
    batches_data = np.array_split(xhat_0_cl1[perm,], n/bs)
    batches_mask = np.array_split(mask_cl1[perm,], n/bs)
    for it in range(len(batches_data)):
        optimizer_cl1.zero_grad()
        encoder_cl1.zero_grad()
        decoder_cl1.zero_grad()
        b_data = torch.from_numpy(batches_data[it]).float()
        b_mask = torch.from_numpy(batches_mask[it]).float()
        loss = miwae_loss(iota_x = b_data,mask = b_mask, d = d)
        loss.backward()
        optimizer_cl1.step()
    if ep % rounds == 1:
        print('Epoch %g' %ep)
        print('MIWAE likelihood bound  %g' %(-np.log(K)-miwae_loss(iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(), d = d).cpu().data.numpy())) # Gradient step      

Epoch 1
MIWAE likelihood bound  -2.8003
Epoch 31
MIWAE likelihood bound  -2.86463
Epoch 61
MIWAE likelihood bound  -2.7242
Epoch 91
MIWAE likelihood bound  -2.89334
Epoch 121
MIWAE likelihood bound  -2.81006


And we do the imputation on the same dataset:

In [25]:
xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder_cl1, decoder = decoder_cl1, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_local_cl1_data = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of local model on data from same client (cl 1)  %g' %err_local_cl1_data)
print('-----')

Imputation MSE of local model on data from same client (cl 1)  1.04139
-----


As well as the imputation on the external test dataset:

In [26]:
xhat[~mask] = miwae_impute(encoder = encoder_cl1,decoder = decoder_cl1,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_local_cl1_test_data = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of local model on testing data %g' %err_local_cl1_test_data)
print('-----')

Imputation MSE of local model on testing data 1.03618
-----


## Comparison of obtained results:

In [27]:
from tabulate import tabulate

print('Imputation MSE on testing data')
print('-----')
data = [['FedAvg', err_test_data],
['FedProx', err_test_data_fedprox],
['FedCos', err_test_data_fedcos],
['Local (cl1)', err_local_cl1_test_data]]
print (tabulate(data, headers=["Model", "Mean Squared Error (\u2193)"]))
print('-----')
print('-----')
print('Imputation MSE on local data from client 1')
print('-----')
data = [['FedAvg', err_cl1_data],
['FedProx', err_cl1_data_fedprox],
['FedCos', err_cl1_data_fedcos],
['Local (cl1)', err_local_cl1_data]]
print (tabulate(data, headers=["Model", "Mean Squared Error (\u2193)"]))

Imputation MSE on testing data
-----
Model          Mean Squared Error (↓)
-----------  ------------------------
FedAvg                       0.584008
FedProx                      0.56444
FedCos                       0.581398
Local (cl1)                  1.03618
-----
-----
Imputation MSE on local data from client 1
-----
Model          Mean Squared Error (↓)
-----------  ------------------------
FedAvg                       0.596102
FedProx                      0.602942
FedCos                       0.591246
Local (cl1)                  1.04139


2022-05-20 10:23:18,146 fedbiomed INFO - [1mCRITICAL[0m
					[1m NODE[0m node_efa8cc39-825f-438a-940c-fce941364667
					[1m MESSAGE:[0m Node stopped in signal_handler, probably by user decision (Ctrl C)[0m
-----------------------------------------------------------------
2022-05-20 10:23:19,730 fedbiomed INFO - [1mCRITICAL[0m
					[1m NODE[0m node_e27fee9e-415f-4927-8f52-41ac937784cf
					[1m MESSAGE:[0m Node stopped in signal_handler, probably by user decision (Ctrl C)[0m
-----------------------------------------------------------------
2022-05-20 10:23:20,980 fedbiomed INFO - [1mCRITICAL[0m
					[1m NODE[0m node_249fa875-5b53-49c9-aebd-c53009d774a5
					[1m MESSAGE:[0m Node stopped in signal_handler, probably by user decision (Ctrl C)[0m
-----------------------------------------------------------------


As you can see, the federated model performs better than the local one!