# Missing data imputation with Fedbiomed using MIWAE

In this notebook we show:
* how to obtain mean and std in a federated manner, to perform afterwards local dataset standardization with respect to the global dataset
* how to impute missing not at random (MAR) data in a federated setting using MIWAE (https://arxiv.org/abs/2006.12871). 

We will compare results of federated training using FedAvg, FedProx (with both local standardization and federated standardization), with local results.

In [1]:
%load_ext autoreload
%autoreload 2

## Prepare the data

For this experiment we will use the breast cancer data from sklearn.

In [2]:
import pandas as pd
import numpy as np

data_url = "http://lib.stat.cmu.edu/datasets/boston"
raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
target = raw_df.values[1::2, 2]

In [3]:
from sklearn.model_selection import train_test_split

#train test split
data_train, data_test, labels_train, labels_test = train_test_split(data, target, test_size=0.20, random_state=42)
df_data_train = pd.DataFrame(data_train)
N_train = len(df_data_train)
mean_train = np.mean(data_train,0)
std_train = np.std(data_train,0)

print("N_train: ",N_train)
print("mean_train: ",mean_train)
print("std_train: ",std_train)

N_train:  404
mean_train:  [3.60912463e+00 1.15693069e+01 1.09850495e+01 7.17821782e-02
 5.56484158e-01 6.31589109e+00 6.85564356e+01 3.80819505e+00
 9.35643564e+00 4.04032178e+02 1.83183168e+01 3.56278342e+02
 1.24573515e+01]
std_train:  [8.86406744e+00 2.31238090e+01 6.88607935e+00 2.58126901e-01
 1.17558710e-01 7.08573178e-01 2.79602535e+01 2.12858714e+00
 8.57908366e+00 1.65966869e+02 2.22594093e+00 9.14531376e+01
 7.10157559e+00]


In [4]:
# split train across datasets
client_1, client_2, client_3 = np.split(df_data_train.sample(frac=1,random_state=42), \
                                        [int(.33*N_train), int(.66*len(df_data_train))])

Clients_data=[client_1, client_2, client_3]

N_cl = [len(i) for i in Clients_data]
mean_cl = [np.nanmean(i,0) for i in Clients_data]
std_cl = [np.nanstd(i,0) for i in Clients_data]
cl = len(Clients_data)

N_train_post = sum(N_cl)
mean_train_post = sum([N_cl[i]*np.array(mean_cl[i])/N_train_post for i in range(cl)])
std_train_post = np.sqrt(sum([((N_cl[i]-1)*np.array(std_cl[i])**2+N_cl[i]*np.array(mean_cl[i])**2)/(N_train_post-cl) for i in range(cl)])-(N_train_post/(N_train_post-cl))*mean_train_post**2)

print("N_train_post", N_train_post)
print("mean_train_post", mean_train_post)
print("std_train_post",std_train_post)

N_train_post 404
mean_train_post [3.60912463e+00 1.15693069e+01 1.09850495e+01 7.17821782e-02
 5.56484158e-01 6.31589109e+00 6.85564356e+01 3.80819505e+00
 9.35643564e+00 4.04032178e+02 1.83183168e+01 3.56278342e+02
 1.24573515e+01]
std_train_post [8.86436413e+00 2.31246317e+01 6.88614271e+00 2.58134386e-01
 1.17559296e-01 7.08579110e-01 2.79608726e+01 2.12859150e+00
 8.57923305e+00 1.65969122e+02 2.22596627e+00 9.14558358e+01
 7.10158461e+00]


In [5]:
# from each dataset we will remove randomly 50% of data
np.random.seed(1234)

# 50% of missing data for client 1, 30% for client 2, 60% for client 3
perc_miss_list = [0.5,0.3,0.6] 

Clients_missing = []
for perc,c in enumerate(Clients_data):
    perc_miss=perc_miss_list[perc]
    n = c.shape[0] # number of observations
    p = c.shape[1] # number of features
    xmiss = np.copy(c)
    #xmiss = (xmiss - np.mean(xmiss,0))/np.std(xmiss,0)
    xmiss_flat = xmiss.flatten()
    miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                    replace=False)
    xmiss_flat[miss_pattern] = np.nan 
    xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans
    mask = np.isfinite(xmiss) # binary mask that indicates which values are missing
    Clients_missing.append(xmiss)

In [6]:
p = Clients_missing[0].shape[1]
N_cl = [np.array([Clients_missing[c][:,dim].size - np.count_nonzero(np.isnan(Clients_missing[c][:,dim])) for dim in range(p)]) for c in range(cl)]
print(N_cl)
mean_cl = [np.nanmean(i,0) for i in Clients_missing]
std_cl = [np.nanstd(i,0) for i in Clients_missing]
cl = len(Clients_missing)

N_train_post = np.array(sum([N_cl[c] for c in range(cl)]))
print(N_train_post)
mean_train_post = np.array(sum([N_cl[i]*mean_cl[i]/N_train_post for i in range(cl)]))
std_train_post = np.sqrt(sum([((N_cl[i]-1)*(std_cl[i]**2)+N_cl[i]*(mean_cl[i]**2))/(N_train_post-cl) for i in range(cl)])-(N_train_post/(N_train_post-cl))*mean_train_post**2)

print("N_train_post after missing", N_train_post)
print("mean_train_post after missing", mean_train_post)
print("std_train_post after missing",std_train_post)

[array([65, 70, 62, 72, 64, 65, 73, 62, 63, 70, 72, 63, 64]), array([ 89,  87,  96,  91,  89,  92,  91,  88,  97, 102,  98,  96,  95]), array([61, 50, 48, 57, 51, 54, 58, 58, 59, 58, 63, 50, 51])]
[215 207 206 220 204 211 222 208 219 230 233 209 210]
N_train_post after missing [215 207 206 220 204 211 222 208 219 230 233 209 210]
mean_train_post after missing [3.07500809e+00 1.11014493e+01 1.12669903e+01 7.72727273e-02
 5.58312745e-01 6.34579147e+00 7.05545045e+01 3.82085865e+00
 9.24657534e+00 4.15421739e+02 1.83309013e+01 3.51897033e+02
 1.23715238e+01]
std_train_post after missing [8.04864157e+00 2.30081521e+01 6.92799094e+00 2.67054724e-01
 1.10852174e-01 7.53521123e-01 2.74137196e+01 2.09178388e+00
 8.56198414e+00 1.67620328e+02 2.22012681e+00 9.87384587e+01
 6.58858144e+00]


In [7]:
Clients_missing_norm = []
for data in Clients_missing:
    data_norm = np.copy(data)
    data_norm = (data_norm - mean_train_post)/std_train_post
    Clients_missing_norm.append(data_norm)

import os 
os.makedirs('data/clients_data', exist_ok=True) 
for i in range(len(Clients_missing)):
    pd.DataFrame(Clients_missing[i]).to_csv('data/clients_data/client_'+str(i+1)+'.csv',index=False)

In [8]:
### We centralize all data to evaluate mean and std with missing
#print([i.shape for i in Clients_missing])
Clients_missing_tot = np.concatenate(Clients_missing,axis=0)
#print(Clients_missing_tot.shape)
mean_clients_missing = np.nanmean(Clients_missing_tot,0)
std_clients_missing = np.nanstd(Clients_missing_tot,0)

print("mean centralized after missing", mean_clients_missing)
print("std centralized after missing", std_clients_missing)

mean centralized after missing [3.07500809e+00 1.11014493e+01 1.12669903e+01 7.72727273e-02
 5.58312745e-01 6.34579147e+00 7.05545045e+01 3.82085865e+00
 9.24657534e+00 4.15421739e+02 1.83309013e+01 3.51897033e+02
 1.23715238e+01]
std centralized after missing [8.05344570e+00 2.30072840e+01 6.92754048e+00 2.67023694e-01
 1.10866213e-01 7.53666761e-01 2.74117840e+01 2.09158817e+00
 8.56207122e+00 1.67622662e+02 2.21998097e+00 9.88107497e+01
 6.58763897e+00]


## Start the network
Before running this notebook, start the network with `./scripts/fedbiomed_run network`

## Setting the nodes up
It is necessary to previously configure a node:
1. `./scripts/fedbiomed_run node add`
  * Select option 1 (csv) to add client_1 dataset to the first node
  * Provide the correct tag by entering:  breast_cancer
  * Pick the folder where client_1 dataset has been saved
  * Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
  
2. Check that your data has been added by executing `./scripts/fedbiomed_run node list`
3. Run the node using `./scripts/fedbiomed_run node start`. Wait until you get `Starting task manager`. it means you are online.
4. Following the same procedure, you can create additional nodes for clients 2 and 3.

Check available clients:

In [None]:
from fedbiomed.researcher.requests import Requests
req = Requests()
req.list(verbose=True)
xx = req.list()
dataset_size = [xx[i][0]['shape'][1] for i in xx]
assert min(dataset_size)==max(dataset_size)
data_size = dataset_size[0]

## Recover global mean and std

In [10]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import pandas as pd
from copy import deepcopy

from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class FedMeanStdTrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(FedMeanStdTrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["import pandas as pd",
               "import numpy as np",
               "from copy import deepcopy"]
        
        self.add_dependency(deps)
        
        self.n_features=model_args['n_features']
        
        self.mean = nn.Parameter(torch.zeros(self.n_features,dtype=torch.float64),requires_grad=False)
        self.std = nn.Parameter(torch.zeros(self.n_features,dtype=torch.float64),requires_grad=False)
        self.size = nn.Parameter(torch.zeros(self.n_features,dtype=torch.float64),requires_grad=False)
        self.fake = nn.Parameter(torch.randn(1),requires_grad=True)
        
    def training_data(self):
        
        df = pd.read_csv(self.dataset_path, sep=',', index_col=False)
        
        ### NOTE: batch_size should be == dataset size ###
        batch_size = df.shape[0]
        x_train = df.values
        x_mask = np.isfinite(x_train)
        xhat_0 = np.copy(x_train)
        ### NOTE: we keep nan when data is missing
        #xhat_0[np.isnan(x_train)] = 0
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        
        data_manager = DataManager(dataset=xhat_0 , target=x_mask , **train_kwargs)
        
        return data_manager
    
    def training_step(self, data, mask):
        
        ### Implementing with np.nanmean, np.nanstd
        
        data_np = data.numpy()
        self.size += torch.Tensor([data_np[:,dim].size - np.count_nonzero(np.isnan(data_np[:,dim]))\
                                   for dim in range(self.n_features)])
        self.mean += torch.from_numpy(np.nanmean(data_np,0))
        self.std += torch.from_numpy(np.nanstd(data_np,0))
        
        # ### Implementing with torch.mean, torch.std
        
        # size_loc = torch.zeros(self.n_features)
        # mean_loc = torch.zeros(self.n_features)
        # std_loc = torch.zeros(self.n_features)
        # for dim in range(self.n_features):
        #     data_i = deepcopy(data[:,dim][mask[:,dim].bool()])
        #     size_loc[dim] = data_i.shape[0]
        #     mean_loc[dim] = torch.mean(data_i, dim=0)
        #     std_loc[dim] = torch.std(data_i, unbiased=False, dim=0)
        # self.size += size_loc
        # self.mean += mean_loc
        # self.std += std_loc
        
        return self.fake

In [11]:
model_args = {'n_features':data_size}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'log_interval' : 1,
    'epochs': 1, 
    'dry_run': False,  
    #'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

tags =  ['breast_cancer']

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedstandard import FedStandard

fed_mean_std = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=FedMeanStdTrainingPlan,
                 training_args=training_args,
                 round_limit=1,
                 aggregator=FedStandard(),
                 node_selection_strategy=None)

In [None]:
fed_mean_std.run()

In [14]:
fed_mean = fed_mean_std.aggregated_params()[0]['params']['fed_mean']
fed_std = fed_mean_std.aggregated_params()[0]['params']['fed_std']

In [15]:
print(fed_mean_std.aggregated_params()[0]['params'])
print(mean_clients_missing)
print(std_clients_missing)
print(fed_mean-mean_clients_missing)
print(fed_std-std_clients_missing)
print(fed_mean.dtype, fed_std.dtype)

{'fed_mean': tensor([3.0750e+00, 1.1101e+01, 1.1267e+01, 7.7273e-02, 5.5831e-01, 6.3458e+00,
        7.0555e+01, 3.8209e+00, 9.2466e+00, 4.1542e+02, 1.8331e+01, 3.5190e+02,
        1.2372e+01], dtype=torch.float64), 'fed_std': tensor([8.0486e+00, 2.3008e+01, 6.9280e+00, 2.6705e-01, 1.1085e-01, 7.5352e-01,
        2.7414e+01, 2.0918e+00, 8.5620e+00, 1.6762e+02, 2.2201e+00, 9.8738e+01,
        6.5886e+00], dtype=torch.float64), 'N_tot': tensor([215., 207., 206., 220., 204., 211., 222., 208., 219., 230., 233., 209.,
        210.], dtype=torch.float64)}
[3.07500809e+00 1.11014493e+01 1.12669903e+01 7.72727273e-02
 5.58312745e-01 6.34579147e+00 7.05545045e+01 3.82085865e+00
 9.24657534e+00 4.15421739e+02 1.83309013e+01 3.51897033e+02
 1.23715238e+01]
[8.05344570e+00 2.30072840e+01 6.92754048e+00 2.67023694e-01
 1.10866213e-01 7.53666761e-01 2.74117840e+01 2.09158817e+00
 8.56207122e+00 1.67622662e+02 2.21998097e+00 9.88107497e+01
 6.58763897e+00]
tensor([ 6.0892e-08, -1.9350e-07,  1.6666e-0

## Define experiment model and parameters

Declare a torch.nn MIWAETrainingPlan class to send for training on the node

Note : we include a function, ``standardize_data``, which allow to standardize data either with respect to a mean and std provided by the user, or locally, considering only local data for each client.

In [16]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import numpy as np
import torch.distributions as td
import pandas as pd

from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from fedbiomed.common.constants import ProcessTypes

# Here we define the model to be used. 
# You can use any class name (here 'Net')
class MIWAETrainingPlan(TorchTrainingPlan):
    def __init__(self, model_args: dict = {}):
        super(MIWAETrainingPlan, self).__init__(model_args)
        
        # Here we define the custom dependencies that will be needed by our custom Dataloader
        deps = ["from torchvision import datasets, transforms",
               "import torch.distributions as td",
               "import pandas as pd",
               "import numpy as np"]
        
        self.n_features=model_args['n_features']
        self.n_latent=model_args['n_latent']
        self.n_hidden=model_args['n_hidden']
        self.n_samples=model_args['n_samples']
        
        if 'standardization' in model_args:
            self.standardization = True
            if (('fed_mean' in model_args['standardization']) and ('fed_std' in model_args['standardization'])):
                self.fed_mean = np.array(model_args['standardization']['fed_mean'])
                self.fed_std = np.array(model_args['standardization']['fed_std'])
            else:
                self.fed_mean = None
                self.fed_std = None
        
        self.add_dependency(deps)
        
        # the encoder will output both the mean and the diagonal covariance
        self.encoder=nn.Sequential(
                        torch.nn.Linear(self.n_features, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 2*self.n_latent),  
                        )
        # the decoder will output both the mean, the scale, 
        # and the number of degrees of freedoms (hence the 3*p)
        self.decoder = nn.Sequential(
                        torch.nn.Linear(self.n_latent, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, self.n_hidden),
                        torch.nn.ReLU(),
                        torch.nn.Linear(self.n_hidden, 3*self.n_features),  
                        )
        
        self.optimizer = torch.optim.Adam(list(self.encoder.parameters()) \
                                    + list(self.decoder.parameters()),lr=1e-3)
              
        self.encoder.apply(self.weights_init)
        self.decoder.apply(self.weights_init)
    
    def weights_init(self,layer):
        if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
    
    def miwae_loss(self,iota_x,mask):
        batch_size = iota_x.shape[0]
        out_encoder = self.encoder(iota_x)
        # prior
        p_z = td.Independent(td.Normal(loc=torch.zeros(self.n_latent).to(self.device)\
                                       ,scale=torch.ones(self.n_latent).to(self.device)),1)
        
        q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :self.n_latent],\
                                                scale=torch.nn.Softplus()\
                                                (out_encoder[..., self.n_latent:\
                                                             (2*self.n_latent)])),1)

        zgivenx = q_zgivenxobs.rsample([self.n_samples])
        zgivenx_flat = zgivenx.reshape([self.n_samples*batch_size,self.n_latent])

        out_decoder = self.decoder(zgivenx_flat)
        all_means_obs_model = out_decoder[..., :self.n_features]
        all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., self.n_features:\
                                                               (2*self.n_features)]) + 0.001
        all_degfreedom_obs_model = torch.nn.Softplus()\
        (out_decoder[..., (2*self.n_features):(3*self.n_features)]) + 3

        data_flat = torch.Tensor.repeat(iota_x,[self.n_samples,1]).reshape([-1,1])
        tiledmask = torch.Tensor.repeat(mask,[self.n_samples,1])

        all_log_pxgivenz_flat = torch.distributions.StudentT\
        (loc=all_means_obs_model.reshape([-1,1]),\
         scale=all_scales_obs_model.reshape([-1,1]),\
         df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
        all_log_pxgivenz = all_log_pxgivenz_flat.reshape([self.n_samples*batch_size,self.n_features])

        logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([self.n_samples,batch_size])
        logpz = p_z.log_prob(zgivenx)
        logq = q_zgivenxobs.log_prob(zgivenx)

        neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

        return neg_bound

    def training_data(self,  batch_size = 48):
        
        df = pd.read_csv(self.dataset_path, sep=',', index_col=False)
        x_train = df.values
        x_mask = np.isfinite(x_train)
        # xhat_0: missing values are replaced by zeros. 
        #This x_hat0 is what will be fed to our encoder.
        xhat_0 = np.copy(x_train)
        
        # Data standardization
        if self.standardization:
            xhat_0 = self.standardize_data(xhat_0)
            
        xhat_0[np.isnan(x_train)] = 0
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        
        data_manager = DataManager(dataset=xhat_0 , target=x_mask , **train_kwargs)
        
        return data_manager
    
    def standardize_data(self,data):
        data_norm = np.copy(data)
        if ((self.fed_mean is not None) and (self.fed_std is not None)):
            print('FEDERATED STANDARDIZATION')
            data_norm = (data_norm - self.fed_mean)/self.fed_std
        else:
            print('LOCAL STANDARDIZATION')
            data_norm = (data_norm - np.nanmean(data_norm,0))/np.nanstd(data_norm,0)
        return data_norm
    
    def training_step(self, data, mask):
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        loss = self.miwae_loss(iota_x = data,mask = mask)
        return loss

This group of arguments correspond respectively:
* `model_args`: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node side. 
* `training_args`: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node side.
* data `tags` to search nodes for training.
* total number of `rounds`.
If FedProx optimisation is requested, `fedprox_mu` parameter must be defined here. It also must be a float between XX and YY.

**NOTE:** typos and/or lack of positional (required) arguments will raise error. 🤓

In [17]:
h = 128 # number of hidden units in (same for all MLPs)
d = 10 # dimension of the latent space, we choose d=1 for visualisation purposes
K = 20 # number of IS during training

n_epochs=5

print(type(fed_mean.tolist()[0]))

model_args = {'n_features':data_size, 'n_latent':d,'n_hidden':h,'n_samples':K,'standardization':{'fed_mean':fed_mean.tolist(),'fed_std':fed_std.tolist()}}

training_args = {
    'batch_size': 48, 
    'lr': 1e-3, 
    'log_interval' : 1,
    'epochs': n_epochs, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

tags =  ['breast_cancer']
rounds = 50

<class 'float'>


## Declare and run the experiment

- search nodes serving data for these `tags`, optionally filter on a list of node ID with `nodes`
- run a round of local training on nodes with model defined in `model_path` + federation with `aggregator`
- run for `round_limit` rounds, applying the `node_selection_strategy` between the rounds

In [None]:
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

exp = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

Let's start the experiment.

By default, this function doesn't stop until all the `round_limit` rounds are done for all the nodes

In [None]:
exp.run()

## Run the experiment with FedProx

We repeat the federated training but using FedProx as aggregation scheme (starting from the second iteration).

In [None]:
# During the first round we will simply use FedAvg 
# with standard optimization scheme: the FedProx penalization
# term will be introduced exclusively from the second round.
# training_args.update(fedprox_mu = 0.)

exp_fedprox = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

In [None]:
exp_fedprox.run_once()

In [None]:
# Starting from the second round, FedProx is used with mu=0.1
# We first update the training args
training_args.update(fedprox_mu = 0.1)

# Then update training args in the experiment
exp_fedprox.set_training_args(training_args)
exp_fedprox.run()

## Run the experiment with FedProx and performing the standardization locally

And finally we propose to use FedCos as well, which introduce an alternative penalization term with cosine similarity:

In [None]:
del training_args['fedprox_mu'] 

model_args.update(standardization = {})

exp_fedprox_std_local = Experiment(tags=tags,
                 model_args=model_args,
                 model_class=MIWAETrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None)

In [None]:
exp_fedprox_std_local.run_once()

In [None]:
training_args.update(fedprox_mu = 0.1)

exp_fedprox_std_local.set_training_args(training_args)
exp_fedprox_std_local.run()

# Test and comparison to local training

## 1. Testing on an external dataset

First of all we are going to test the performance of the final federated model to impute missing data on a test dataset. To this extent we are going to remove randomly 50% of samples from the test dataset, `data_test`, defined at the beginning of this notebook.

In [26]:
# from the test dataset, we will remove randomly 50% of data
np.random.seed(1234)

perc_miss = 0.5 # 50% of missing data

n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features

xmiss = np.copy(data_test)
xmiss_flat = xmiss.flatten()
miss_pattern = np.random.choice(n*p, np.floor(n*p*perc_miss).astype(np.int_),\
                                replace=False)
xmiss_flat[miss_pattern] = np.nan 
xmiss = xmiss_flat.reshape([n,p]) # in xmiss, the missing values are represented by nans

mask_test = np.isfinite(xmiss) # binary mask that indicates which values are missing

mean_test_missing = np.nanmean(xmiss,0)
std_test_missing = np.nanstd(xmiss,0)
xmiss = (xmiss - mean_test_missing)/std_test_missing

xhat_0_test = np.copy(xmiss)
xhat_0_test[np.isnan(xmiss)] = 0
xhat_test = np.copy(xhat_0_test) # This will be out imputed data matrix

xfull_test = np.copy(data_test)
xfull_test = (xfull_test - mean_test_missing)/std_test_missing

We define the MIWAE imputation routine:

In [27]:
def miwae_impute(encoder,decoder,iota_x,mask,d,L):
    
    p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)
    
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([L])
    zgivenx_flat = zgivenx.reshape([L*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[L,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[L,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([L*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([L,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    xgivenz = td.Independent(td.StudentT(loc=all_means_obs_model, scale=all_scales_obs_model, df=all_degfreedom_obs_model),1)

    imp_weights = torch.nn.functional.softmax(logpxobsgivenz + logpz - logq,0) # these are w_1,....,w_L for all observations in the batch
    xms = xgivenz.mean.reshape([L,batch_size,p])  # that's the only line that changed!
    xm=torch.einsum('ki,kij->ij', imp_weights, xms) 

    return xm

As well as the MSE function:

In [28]:
def mse(xhat,xtrue,mask): # MSE function for imputations
    xhat = np.array(xhat)
    xtrue = np.array(xtrue)
    return np.mean(np.power(xhat-xtrue,2)[~mask])

We instantiate the model using last updated federated parameters:

In [29]:
# extract federated model into PyTorch framework
model = exp.model_instance()
model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])

encoder = model.encoder
decoder = model.decoder

Same for the models trained with FedProx and FedCos

In [30]:
model_fedprox = exp_fedprox.model_instance()
model_fedprox.load_state_dict(exp_fedprox.aggregated_params()[rounds - 1]['params'])

encoder_fedprox = model_fedprox.encoder
decoder_fedprox = model_fedprox.decoder

And we finally do the imputation and evaluate the corresponding imputation error through MSE for each federated model:

In [31]:
L = 100

xhat = np.copy(xhat_test)
mask = np.copy(mask_test)
xhat_0 = np.copy(xhat_0_test)
xfull = np.copy(xfull_test)

xhat[~mask] = miwae_impute(encoder = encoder,decoder = decoder,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model on testing data %g' %err_test_data)
print('-----')

xhat[~mask] = miwae_impute(encoder = encoder_fedprox,decoder = decoder_fedprox,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data_fedprox = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model (with fedprox) on testing data  %g' %err_test_data_fedprox)
print('-----')

Imputation MSE of fed model on testing data 0.772005
-----
Imputation MSE of fed model (with fedprox) on testing data  0.7994
-----


## 2. Testing on a client's dataset

We are now going to use the final federated model to impute missing data of client 1, which have been used for training:

In [32]:
# We first recover data (full and with missing entries) from client 1
n = Clients_data[0].shape[0] # number of observations
p = Clients_data[0].shape[1] # number of features

xfull_cl1 = np.copy(Clients_data[0])
xfull_cl1 = (xfull_cl1 - fed_mean.numpy())/fed_std.numpy()

xmiss_cl1 = np.copy(Clients_missing[0])
xmiss_cl1 = (xmiss_cl1 - fed_mean.numpy())/fed_std.numpy()
mask_cl1 = np.isfinite(xmiss_cl1) # binary mask that indicates which values are missing
xhat_0_cl1 = np.copy(xmiss_cl1)
xhat_0_cl1[np.isnan(xmiss_cl1)] = 0
xhat_cl1 = np.copy(xhat_0_cl1) # This will be out imputed data matrix

### Now we do the imputation

xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder,decoder = decoder, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model on data from client 1  %g' %err_cl1_data)
print('-----')

xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder_fedprox,decoder = decoder_fedprox, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data_fedprox = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model (with fedprox) on data from client 1  %g' %err_cl1_data_fedprox)
print('-----')

Imputation MSE of fed model on data from client 1  0.555779
-----
Imputation MSE of fed model (with fedprox) on data from client 1  0.544626
-----


## 3. Testing of FedProx model with local standardization

We are going to test the federated model with FedProx, where data standardization is performed locally. In order to be as much coherent as possible, each time the standardization will be realized locally as well in the testing phase

In [33]:
# We recover the model
model_fedprox_std_local = exp_fedprox_std_local.model_instance()
model_fedprox_std_local.load_state_dict(exp_fedprox_std_local.aggregated_params()[rounds - 1]['params'])

encoder_fedprox_std_local = model_fedprox_std_local.encoder
decoder_fedprox_std_local = model_fedprox_std_local.decoder

# We do the imputation on the test data
n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features
xhat = np.copy(xhat_test)
mask = np.copy(mask_test)
xhat_0 = np.copy(xhat_0_test)
xfull = np.copy(xfull_test)

xhat[~mask] = miwae_impute(encoder = encoder_fedprox_std_local,decoder = decoder_fedprox_std_local,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_test_data_fedprox_std_local = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of fed model (with fedprox and local standardization) on testing data  %g' %err_test_data_fedprox_std_local)
print('-----')

# Same for the dataset from client 1. In this case the dataset standardization is done with respect to his own data
n = Clients_data[0].shape[0] # number of observations
p = Clients_data[0].shape[1] # number of features

xmiss_cl1 = np.copy(Clients_missing[0])
mean_cl1_missing = np.nanmean(xmiss_cl1,0)
std_cl1_missing = np.nanstd(xmiss_cl1,0)
xmiss_cl1 = (xmiss_cl1 - mean_cl1_missing)/std_cl1_missing
mask_cl1 = np.isfinite(xmiss_cl1) # binary mask that indicates which values are missing
xhat_0_cl1 = np.copy(xmiss_cl1)
xhat_0_cl1[np.isnan(xmiss_cl1)] = 0
xhat_cl1 = np.copy(xhat_0_cl1) # This will be out imputed data matrix

xfull_cl1 = np.copy(Clients_data[0])
xfull_cl1 = (xfull_cl1 - mean_cl1_missing)/std_cl1_missing

xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder_fedprox_std_local,decoder = decoder_fedprox_std_local, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_cl1_data_fedprox_std_local = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of fed model (with fedprox and local standardization) on data from client 1  %g' %err_cl1_data_fedprox_std_local)
print('-----')

Imputation MSE of fed model (with fedprox and local standardization) on testing data  0.77014
-----
Imputation MSE of fed model (with fedprox and local standardization) on data from client 1  0.768991
-----


## 4. Local training and testing on a client

Finally, we test the performance of the same model trained locally and tested on the dataset from client 1. We will use a total of `epochs`x`rounds` local epochs.

In [41]:
def miwae_loss(encoder,decoder,iota_x,mask,d):
    
    p_z = td.Independent(td.Normal(loc=torch.zeros(d),scale=torch.ones(d)),1)
    
    batch_size = iota_x.shape[0]
    out_encoder = encoder(iota_x)
    q_zgivenxobs = td.Independent(td.Normal(loc=out_encoder[..., :d],scale=torch.nn.Softplus()(out_encoder[..., d:(2*d)])),1)

    zgivenx = q_zgivenxobs.rsample([K])
    zgivenx_flat = zgivenx.reshape([K*batch_size,d])

    out_decoder = decoder(zgivenx_flat)
    all_means_obs_model = out_decoder[..., :p]
    all_scales_obs_model = torch.nn.Softplus()(out_decoder[..., p:(2*p)]) + 0.001
    all_degfreedom_obs_model = torch.nn.Softplus()(out_decoder[..., (2*p):(3*p)]) + 3

    data_flat = torch.Tensor.repeat(iota_x,[K,1]).reshape([-1,1])
    tiledmask = torch.Tensor.repeat(mask,[K,1])

    all_log_pxgivenz_flat = torch.distributions.StudentT(loc=all_means_obs_model.reshape([-1,1]),scale=all_scales_obs_model.reshape([-1,1]),df=all_degfreedom_obs_model.reshape([-1,1])).log_prob(data_flat)
    all_log_pxgivenz = all_log_pxgivenz_flat.reshape([K*batch_size,p])

    logpxobsgivenz = torch.sum(all_log_pxgivenz*tiledmask,1).reshape([K,batch_size])
    logpz = p_z.log_prob(zgivenx)
    logq = q_zgivenxobs.log_prob(zgivenx)

    neg_bound = -torch.mean(torch.logsumexp(logpxobsgivenz + logpz - logq,0))

    return neg_bound

We perform the local training:

In [43]:
# Recall all hyperparameters

n_epochs_local = n_epochs*rounds

bs = training_args.get('batch_size')
lr = training_args.get('lr')

n = xfull_cl1.shape[0] # number of observations
p = xfull_cl1.shape[1] # number of features

h = model_args.get('n_hidden') 
d = model_args.get('n_latent') 
K = model_args.get('n_samples') 

# Data

n = Clients_data[0].shape[0] # number of observations
p = Clients_data[0].shape[1] # number of features

xmiss_cl1 = np.copy(Clients_missing[0])
mean_cl1_missing = np.nanmean(xmiss_cl1,0)
std_cl1_missing = np.nanstd(xmiss_cl1,0)
xmiss_cl1 = (xmiss_cl1 - mean_cl1_missing)/std_cl1_missing
mask_cl1 = np.isfinite(xmiss_cl1) # binary mask that indicates which values are missing
xhat_0_cl1 = np.copy(xmiss_cl1)
xhat_0_cl1[np.isnan(xmiss_cl1)] = 0
xhat_cl1 = np.copy(xhat_0_cl1) # This will be out imputed data matrix

xfull_cl1 = np.copy(Clients_data[0])
xfull_cl1 = (xfull_cl1 - mean_cl1_missing)/std_cl1_missing

encoder_cl1 = nn.Sequential(
    torch.nn.Linear(p, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 2*d),  # the encoder will output both the mean and the diagonal covariance
)

decoder_cl1 = nn.Sequential(
    torch.nn.Linear(d, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 3*p),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
)

optimizer_cl1 = torch.optim.Adam(list(encoder_cl1.parameters()) + list(decoder_cl1.parameters()),lr=1e-3)

def weights_init(layer):
    if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
        
encoder_cl1.apply(weights_init)
decoder_cl1.apply(weights_init)

for ep in range(1,n_epochs_local):
    perm = np.random.permutation(n) # We use the "random reshuffling" version of SGD
    batches_data = np.array_split(xhat_0_cl1[perm,], n/bs)
    batches_mask = np.array_split(mask_cl1[perm,], n/bs)
    for it in range(len(batches_data)):
        optimizer_cl1.zero_grad()
        encoder_cl1.zero_grad()
        decoder_cl1.zero_grad()
        b_data = torch.from_numpy(batches_data[it]).float()
        b_mask = torch.from_numpy(batches_mask[it]).float()
        loss = miwae_loss(encoder = encoder_cl1,decoder = decoder_cl1, iota_x = b_data,mask = b_mask, d = d)
        loss.backward()
        optimizer_cl1.step()
    if ep % rounds == 1:
        print('Epoch %g' %ep)
        print('MIWAE likelihood bound  %g' %(-np.log(K)-miwae_loss(encoder = encoder_cl1,decoder = decoder_cl1, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(), d = d).cpu().data.numpy())) # Gradient step      

Epoch 1
MIWAE likelihood bound  -9.09662
Epoch 51
MIWAE likelihood bound  -3.76211
Epoch 101
MIWAE likelihood bound  -3.02126
Epoch 151
MIWAE likelihood bound  -2.67265
Epoch 201
MIWAE likelihood bound  -1.76575


And we do the imputation on the same dataset:

In [44]:
xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder_cl1, decoder = decoder_cl1, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_local_cl1_data = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of local model on data from same client (cl 1)  %g' %err_local_cl1_data)
print('-----')

Imputation MSE of local model on data from same client (cl 1)  0.807842
-----


As well as the imputation on the external test dataset:

In [50]:
n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features
xhat = np.copy(xhat_test)
mask = np.copy(mask_test)
xhat_0 = np.copy(xhat_0_test)
xfull = np.copy(xfull_test)

xhat[~mask] = miwae_impute(encoder = encoder_cl1,decoder = decoder_cl1,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_local_cl1_test_data = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of local model on testing data %g' %err_local_cl1_test_data)
print('-----')

Imputation MSE of local model on testing data 0.889466
-----


## 6. Local training on centralized data:

We centralized data from all clients, and perform the local training, hence test results on the external testing dataset as well as data from client 1, as we deed for the others models.

In [48]:
# Epochs

n_epochs_local = n_epochs*rounds*len(Clients_missing)

# Data

xmiss_tot = np.concatenate(Clients_missing,axis=0)

n = xmiss_tot.shape[0] # number of observations
p = xmiss_tot.shape[1] # number of features

mean_tot_missing = np.nanmean(xmiss_tot,0)
std_tot_missing = np.nanstd(xmiss_tot,0)
xmiss_tot = (xmiss_tot - mean_tot_missing)/std_tot_missing
mask_tot = np.isfinite(xmiss_tot) # binary mask that indicates which values are missing
xhat_0_tot = np.copy(xmiss_tot)
xhat_0_tot[np.isnan(xmiss_tot)] = 0
xhat_tot = np.copy(xhat_0_tot) # This will be out imputed data matrix

xfull_tot = np.concatenate(Clients_data,axis=0)
xfull_tot = (xfull_tot - mean_tot_missing)/std_tot_missing

# Model

encoder_tot = nn.Sequential(
    torch.nn.Linear(p, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 2*d),  # the encoder will output both the mean and the diagonal covariance
)

decoder_tot = nn.Sequential(
    torch.nn.Linear(d, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, h),
    torch.nn.ReLU(),
    torch.nn.Linear(h, 3*p),  # the decoder will output both the mean, the scale, and the number of degrees of freedoms (hence the 3*p)
)

optimizer_tot = torch.optim.Adam(list(encoder_tot.parameters()) + list(decoder_tot.parameters()),lr=1e-3)

def weights_init(layer):
    if type(layer) == nn.Linear: torch.nn.init.orthogonal_(layer.weight)
        
encoder_tot.apply(weights_init)
decoder_tot.apply(weights_init)

# Training loop

for ep in range(1,n_epochs_local):
    perm = np.random.permutation(n) # We use the "random reshuffling" version of SGD
    batches_data = np.array_split(xhat_0_tot[perm,], n/bs)
    batches_mask = np.array_split(mask_tot[perm,], n/bs)
    for it in range(len(batches_data)):
        optimizer_tot.zero_grad()
        encoder_tot.zero_grad()
        decoder_tot.zero_grad()
        b_data = torch.from_numpy(batches_data[it]).float()
        b_mask = torch.from_numpy(batches_mask[it]).float()
        loss = miwae_loss(encoder = encoder_tot,decoder = decoder_tot, iota_x = b_data,mask = b_mask, d = d)
        loss.backward()
        optimizer_tot.step()
    if ep % rounds == 1:
        print('Epoch %g' %ep)
        print('MIWAE likelihood bound  %g' %(-np.log(K)-miwae_loss(encoder = encoder_tot,decoder = decoder_tot, iota_x = torch.from_numpy(xhat_0_tot).float(),mask = torch.from_numpy(mask_tot).float(), d = d).cpu().data.numpy())) # Gradient step      

Epoch 1
MIWAE likelihood bound  -9.29921
Epoch 51
MIWAE likelihood bound  -3.01712
Epoch 101
MIWAE likelihood bound  -2.26313
Epoch 151
MIWAE likelihood bound  -1.73027
Epoch 201
MIWAE likelihood bound  -1.21782
Epoch 251
MIWAE likelihood bound  -0.0541024
Epoch 301
MIWAE likelihood bound  0.435457
Epoch 351
MIWAE likelihood bound  1.02151
Epoch 401
MIWAE likelihood bound  0.92007
Epoch 451
MIWAE likelihood bound  0.783818
Epoch 501
MIWAE likelihood bound  1.40394
Epoch 551
MIWAE likelihood bound  1.28166
Epoch 601
MIWAE likelihood bound  2.22676
Epoch 651
MIWAE likelihood bound  2.30973
Epoch 701
MIWAE likelihood bound  2.1331


In [51]:
n = Clients_data[0].shape[0] # number of observations
p = Clients_data[0].shape[1] # number of features

xhat_cl1[~mask_cl1] = miwae_impute(encoder = encoder_tot, decoder = decoder_tot, iota_x = torch.from_numpy(xhat_0_cl1).float(),mask = torch.from_numpy(mask_cl1).float(),d = d,L= L).cpu().data.numpy()[~mask_cl1]
err_local_tot_cl1_data = np.array([mse(xhat_cl1,xfull_cl1,mask_cl1)])
print('Imputation MSE of local model on the whole dataset, on data from same client 1  %g' %err_local_tot_cl1_data)
print('-----')

n = data_test.shape[0] # number of observations
p = data_test.shape[1] # number of features
xhat = np.copy(xhat_test)
mask = np.copy(mask_test)
xhat_0 = np.copy(xhat_0_test)
xfull = np.copy(xfull_test)

xhat[~mask] = miwae_impute(encoder = encoder_tot,decoder = decoder_tot,iota_x = torch.from_numpy(xhat_0).float(),mask = torch.from_numpy(mask).float(),d = d,L= L).cpu().data.numpy()[~mask]
err_local_tot_cl1_test_data = np.array([mse(xhat,xfull,mask)])
print('Imputation MSE of local model on the whole dataset, on testing data %g' %err_local_tot_cl1_test_data)
print('-----')

Imputation MSE of local model on the whole dataset, on data from same client 1  0.797672
-----
Imputation MSE of local model on the whole dataset, on testing data 0.871439
-----


## 6. Summary of obtained results:

In [53]:
from tabulate import tabulate

print('Imputation MSE on testing data')
print('-----')
data = [['FedAvg', err_test_data],
['FedProx', err_test_data_fedprox],
['FedLocStd', err_test_data_fedprox_std_local],
['Local (cl1)', err_local_cl1_test_data],
['Centralized', err_local_tot_cl1_test_data]]
print (tabulate(data, headers=["Model", "Mean Squared Error (\u2193)"]))
print('-----')
print('-----')
print('Imputation MSE on local data from client 1')
print('-----')
data = [['FedAvg', err_cl1_data],
['FedProx', err_cl1_data_fedprox],
['FedLocStd', err_cl1_data_fedprox_std_local],
['Local (cl1)', err_local_cl1_data],
['Centralized', err_local_tot_cl1_test_data]]
print (tabulate(data, headers=["Model", "Mean Squared Error (\u2193)"]))

Imputation MSE on testing data
-----
Model          Mean Squared Error (↓)
-----------  ------------------------
FedAvg                       0.772005
FedProx                      0.7994
FedLocStd                    0.77014
Local (cl1)                  0.889466
Centralized                  0.871439
-----
-----
Imputation MSE on local data from client 1
-----
Model          Mean Squared Error (↓)
-----------  ------------------------
FedAvg                       0.555779
FedProx                      0.544626
FedLocStd                    0.768991
Local (cl1)                  0.807842
Centralized                  0.871439
