# Missing data imputation with Fedbiomed using MIWAE

In this notebook we show how to impute missing not at random (MAR) data in a federated setting using MIWAE (https://arxiv.org/abs/2006.12871). We will compare results of federated training using FedAvg and FedProx with local results.

In [1]:
%load_ext autoreload
%autoreload 2

## Prepare the data

For this experiment we will use the data extracted from ADNI.

In [2]:
import pandas as pd
import numpy as np
import os
from copy import deepcopy

cwd = os.getcwd()
data_file = os.path.join(cwd, "data/ADNI/data_irene_dx_2g_corrected.csv")
raw_df = pd.read_csv(data_file, sep=",",index_col="RID")
target = raw_df["DX"]
data = deepcopy(raw_df)
del data["DX"]
print(data.shape)

(311, 130)


In [3]:
from sklearn.model_selection import train_test_split

#train test split
data_train, data_test, labels_train, labels_test = train_test_split(data, target, test_size=0.20, random_state=42)

# split train across datasets in a non-iid manner: 
#client 1 will only contain ADNI subjects, client 2 will only contain controls
client_1 = data_train.loc[labels_train == 'AD']
client_2 = data_train.loc[labels_train == 'NL']

#Clients_data contains original full data of each client 
Clients_data=[client_1, client_2]

In [4]:
# from each dataset we will remove randomly 30% of data
np.random.seed(1234)

# 30% of missing data for client 1, 30% for client 2
perc_miss_list = [0.3,0.3] 

#Clients_data contains data of each client with missing entries wrt to perc_miss_list
Clients_missing = []
for perc,c in enumerate(Clients_data):
    perc_miss=perc_miss_list[perc]
    n = c.shape[0] # number of observations
    p = c.shape[1] # number of features
    xmiss = np.copy(c)
    xmiss_flat = xmiss.flatten()
    miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                    replace=False)
    xmiss_flat[miss_pattern] = np.nan 
    xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
    mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
    Clients_missing.append(xmiss)
    
# And we finally save clients data to be provided to the nodes
import os 
os.makedirs('data/clients_data', exist_ok=True) 
for i in range(len(Clients_missing)):
    pd.DataFrame(Clients_missing[i]).to_csv('data/clients_data/client_'+str(i+1)+'.csv',index=False)

## Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

## Setting the nodes up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 1 (csv) to add client_1 dataset to the first node
  * Provide the correct tag by entering:  `adni`
  * Pick the folder where client_1 dataset has been saved
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node start`. Wait until you get `Starting task manager`. it means you are online.
4. Following the same procedure, you can create additional nodes for clients 2.

Check available clients:

In [None]:
from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)
xx = req.list()
dataset_size = [xx[i][0]['shape'][1] for i in xx]
assert min(dataset_size)==max(dataset_size)
data_size = dataset_size[0]

## Recover global mean and std

We are going to run a model whose objective is to recover the federated mean and std, which we will use afterwards to standardize clients datasets with respect to the global dataset. We will need to run the following model exclusively for 1 round of 1 epoch.

In [6]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import pandas as pd
from copy import deepcopy

from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class FedMeanStdTrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(FedMeanStdTrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["import pandas as pd",
               "import numpy as np",
               "from copy import deepcopy"]
        
        self.add_dependency(deps)
        
        self.n_features=model_args['n_features']
        
        # Here we define the parameter of our custom model: we want to recover the mean and std of each client,
        # a specifically designed aggregatod will allow to evaluate the global mean and std
        self.mean = nn.Parameter(torch.zeros(self.n_features),requires_grad=False)
        self.std = nn.Parameter(torch.zeros(self.n_features),requires_grad=False)
        self.size = nn.Parameter(torch.zeros(self.n_features),requires_grad=False)
        # Fake parameter with requires_grad=True to be passed to the backward (included in TorchTrainingPlan)
        self.fake = nn.Parameter(torch.randn(1),requires_grad=True)
        
    def training_data(self):
        
        df = pd.read_csv(self.dataset_path, sep=',', index_col=False)
        
        ### NOTE: batch_size should be == dataset size ###
        batch_size = df.shape[0]
        x_train = df.values
        x_mask = np.isfinite(x_train)
        xhat_0 = np.copy(x_train)
        ### NOTE: for standardization purposes, we keep nan when data is missing
        #xhat_0[np.isnan(x_train)] = 0
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        
        data_manager = DataManager(dataset=xhat_0 , target=x_mask , **train_kwargs)
        
        return data_manager
    
    def training_step(self, data, mask):
        data_np = data.numpy()
        self.size += torch.Tensor([data_np[:,dim].size - np.count_nonzero(np.isnan(data_np[:,dim]))\
                                   for dim in range(self.n_features)])
        self.mean += torch.from_numpy(np.nanmean(data_np,0))
        self.std += torch.from_numpy(np.nanstd(data_np,0))
        return self.fake

In [7]:
model_args = {'n_features':data_size}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'log_interval' : 1,
    'epochs': 1, 
    'dry_run': False
}

tags =  ['adni']

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedstandard import FedStandard

fed_mean_std = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=FedMeanStdTrainingPlan,
                 training_args=training_args,
                 round_limit=1,
                 aggregator=FedStandard(),
                 node_selection_strategy=None)

In [None]:
fed_mean_std.run()

In [10]:
# We save the federated mean and std:
fed_mean = fed_mean_std.aggregated_params()[0]['params']['fed_mean']
fed_std = fed_mean_std.aggregated_params()[0]['params']['fed_std']

## Define the MIWAE experiment model and parameters

Declare a torch.nn MIWAETrainingPlan class to send for training on the node

Note: we include a function, ``standardize_data``, which allow to standardize data either with respect to a mean and std provided by the user, or locally, considering only local data for each client.

In [11]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import numpy as np
import torch.distributions as td
import pandas as pd

from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MIWAETrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(MIWAETrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from torchvision import datasets, transforms",
               "import torch.distributions as td",
               "import pandas as pd",
               "import numpy as np"]
        
        self.n_features=model_args['n_features']
        self.n_latent=model_args['n_latent']
        self.n_hidden=model_args['n_hidden']
        self.n_samples=model_args['n_samples']
        
        self.add_dependency(deps)
        
        if 'standardization' in model_args:
            self.standardization = True
            if (('fed_mean' in model_args['standardization']) and ('fed_std' in model_args['standardization'])):
                self.fed_mean = np.array(model_args['standardization']['fed_mean'])
                self.fed_std = np.array(model_args['standardization']['fed_std'])
            else:
                self.fed_mean = None
                self.fed_std = None
        
        # the encoder will output both the mean and the diagonal covariance
        self.encoder=nn.Sequential(
                        torch.nn.Linear(self.n_features, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 2*self.n_latent),  
                        )
        # the decoder will output both the mean, the scale, 
        # and the number of degrees of freedoms (hence the 3*p)
        self.decoder = nn.Sequential(
                        torch.nn.Linear(self.n_latent, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 3*self.n_features),  
                        )
        
        self.optimizer = torch.optim.Adam(list(self.encoder.parameters()) \
                                    + list(self.decoder.parameters()),lr=1e-3)
              
        self.encoder.apply(self.weights_init)
        self.decoder.apply(self.weights_init)
        
        # prior
        self.p_z = td.Independent(td.Normal(loc=torch.zeros(self.n_latent).to(self.device)\
                                       ,scale=torch.ones(self.n_latent).to(self.device)),1)
    
    def weights_init(self,layer):
        if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
    
    def miwae_loss(self,iota_x,mask):
        batch_size = iota_x.shape[0]
        out_encoder = self.encoder(iota_x)
        
        q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :self.n_latent],\
                                                scale=torch.nn.Softplus()\
                                                (out_encoder[..., self.n_latent:\
                                                             (2*self.n_latent)])),1)

        zgivenx = q_zgivenxobs.rsample([self.n_samples])
        zgivenx_flat = zgivenx.reshape([self.n_samples*batch_size,self.n_latent])

        out_decoder = self.decoder(zgivenx_flat)
        all_means_obs_model = out_decoder[..., :self.n_features]
        all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., self.n_features:\
                                                               (2*self.n_features)]) + 0.001
        all_degfreedom_obs_model = torch.nn.Softplus()\
        (out_decoder[..., (2*self.n_features):(3*self.n_features)]) + 3

        data_flat = torch.Tensor.repeat(iota_x,[self.n_samples,1]).reshape([-1,1])
        tiledmask = torch.Tensor.repeat(mask,[self.n_samples,1])

        all_log_pxgivenz_flat = torch.distributions.StudentT\
        (loc=all_means_obs_model.reshape([-1,1]),\
         scale=all_scales_obs_model.reshape([-1,1]),\
         df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
        all_log_pxgivenz = all_log_pxgivenz_flat.reshape([self.n_samples*batch_size,self.n_features])

        logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([self.n_samples,batch_size])
        logpz = self.p_z.log_prob(zgivenx)
        logq = q_zgivenxobs.log_prob(zgivenx)

        neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

        return neg_bound

    def training_data(self,  batch_size = 48):
        
        df = pd.read_csv(self.dataset_path, sep=',', index_col=False)
        x_train = df.values
        x_mask = np.isfinite(x_train)
        # xhat_0: missing values are replaced by zeros. 
        #This x_hat0 is what will be fed to our encoder.
        xhat_0 = np.copy(x_train)
        
        # Data standardization
        if self.standardization:
            xhat_0 = self.standardize_data(xhat_0)
            
        xhat_0[np.isnan(x_train)] = 0
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        
        data_manager = DataManager(dataset=xhat_0 , target=x_mask , **train_kwargs)
        
        return data_manager
    
    def standardize_data(self,data):
        data_norm = np.copy(data)
        if ((self.fed_mean is not None) and (self.fed_std is not None)):
            print('FEDERATED STANDARDIZATION')
            data_norm = (data_norm - self.fed_mean)/self.fed_std
        else:
            print('LOCAL STANDARDIZATION')
            data_norm = (data_norm - np.nanmean(data_norm,0))/np.nanstd(data_norm,0)
        return data_norm
    
    def training_step(self, data, mask):
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        loss = self.miwae_loss(iota_x = data,mask = mask)
        return loss

This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. 
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.
* data `tags` to search nodes for training.
* total number of `rounds`.
If FedProx optimisation is requested, `fedprox_mu` parameter must be defined here. It also must be a float between XX and YY.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [12]:
h = 128 # number of hidden units in (same for all MLPs)
d = 5 # dimension of the latent space, we choose d=1 for visualisation purposes
K = 100 # number of IS during training

n_epochs=10

batch_size = 32

model_args = {'n_features':data_size,
              'n_latent':d,
              'n_hidden':h,
              'n_samples':K,
              'standardization':{'fed_mean':fed_mean.tolist(),'fed_std':fed_std.tolist()}}

training_args = {
    'batch_size': batch_size, 
    'lr': 1e-3, 
    'log_interval' : 1,
    'epochs': n_epochs, 
    'dry_run': False,  
    #'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

tags =  ['adni']
rounds = 150

## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `round_limit` rounds, applying the `node_selection_strategy` between the rounds

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

exp = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [None]:
exp.run()

## Run the experiment with FedProx

We repeat the federated training but using FedProx as aggregation scheme (starting from the second iteration).

In [None]:
# NOTE: during the first round
# we will simply use FedAvg with standard optimization scheme: the FedProx penalization
# term will be introduced exclusively from the second round.
# training_args.update(fedprox_mu = 0.)

exp_fedprox = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

In [None]:
exp_fedprox.run_once()

In [None]:
# Starting from the second round, FedProx is used with mu=0.1
# We first update the training args
training_args.update(fedprox_mu = 0.1)
# Then update training args in the experiment
exp_fedprox.set_training_args(training_args)
exp_fedprox.run()

## Run the experiment with FedProx and performing the standardization locally

And finally we propose to use FedCos as well, which introduce an alternative penalization term with cosine similarity:

In [None]:
del training_args['fedprox_mu'] 

model_args.update(standardization = {})

exp_fedprox_std_local = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

In [None]:
exp_fedprox_std_local.run_once()

In [None]:
training_args.update(fedprox_mu = 0.1)
# Then update training args in the experiment
exp_fedprox_std_local.set_training_args(training_args)
exp_fedprox_std_local.run()

# Test and comparison to local training

## 1. Testing on an external dataset

First of all we are going to test the performance of the final federated model to impute missing data on a test dataset. To this extent we are going to remove randomly 50% of samples from the test dataset, `data_test`, defined at the beginning of this notebook.

In [21]:
# from the test dataset, we will remove randomly 50% of data
np.random.seed(1234)

perc_miss = 0.5 # 50% of missing data

n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features

xmiss = np.copy(data_test)
xmiss_flat = xmiss.flatten()
miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                replace=False)
xmiss_flat[miss_pattern] = np.nan 
xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans

# We save the training dataset with missing values before normalization: we will need it later
data_test_missing = np.copy(xmiss)

In [22]:
xfull = np.copy(data_test)
xfull = (xfull - fed_mean.numpy())/fed_std.numpy()
xmiss = np.copy(data_test_missing)
xmiss = (xmiss - fed_mean.numpy())/fed_std.numpy() #xmiss is normalized wrt global mean and std
mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
xhat_0 = np.copy(xmiss)
xhat_0[np.isnan(xmiss)] = 0
xhat = np.copy(xhat_0) # This will be out imputed data matrix

In [23]:
def miwae_impute_single(encoder,decoder,iota_x,mask,d,L):
    
    p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)
    
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([L])
    zgivenx_flat = zgivenx.reshape([L*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[L,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[L,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([L*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([L,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    xgivenz = td.Independent(td.StudentT(loc=all_means_obs_model, scale=all_scales_obs_model, df=all_degfreedom_obs_model),1)

    imp_weights = torch.nn.functional.softmax(logpxobsgivenz + logpz - logq,0) # these are w_1,....,w_L for all observations in the batch
    xms = xgivenz.mean.reshape([L,batch_size,p])  # that's the only line that changed!
    xm=torch.einsum('ki,kij->ij', imp_weights, xms) 

    return xm

As well as the MSE function:

In [24]:
def mse(xhat,xtrue,mask): # MSE function for imputations
    xhat = np.array(xhat)
    xtrue = np.array(xtrue)
    return np.mean(np.power(xhat-xtrue,2)[~mask])

We instantiate the model using last updated federated parameters:

In [25]:
# extract federated model into PyTorch framework
model = exp.model_instance()
model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

encoder = model.encoder
decoder = model.decoder

Same for model obtained with FedProx (and federated standardization)

In [26]:
model_fedprox = exp_fedprox.model_instance()
model_fedprox.load_state_dict(exp_fedprox.aggregated_params()[rounds - 1]['params'])

encoder_fedprox = model_fedprox.encoder
decoder_fedprox = model_fedprox.decoder

In [27]:
model_fedprox_std_local = exp_fedprox_std_local.model_instance()
model_fedprox_std_local.load_state_dict(exp_fedprox_std_local.aggregated_params()[rounds - 1]['params'])

encoder_fedprox_std_local = model_fedprox_std_local.encoder
decoder_fedprox_std_local = model_fedprox_std_local.decoder

And we finally do the imputation and evaluate the corresponding imputation error through MSE for each federated model:

In [28]:
L = 500

xhat[~mask] = miwae_impute_single(encoder = encoder,decoder = decoder,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model on testing data %g' %err_test_data)
print('-----')

xhat[~mask] = miwae_impute_single(encoder = encoder_fedprox,decoder = decoder_fedprox,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data_fedprox = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model (with fedprox) on testing data  %g' %err_test_data_fedprox)
print('-----')

Imputation MSE of fed model on testing data 0.62712
-----
Imputation MSE of fed model (with fedprox) on testing data  0.604235
-----


## 2. Testing on the client's datasets

We are now going to use the final federated model to impute missing data of client 1, which have been used for training:

In [29]:
# We first recover data (full and with missing entries) from client 1
data_client_1 = np.copy(Clients_data[0])
n = data_client_1.shape[0] # number of observations
p = data_client_1.shape[1] # number of features

xfull_cl1 = np.copy(data_client_1)
xfull_cl1 = (xfull_cl1 - fed_mean.numpy())/fed_std.numpy()

xmiss_cl1 = np.copy(Clients_missing[0])
xmiss_cl1 = (xmiss_cl1 - fed_mean.numpy())/fed_std.numpy()
mask_cl1 = np.isfinite(xmiss_cl1) # binary mask that indicates which values are missing
xhat_0_cl1 = np.copy(xmiss_cl1)
xhat_0_cl1[np.isnan(xmiss_cl1)] = 0
xhat_cl1 = np.copy(xhat_0_cl1) # This will be out imputed data matrix

### Now we do the imputation

xhat_cl1[~mask_cl1] = miwae_impute_single(encoder = encoder,decoder = decoder, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model on data from client 1  %g' %err_cl1_data)
print('-----')

xhat_cl1[~mask_cl1] = miwae_impute_single(encoder = encoder_fedprox,decoder = decoder_fedprox, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data_fedprox = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model (with fedprox) on data from client 1  %g' %err_cl1_data_fedprox)
print('-----')

Imputation MSE of fed model on data from client 1  0.603596
-----
Imputation MSE of fed model (with fedprox) on data from client 1  0.630319
-----


And on client 2

In [30]:
# We first recover data (full and with missing entries) from client 2
data_client_2 = np.copy(Clients_data[1])
n = data_client_2.shape[0] # number of observations
p = data_client_2.shape[1] # number of features

xfull_cl2 = np.copy(data_client_2)
xfull_cl2 = (xfull_cl2 - fed_mean.numpy())/fed_std.numpy()

xmiss_cl2 = np.copy(Clients_missing[1])
xmiss_cl2 = (xmiss_cl2 - fed_mean.numpy())/fed_std.numpy()
mask_cl2 = np.isfinite(xmiss_cl2) # binary mask that indicates which values are missing
xhat_0_cl2 = np.copy(xmiss_cl2)
xhat_0_cl2[np.isnan(xmiss_cl2)] = 0
xhat_cl2 = np.copy(xhat_0_cl2) # This will be out imputed data matrix

### Now we do the imputation

xhat_cl2[~mask_cl2] = miwae_impute_single(encoder = encoder,decoder = decoder, iota_x = torch.from_numpy(xhat_0_cl2).float(),mask = torch.from_numpy(mask_cl2).float(),d = d,L= L).cpu().data.numpy()[~mask_cl2]
err_cl2_data = np.array([mse(xhat_cl2,xfull_cl2,mask_cl2)])
print('Imputation MSE of fed model on data from client 2  %g' %err_cl2_data)
print('-----')

xhat_cl2[~mask_cl2] = miwae_impute_single(encoder = encoder_fedprox,decoder = decoder_fedprox, iota_x = torch.from_numpy(xhat_0_cl2).float(),mask = torch.from_numpy(mask_cl2).float(),d = d,L= L).cpu().data.numpy()[~mask_cl2]
err_cl2_data_fedprox = np.array([mse(xhat_cl2,xfull_cl2,mask_cl2)])
print('Imputation MSE of fed model (with fedprox) on data from client 2  %g' %err_cl2_data_fedprox)
print('-----')

Imputation MSE of fed model on data from client 2  0.532805
-----
Imputation MSE of fed model (with fedprox) on data from client 2  0.53263
-----


## 3. Testing of FedProx model with local standardization

We are going to test the federated model with FedProx, where data standardization is performed locally. In order to be as much coherent as possible, each time the standardization will be realized locally as well in the testing phase.

In [35]:
# We recover the model
model_fedprox_std_local = exp_fedprox_std_local.model_instance()
model_fedprox_std_local.load_state_dict(exp_fedprox_std_local.aggregated_params()[rounds - 1]['params'])

encoder_fedprox_std_local = model_fedprox_std_local.encoder
decoder_fedprox_std_local = model_fedprox_std_local.decoder

# We re-create the testing dataset, and standardize with respect to his own data
n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features
xmiss = np.copy(data_test_missing)
mean_miss_test = np.nanmean(xmiss,0)
std_miss_test = np.nanstd(xmiss,0)
xmiss = (xmiss - mean_miss_test)/std_miss_test
mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
xhat_0 = np.copy(xmiss)
xhat_0[np.isnan(xmiss)] = 0
xhat = np.copy(xhat_0) # This will be out imputed data matrix

xfull = np.copy(data_test)
xfull = (xfull - mean_miss_test)/std_miss_test

# We do the imputation
xhat[~mask] = miwae_impute_single(encoder = encoder_fedprox_std_local,decoder = decoder_fedprox_std_local,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data_fedprox_std_local = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model (with fedprox and local standardization) on testing data  %g' %err_test_data_fedprox_std_local)
print('-----')

# Same for the dataset from client 1
data_client_1 = np.copy(Clients_data[0])
n = data_client_1.shape[0] # number of observations
p = data_client_1.shape[1] # number of features

xmiss_cl1 = np.copy(Clients_missing[0])
mean_miss_cl1 = np.nanmean(xmiss_cl1,0)
std_miss_cl1 = np.nanstd(xmiss_cl1,0)
xmiss_cl1 = (xmiss_cl1 - mean_miss_cl1)/std_miss_cl1
mask_cl1 = np.isfinite(xmiss_cl1) # binary mask that indicates which values are missing
xhat_0_cl1 = np.copy(xmiss_cl1)
xhat_0_cl1[np.isnan(xmiss_cl1)] = 0
xhat_cl1 = np.copy(xhat_0_cl1) # This will be out imputed data matrix
xfull_cl1 = np.copy(data_client_1)
xfull_cl1 = (xfull_cl1 - mean_miss_cl1)/std_miss_cl1

xhat_cl1[~mask_cl1] = miwae_impute_single(encoder = encoder_fedprox_std_local,decoder = decoder_fedprox_std_local, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data_fedprox_std_local = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model (with fedprox and local standardization) on data from client 1  %g' %err_cl1_data_fedprox_std_local)
print('-----')

# And for the dataset from client 2
data_client_2 = np.copy(Clients_data[1])
n = data_client_2.shape[0] # number of observations
p = data_client_2.shape[1] # number of features

xmiss_cl2 = np.copy(Clients_missing[1])
mean_miss_cl2 = np.nanmean(xmiss_cl2,0)
std_miss_cl2 = np.nanstd(xmiss_cl2,0)
xmiss_cl2 = (xmiss_cl2 - mean_miss_cl2)/std_miss_cl2
mask_cl2 = np.isfinite(xmiss_cl2) # binary mask that indicates which values are missing
xhat_0_cl2 = np.copy(xmiss_cl2)
xhat_0_cl2[np.isnan(xmiss_cl2)] = 0
xhat_cl2 = np.copy(xhat_0_cl2) # This will be out imputed data matrix
xfull_cl2 = np.copy(data_client_2)
xfull_cl2 = (xfull_cl2 - mean_miss_cl2)/std_miss_cl2

xhat_cl2[~mask_cl2] = miwae_impute_single(encoder = encoder_fedprox_std_local,decoder = decoder_fedprox_std_local, iota_x = torch.from_numpy(xhat_0_cl2).float(),mask = torch.from_numpy(mask_cl2).float(),d = d,L= L).cpu().data.numpy()[~mask_cl2]
err_cl2_data_fedprox_std_local = np.array([mse(xhat_cl2,xfull_cl2,mask_cl2)])
print('Imputation MSE of fed model (with fedprox and local standardization) on data from client 2  %g' %err_cl2_data_fedprox_std_local)
print('-----')

Imputation MSE of fed model (with fedprox and local standardization) on testing data  0.851073
-----
Imputation MSE of fed model (with fedprox and local standardization) on data from client 1  0.615122
-----
Imputation MSE of fed model (with fedprox and local standardization) on data from client 2  0.641046
-----


## 3. Local training and testing on client 1

Finally, we test the performance of the same model trained locally and tested on the dataset from client 1. We will use a total of `epochs`x`rounds` local epochs.

In [32]:
p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)
def miwae_loss(iota_x,mask, encoder, decoder, d, p, K, batch_size):
    
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)

    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([K])
    zgivenx_flat = zgivenx.reshape([K*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[K,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[K,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([K*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([K,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

    return neg_bound

We perform the local training on data from client 1, standardized with respect to his own data:

In [33]:
n_epochs_local = n_epochs*rounds

bs = training_args.get('batch_size')
lr = training_args.get('lr')

n = xfull_cl1.shape[0] # number of observations
p = xfull_cl1.shape[1] # number of features

h = model_args.get('n_hidden') 
d = model_args.get('n_latent') 
K = model_args.get('n_samples') 

encoder_cl1 = nn.Sequential(
    torch.nn.Linear(p, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 2*d),  # the encoder will output both the mean and the diagonal covariance
)

decoder_cl1 = nn.Sequential(
    torch.nn.Linear(d, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 3*p),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
)

optimizer_cl1 = torch.optim.Adam(list(encoder_cl1.parameters()) + list(decoder_cl1.parameters()),lr=lr)

def weights_init(layer):
    if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
        
encoder_cl1.apply(weights_init)
decoder_cl1.apply(weights_init)

for ep in range(1,n_epochs_local):
    perm = np.random.permutation(n) # We use the "random reshuffling" version of SGD
    batches_data = np.array_split(xhat_0_cl1[perm,], n/bs)
    batches_mask = np.array_split(mask_cl1[perm,], n/bs)
    for it in range(len(batches_data)):
        optimizer_cl1.zero_grad()
        encoder_cl1.zero_grad()
        decoder_cl1.zero_grad()
        b_data = torch.from_numpy(batches_data[it]).float()
        b_mask = torch.from_numpy(batches_mask[it]).float()
        loss = miwae_loss(iota_x = b_data,mask = b_mask, encoder = encoder_cl1, decoder = decoder_cl1, d = d, p = p, K = K, batch_size = bs)
        loss.backward()
        optimizer_cl1.step()
    if ep % 100 == 1:
        print('Epoch %g' %ep)
        print('MIWAE likelihood bound  %g' %(-np.log(K)-miwae_loss(iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(), encoder = encoder_cl1, decoder = decoder_cl1, d = d, p = p, K = K, batch_size = bs).cpu().data.numpy())) # Gradient step      
        print('Loss: {:.6f}'.format(loss.item()))

Epoch 1
MIWAE likelihood bound  -132.373
Loss: 122.189293
Epoch 101
MIWAE likelihood bound  -69.277
Loss: 63.758804
Epoch 201
MIWAE likelihood bound  -44.2144
Loss: 42.097584
Epoch 301
MIWAE likelihood bound  -29.7305
Loss: 27.548759
Epoch 401
MIWAE likelihood bound  -14.2438
Loss: 12.528351
Epoch 501
MIWAE likelihood bound  -5.80943
Loss: -4.812748
Epoch 601
MIWAE likelihood bound  0.699142
Loss: -6.148830
Epoch 701
MIWAE likelihood bound  12.8062
Loss: -34.333340
Epoch 801
MIWAE likelihood bound  18.5732
Loss: -20.412441
Epoch 901
MIWAE likelihood bound  25.4719
Loss: -33.905273
Epoch 1001
MIWAE likelihood bound  30.8882
Loss: -38.381958
Epoch 1101
MIWAE likelihood bound  31.5629
Loss: -32.376366
Epoch 1201
MIWAE likelihood bound  40.3618
Loss: -23.410786
Epoch 1301
MIWAE likelihood bound  42.746
Loss: -37.975700
Epoch 1401
MIWAE likelihood bound  46.3974
Loss: -47.984089


And we do the imputation on the same dataset:

In [36]:
xhat_cl1_loc = np.copy(xhat_cl1)
xhat_cl1_loc[~mask_cl1] = miwae_impute_single(encoder = encoder_cl1, decoder = decoder_cl1, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_local_cl1_data = np.array([mse(xhat_cl1_loc,xfull_cl1,mask_cl1)])
print('Imputation MSE of local model on data from same client (cl 1)  %g' %err_local_cl1_data)
print('-----')

Imputation MSE of local model on data from same client (cl 1)  0.700608
-----


As well as the imputation on the external test dataset:

In [37]:
xhat_loc1 = np.copy(xhat)
xhat_loc1[~mask] = miwae_impute_single(encoder = encoder_cl1,decoder = decoder_cl1,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_local_cl1_test_data = np.array([mse(xhat_loc1,xfull,mask)])
print('Imputation MSE of local (cl1) model on testing data %g' %err_local_cl1_test_data)
print('-----')

Imputation MSE of local (cl1) model on testing data 0.879791
-----


## 3. Local training and testing on client 2

In [38]:
n_epochs_local = n_epochs*rounds

bs = training_args.get('batch_size')
lr = training_args.get('lr')

n = xfull_cl2.shape[0] # number of observations
p = xfull_cl2.shape[1] # number of features

h = model_args.get('n_hidden') 
d = model_args.get('n_latent') 
K = model_args.get('n_samples') 

encoder_cl2 = nn.Sequential(
    torch.nn.Linear(p, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 2*d),  # the encoder will output both the mean and the diagonal covariance
)

decoder_cl2 = nn.Sequential(
    torch.nn.Linear(d, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 3*p),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
)

optimizer_cl2 = torch.optim.Adam(list(encoder_cl2.parameters()) + list(decoder_cl2.parameters()),lr=lr)

def weights_init(layer):
    if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
        
encoder_cl2.apply(weights_init)
decoder_cl2.apply(weights_init)

for ep in range(1,n_epochs_local):
    perm = np.random.permutation(n) # We use the "random reshuffling" version of SGD
    batches_data = np.array_split(xhat_0_cl2[perm,], n/bs)
    batches_mask = np.array_split(mask_cl2[perm,], n/bs)
    for it in range(len(batches_data)):
        optimizer_cl2.zero_grad()
        encoder_cl2.zero_grad()
        decoder_cl2.zero_grad()
        b_data = torch.from_numpy(batches_data[it]).float()
        b_mask = torch.from_numpy(batches_mask[it]).float()
        loss = miwae_loss(iota_x = b_data,mask = b_mask, encoder = encoder_cl2, decoder = decoder_cl2, d = d, p = p, K = K, batch_size = bs)
        loss.backward()
        optimizer_cl2.step()
    if ep % 500 == 1:
        print('Epoch %g' %ep)
        print('MIWAE likelihood bound  %g' %(-np.log(K)-miwae_loss(iota_x = torch.from_numpy(xhat_0_cl2).float(),mask = torch.from_numpy(mask_cl2).float(), encoder = encoder_cl2, decoder = decoder_cl2, d = d, p = p, K = K, batch_size = bs).cpu().data.numpy())) # Gradient step      
        print('Loss: {:.6f}'.format(loss.item()))

Epoch 1
MIWAE likelihood bound  -132.631
Loss: 122.724594
Epoch 501
MIWAE likelihood bound  18.5143
Loss: -14.616351
Epoch 1001
MIWAE likelihood bound  70.4203
Loss: -69.737221


In [39]:
xhat_cl2_loc = np.copy(xhat_cl2)
xhat_cl2_loc[~mask_cl2] = miwae_impute_single(encoder = encoder_cl2, decoder = decoder_cl2, iota_x = torch.from_numpy(xhat_0_cl2).float(),mask = torch.from_numpy(mask_cl2).float(),d = d,L= L).cpu().data.numpy()[~mask_cl2]
err_local_cl2_data = np.array([mse(xhat_cl2_loc,xfull_cl2,mask_cl2)])
print('Imputation MSE of local model on data from same client (cl 2)  %g' %err_local_cl2_data)
print('-----')

Imputation MSE of local model on data from same client (cl 2)  0.75784
-----


In [40]:
xhat_loc2 = np.copy(xhat)
xhat_loc2[~mask] = miwae_impute_single(encoder = encoder_cl2,decoder = decoder_cl2,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_local_cl2_test_data = np.array([mse(xhat_loc2,xfull,mask)])
print('Imputation MSE of local (cl2) model on testing data %g' %err_local_cl2_test_data)
print('-----')

Imputation MSE of local (cl2) model on testing data 0.940715
-----


## Comparison of obtained results:

In [41]:
from tabulate import tabulate

print('Imputation MSE on testing data')
print('-----')
data = [['FedAvg', err_test_data],
['FedProx', err_test_data_fedprox],
['FedLocStd', err_test_data_fedprox_std_local],
['Local (cl1)', err_local_cl1_test_data],
['Local (cl2)', err_local_cl2_test_data]]
print (tabulate(data, headers=["Model", "Mean Squared Error (\u2193)"]))
print('-----')
print('-----')
print('Imputation MSE on local data from client 1')
print('-----')
data = [['FedAvg', err_cl1_data],
['FedProx', err_cl1_data_fedprox],
['FedLocStd', err_cl1_data_fedprox_std_local],
['Local (cl1)', err_local_cl1_data]]
print (tabulate(data, headers=["Model", "Mean Squared Error (\u2193)"]))
print('-----')
print('-----')
print('Imputation MSE on local data from client 2')
print('-----')
data = [['FedAvg', err_cl2_data],
['FedProx', err_cl2_data_fedprox],
['FedLocStd', err_cl2_data_fedprox_std_local],
['Local (cl2)', err_local_cl2_data]]
print (tabulate(data, headers=["Model", "Mean Squared Error (\u2193)"]))

Imputation MSE on testing data
-----
Model          Mean Squared Error (↓)
-----------  ------------------------
FedAvg                       0.62712
FedProx                      0.604235
FedLocStd                    0.851073
Local (cl1)                  0.879791
Local (cl2)                  0.940715
-----
-----
Imputation MSE on local data from client 1
-----
Model          Mean Squared Error (↓)
-----------  ------------------------
FedAvg                       0.603596
FedProx                      0.630319
FedLocStd                    0.615122
Local (cl1)                  0.700608
-----
-----
Imputation MSE on local data from client 2
-----
Model          Mean Squared Error (↓)
-----------  ------------------------
FedAvg                       0.532805
FedProx                      0.53263
FedLocStd                    0.641046
Local (cl2)                  0.75784
