## Importing libraries

In [23]:
import pypots
import os
import sys
from pypots.utils.metrics import calc_mae
from pypots.optim import Adam
from pypots.imputation import SAITS, BRITS, USGAN
import numpy as np
import benchpots
from pypots.utils.random import set_random_seed
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import pandas as pd    
from pypotsModify.benchpotsMAE.datasets import preprocess_physionet2012

## Loading database

In [3]:
set_random_seed()

# Load the PhysioNet-2012 dataset
physionet2012_dataset = preprocess_physionet2012(subset="all", rate=0.1)

# Take a look at the generated PhysioNet-2012 dataset, you'll find that everything has been prepared for you,
# data splitting, normalization, additional artificially-missing values for evaluation, etc.
print(physionet2012_dataset.keys())

2025-02-05 14:26:23 [INFO]: Have set the random seed as 2022 for numpy and pytorch.
2025-02-05 14:26:23 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-02-05 14:26:23 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-02-05 14:26:23 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...
2025-02-05 14:26:23 [INFO]: Loaded successfully!
2025-02-05 14:26:39 [INFO]: 68807 values masked out in the val set as ground truth, take 9.97% of the original observed values
2025-02-05 14:26:39 [INFO]: 68807 values masked out in the val set as ground truth, take 9.97% of the original observed values
2025-02-05 14:26:39 [INFO]: 86319 values masked out in the test set as ground truth, take 9.99% of the original observed values
2025-02-05 14:26:39 [INFO]: 86319 valu

dict_keys(['n_classes', 'n_steps', 'n_features', 'scaler', 'train_X', 'train_y', 'train_ICUType', 'val_X', 'val_y', 'val_ICUType', 'test_X', 'test_y', 'test_ICUType', 'female_gender_test_X', 'female_gender_test_y', 'test_ICUType_female_gender', 'male_gender_test_X', 'male_gender_test_y', 'test_ICUType_male_gender', 'undefined_gender_test_X', 'undefined_gender_test_y', 'test_ICUType_undefined_gender', 'more_than_or_equal_to_65_test_X', 'more_than_or_equal_to_65_test_y', 'test_ICUType_more_than_or_equal_to_65', 'less_than_65_test_X', 'less_than_65_test_y', 'test_ICUType_less_than_65', 'ICUType_1_test_X', 'ICUType_1_test_y', 'test_ICUType_1', 'ICUType_2_test_X', 'ICUType_2_test_y', 'test_ICUType_2', 'ICUType_3_test_X', 'ICUType_3_test_y', 'test_ICUType_3', 'ICUType_4_test_X', 'ICUType_4_test_y', 'test_ICUType_4', 'classificacao_undefined_test_X', 'classificacao_undefined_test_y', 'test_ICUType_classificacao_undefined', 'classificacao_baixo_peso_test_X', 'classificacao_baixo_peso_test_y', 

In [4]:
# assemble the datasets for training
dataset_for_training = {
    "X": physionet2012_dataset['train_X'],
}
# assemble the datasets for validation
dataset_for_validating = {
    "X": physionet2012_dataset['val_X'],
    "X_ori": physionet2012_dataset['val_X_ori'],
}

dataset_for_testing_ori = {
    "X_ori": physionet2012_dataset['test_X_ori'],
    "female_gender_test_X_ori": physionet2012_dataset['female_gender_test_X_ori'],
    "male_gender_test_X_ori": physionet2012_dataset['male_gender_test_X_ori'],
    "undefined_gender_test_X_ori": physionet2012_dataset['undefined_gender_test_X_ori'],
    "more_than_or_equal_to_65_test_X_ori":  physionet2012_dataset['more_than_or_equal_to_65_test_X_ori'],
    "less_than_65_test_X_ori": physionet2012_dataset['less_than_65_test_X_ori'],
    "ICUType_1_test_X_ori": physionet2012_dataset['ICUType_1_test_X_ori'],
    "ICUType_2_test_X_ori": physionet2012_dataset['ICUType_2_test_X_ori'],
    "ICUType_3_test_X_ori": physionet2012_dataset['ICUType_3_test_X_ori'],
    "ICUType_4_test_X_ori": physionet2012_dataset['ICUType_4_test_X_ori'],
    "classificacao_undefined_test_X_ori": physionet2012_dataset['classificacao_undefined_test_X_ori'],
    "classificacao_baixo_peso_test_X_ori": physionet2012_dataset['classificacao_baixo_peso_test_X_ori'],
    "classificacao_normal_peso_test_X_ori": physionet2012_dataset['classificacao_normal_peso_test_X_ori'],
    "classificacao_sobrepeso_test_X_ori": physionet2012_dataset['classificacao_sobrepeso_test_X_ori'],
    "classificacao_obesidade_1_test_X_ori": physionet2012_dataset['classificacao_obesidade_1_test_X_ori'],
    "classificacao_obesidade_2_test_X_ori": physionet2012_dataset['classificacao_obesidade_2_test_X_ori'],
    "classificacao_obesidade_3_test_X_ori": physionet2012_dataset['classificacao_obesidade_3_test_X_ori']
}

# assemble the datasets for test
dataset_for_testing = {
    "X": physionet2012_dataset['test_X'],
    "female_gender_test_X": physionet2012_dataset['female_gender_test_X'],
    "male_gender_test_X": physionet2012_dataset['male_gender_test_X'],
    "undefined_gender_test_X": physionet2012_dataset['undefined_gender_test_X'],
    "more_than_or_equal_to_65_test_X":  physionet2012_dataset['more_than_or_equal_to_65_test_X'],
    "less_than_65_test_X": physionet2012_dataset['less_than_65_test_X'],
    "ICUType_1_test_X": physionet2012_dataset['ICUType_1_test_X'],
    "ICUType_2_test_X": physionet2012_dataset['ICUType_2_test_X'],
    "ICUType_3_test_X": physionet2012_dataset['ICUType_3_test_X'],
    "ICUType_4_test_X": physionet2012_dataset['ICUType_4_test_X'],
    "classificacao_undefined_test_X": physionet2012_dataset['classificacao_undefined_test_X'],
    "classificacao_baixo_peso_test_X": physionet2012_dataset['classificacao_baixo_peso_test_X'],
    "classificacao_normal_peso_test_X": physionet2012_dataset['classificacao_normal_peso_test_X'],
    "classificacao_sobrepeso_test_X": physionet2012_dataset['classificacao_sobrepeso_test_X'],
    "classificacao_obesidade_1_test_X": physionet2012_dataset['classificacao_obesidade_1_test_X'],
    "classificacao_obesidade_2_test_X": physionet2012_dataset['classificacao_obesidade_2_test_X'],
    "classificacao_obesidade_3_test_X": physionet2012_dataset['classificacao_obesidade_3_test_X']
}

## calculate the mask to indicate the ground truth positions in test_X_ori, will be used by metric funcs to evaluate models
test_X_indicating_mask = []
test_X_ori = []
for i, j in zip(dataset_for_testing_ori.values(), dataset_for_testing.values()):
    test_X_indicating_mask.append(np.isnan(i) ^ np.isnan(j))
    test_X_ori.append(np.nan_to_num(i))# metric functions do not accpet input with NaNs, hence fill NaNs with 0


In [5]:
test_X_indicating_mask_variable = []
test_X_ori_variable = []
for i in range(len(test_X_indicating_mask)):
    test_X_indicating_mask_variable.append(test_X_indicating_mask[i].reshape(37, len(test_X_indicating_mask[i]) * 48))
    test_X_ori_variable.append(test_X_ori[i].reshape(37, len(test_X_ori[i]) * 48))

## Inicialize the models

#### SAITS

In [5]:
saits = SAITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_layers=1,
    d_model=256,
    d_ffn=128,
    n_heads=4,
    d_k=64,
    d_v=64,
    dropout=0.1,
    ORT_weight=1,  # you can adjust the weight values of arguments ORT_weight
    # and MIT_weight to make the SAITS model focus more on one task. Usually you can just leave them to the default values, i.e. 1.
    MIT_weight=1,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    saving_path="tutorial_results/imputation/saits",
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

2025-02-04 16:05:56 [INFO]: No given device, using default device: cuda
2025-02-04 16:05:56 [INFO]: Model files will be saved to tutorial_results/imputation/saits/20250204_T160556
2025-02-04 16:05:56 [INFO]: Tensorboard file will be saved to tutorial_results/imputation/saits/20250204_T160556/tensorboard
2025-02-04 16:05:56 [INFO]: SAITS initialized with the given hyperparameters, the number of trainable parameters: 720,182


In [6]:
saits = SAITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_layers=1,
    d_model=256,
    d_ffn=128,
    n_heads=4,
    d_k=64,
    d_v=64,
    dropout=0.1,
    ORT_weight=1,  # you can adjust the weight values of arguments ORT_weight
    # and MIT_weight to make the SAITS model focus more on one task. Usually you can just leave them to the default values, i.e. 1.
    MIT_weight=1,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

2025-02-05 14:28:28 [INFO]: No given device, using default device: cuda
2025-02-05 14:28:28 [INFO]: SAITS initialized with the given hyperparameters, the number of trainable parameters: 720,182


#### BRITS

In [29]:
brits = BRITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=128,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    saving_path="tutorial_results/imputation/brits",
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

2025-02-04 16:21:08 [INFO]: No given device, using default device: cuda
2025-02-04 16:21:08 [INFO]: Model files will be saved to tutorial_results/imputation/brits/20250204_T162108
2025-02-04 16:21:08 [INFO]: Tensorboard file will be saved to tutorial_results/imputation/brits/20250204_T162108/tensorboard
2025-02-04 16:21:08 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 239,344


In [8]:
brits = BRITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=128,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

2025-02-05 14:30:11 [INFO]: No given device, using default device: cuda
2025-02-05 14:30:11 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 239,344


#### USGAN

In [24]:
us_gan = USGAN(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=256,
    lambda_mse=1,
    dropout=0.1,
    G_steps=1,
    D_steps=1,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    G_optimizer=Adam(lr=1e-3),
    D_optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    saving_path="tutorial_results/imputation/us_gan",
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

2025-02-05 14:47:07 [INFO]: No given device, using default device: cuda
2025-02-05 14:47:07 [INFO]: Model files will be saved to tutorial_results/imputation/us_gan/20250205_T144707
2025-02-05 14:47:07 [INFO]: Tensorboard file will be saved to tutorial_results/imputation/us_gan/20250205_T144707/tensorboard
2025-02-05 14:47:07 [INFO]: USGAN initialized with the given hyperparameters, the number of trainable parameters: 1,258,517


In [None]:
us_gan = USGAN(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=256,
    lambda_mse=1,
    dropout=0.1,
    G_steps=1,
    D_steps=1,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    G_optimizer=Adam(lr=1e-3),
    D_optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

In [None]:
us_gan.load("tutorial_results/imputation/us_gan/20250205_T144707/USGAN.pypots")

## Train/Load the models

#### SAITS

In [6]:
# train the model on the training set, and validate it on the validating set to select the best model for testing in the next step
saits.fit(train_set=dataset_for_training, val_set=dataset_for_validating)

2025-02-04 16:06:11 [INFO]: Epoch 001 - training loss: 0.7230, validation loss: 6.7596
2025-02-04 16:06:11 [INFO]: Saved the model to tutorial_results/imputation/saits/20250204_T160556/SAITS_epoch1_loss6.75964179734389.pypots
2025-02-04 16:06:16 [INFO]: Epoch 002 - training loss: 0.5370, validation loss: 6.7420
2025-02-04 16:06:16 [INFO]: Saved the model to tutorial_results/imputation/saits/20250204_T160556/SAITS_epoch2_loss6.741982015470664.pypots
2025-02-04 16:06:22 [INFO]: Epoch 003 - training loss: 0.4953, validation loss: 6.7242
2025-02-04 16:06:22 [INFO]: Saved the model to tutorial_results/imputation/saits/20250204_T160556/SAITS_epoch3_loss6.724190835654736.pypots
2025-02-04 16:06:28 [INFO]: Epoch 004 - training loss: 0.4641, validation loss: 6.7040
2025-02-04 16:06:28 [INFO]: Saved the model to tutorial_results/imputation/saits/20250204_T160556/SAITS_epoch4_loss6.703969217836857.pypots
2025-02-04 16:06:34 [INFO]: Epoch 005 - training loss: 0.4387, validation loss: 6.6843
2025-0

In [7]:
saits.load("tutorial_results/imputation/saits/20250204_T160556/SAITS.pypots")

  loaded_model = torch.load(path, map_location=self.device)
2025-02-05 14:28:53 [INFO]: Model loaded successfully from tutorial_results/imputation/saits/20250204_T160556/SAITS.pypots


#### BRITS

In [30]:
# train the model on the training set, and validate it on the validating set to select the best model for testing in the next step
brits.fit(train_set=dataset_for_training, val_set=dataset_for_validating)

2025-02-04 16:23:39 [INFO]: Epoch 001 - training loss: 0.9443, validation loss: 6.7825
2025-02-04 16:23:39 [INFO]: Saved the model to tutorial_results/imputation/brits/20250204_T162108/BRITS_epoch1_loss6.782457655668258.pypots
2025-02-04 16:25:32 [INFO]: Epoch 002 - training loss: 0.7357, validation loss: 6.7422
2025-02-04 16:25:32 [INFO]: Saved the model to tutorial_results/imputation/brits/20250204_T162108/BRITS_epoch2_loss6.742190875361363.pypots
2025-02-04 16:27:24 [INFO]: Epoch 003 - training loss: 0.6829, validation loss: 6.7320
2025-02-04 16:27:24 [INFO]: Saved the model to tutorial_results/imputation/brits/20250204_T162108/BRITS_epoch3_loss6.73199325154225.pypots
2025-02-04 16:29:14 [INFO]: Epoch 004 - training loss: 0.6586, validation loss: 6.7264
2025-02-04 16:29:14 [INFO]: Saved the model to tutorial_results/imputation/brits/20250204_T162108/BRITS_epoch4_loss6.726365307718515.pypots
2025-02-04 16:31:08 [INFO]: Epoch 005 - training loss: 0.6433, validation loss: 6.7244
2025-0

In [9]:
brits.load("tutorial_results/imputation/brits/20250204_T162108/BRITS.pypots")

2025-02-05 14:38:01 [INFO]: Model loaded successfully from tutorial_results/imputation/brits/20250204_T162108/BRITS.pypots


#### USGAN

In [25]:
us_gan.fit(train_set=dataset_for_training, val_set=dataset_for_validating)

2025-02-05 14:52:52 [INFO]: Epoch 001 - generator training loss: 0.4299, discriminator training loss: 0.1869, validation loss: 6.7646
2025-02-05 14:52:52 [INFO]: Saved the model to tutorial_results/imputation/us_gan/20250205_T144707/USGAN_epoch1_loss6.764622490356365.pypots
2025-02-05 14:56:00 [INFO]: Epoch 002 - generator training loss: 0.3612, discriminator training loss: 0.0552, validation loss: 6.7262
2025-02-05 14:56:00 [INFO]: Saved the model to tutorial_results/imputation/us_gan/20250205_T144707/USGAN_epoch2_loss6.7262225342293585.pypots
2025-02-05 14:59:29 [INFO]: Epoch 003 - generator training loss: 0.3365, discriminator training loss: 0.0370, validation loss: 6.6997
2025-02-05 14:59:29 [INFO]: Saved the model to tutorial_results/imputation/us_gan/20250205_T144707/USGAN_epoch3_loss6.699678348998229.pypots
2025-02-05 15:02:31 [INFO]: Epoch 004 - generator training loss: 0.3192, discriminator training loss: 0.0312, validation loss: 6.6862
2025-02-05 15:02:31 [INFO]: Saved the mo

## Testing stage

#### SAITS

In [10]:
saits_imputation = []
for value in  dataset_for_testing.values():
   _dict = {'X':value}
   saits_results = saits.predict(_dict)
   saits_imputation.append(saits_results["imputation"])


In [11]:
saits_imputation_variable = []
for i in range(len(saits_imputation)):
    saits_imputation_variable.append(saits_imputation[i].reshape(37, len(saits_imputation[i]) * 48))


#### BRITS

In [12]:
brits_imputation = []
for value in dataset_for_testing.values():
    _dict = {'X':value}
    brits_results = brits.predict(_dict)
    brits_imputation.append(brits_results["imputation"])

In [13]:
brits_imputation_variable = []
for i in range(len(brits_imputation)):
    brits_imputation_variable.append(brits_imputation[i].reshape(37, len(brits_imputation[i]) * 48))

#### USGAN

In [28]:
usgan_imputation = []
for value in  dataset_for_testing.values():
   _dict = {'X':value}
   usgan_results = us_gan.predict(_dict)
   usgan_imputation.append(usgan_results["imputation"])

In [29]:
usgan_imputation_variable = []
for i in range(len(usgan_imputation)):
    usgan_imputation_variable.append(usgan_imputation[i].reshape(37, len(usgan_imputation[i]) * 48))

## Calculate mean absolute error

#### SAITS

In [14]:
testing_mae_saits_append_subgroups = []
testing_mae_saits_append_variables = []
for i in range(len(saits_imputation_variable)):
    for j in range(len(saits_imputation_variable[i])):
        testing_mae_saits_append_variables.append(calc_mae(saits_imputation_variable[i][j], test_X_ori_variable[i][j], test_X_indicating_mask_variable[i][j]))
    testing_mae_saits_append_subgroups.append(testing_mae_saits_append_variables)
    testing_mae_saits_append_variables = [] 

#### BRITS

In [15]:
testing_mae_brits_append_subgroups = []
testing_mae_brits_append_variables = []
for i in range(len(brits_imputation_variable)):
    for j in range(len(brits_imputation_variable[i])):
      testing_mae_brits_append_variables.append(calc_mae(brits_imputation_variable[i][j], test_X_ori_variable[i][j], test_X_indicating_mask_variable[i][j]))
    testing_mae_brits_append_subgroups.append(testing_mae_brits_append_variables)
    testing_mae_brits_append_variables = []

#### USGAN

In [30]:
testing_mae_usgan_append_subgroups = []
testing_mae_usgan_append_variables = []
for i in range(len(usgan_imputation_variable)):
    for j in range(len(usgan_imputation_variable[i])):
        testing_mae_usgan_append_variables.append(calc_mae(usgan_imputation_variable[i][j], test_X_ori_variable[i][j], test_X_indicating_mask_variable[i][j]))
    testing_mae_usgan_append_subgroups.append(testing_mae_usgan_append_variables)
    testing_mae_usgan_append_variables = []

## Results mean absolute error

In [16]:
subgroups = ["General", "Female", "Male", "Undefined Gender", "+65", "-65", "ICUType 1", "ICUType 2", "ICUType 3", "ICUType 4", "Undefined classification", "Low Weight", "Normal Weight", "Overweight", "Obesity 1", "Obesity 2", "Obesity 3"]

In [17]:
variables = ["ALP", "ALT", "AST", "Albumin", "BUN", "Bilirubin", "Cholesterol", "Creatinine", "DiasABP", "FiO2", "GCS", "Glucose", "HCO3", "HCT", "HR", "K", "Lactate", "MAP", "MechVent", "Mg", "NIDiasABP", "NIMAP", "NISysABP", "Na", "PaCO2", "PaO2", "Platelets", "RespRate", "SaO2", "SysABP", "Temp", "TroponinI", "TroponinT", "Urine", "WBC", "Weight", "Ph"]

#### SAITS

In [18]:
print("SAITS - MAE")
print("************")
for i in range(len(subgroups)): 
    print(subgroups[i]) 
    print("-------------")
    for j in range(len(variables)):
        print(variables[j], ":" ,testing_mae_saits_append_subgroups[i][j])

SAITS - MAE
************
General
-------------
ALP : 0.2460293682153394
ALT : 0.22043134772111558
AST : 0.26230819338030253
Albumin : 0.30293279432474296
BUN : 0.23659828366631006
Bilirubin : 0.2312914200797686
Cholesterol : 0.251679622012055
Creatinine : 0.2557353119481593
DiasABP : 0.23659807986738537
FiO2 : 0.2439170004762481
GCS : 0.27224084306168234
Glucose : 0.26067426800470134
HCO3 : 0.25306096736987144
HCT : 0.23665193821730493
HR : 0.22776729540049326
K : 0.23719489152605258
Lactate : 0.2550807744721406
MAP : 0.2596030041903261
MechVent : 0.2590311293507116
Mg : 0.2570448834883059
NIDiasABP : 0.2297923796966098
NIMAP : 0.23891888162436045
NISysABP : 0.25352442673224646
Na : 0.2576979141754115
PaCO2 : 0.22809380110830693
PaO2 : 0.2380665948777946
Platelets : 0.23528387027030165
RespRate : 0.25253001590961144
SaO2 : 0.255004959199687
SysABP : 0.2352274283737271
Temp : 0.24351372214867664
TroponinI : 0.24266481480201985
TroponinT : 0.2294270462157759
Urine : 0.24475188600873804
W

In [19]:
df_saits_mae = pd.DataFrame(variables)

for i in range(len(subgroups)):
    df_saits_mae[subgroups[i]] = testing_mae_saits_append_subgroups[i]

In [20]:
df_saits_mae

Unnamed: 0,0,General,Female,Male,Undefined Gender,+65,-65,ICUType 1,ICUType 2,ICUType 3,ICUType 4,Undefined classification,Low Weight,Normal Weight,Overweight,Obesity 1,Obesity 2,Obesity 3
0,ALP,0.246029,0.220452,0.281318,0.024085,0.254108,0.286183,0.291996,0.220202,0.261112,0.264622,0.292983,0.277895,0.247686,0.237087,0.291813,0.180335,0.239991
1,ALT,0.220431,0.231447,0.222804,0.0,0.254721,0.251318,0.278319,0.228516,0.228215,0.239234,0.248507,0.326892,0.210455,0.191214,0.277488,0.329962,0.228294
2,AST,0.262308,0.263395,0.237008,0.295356,0.248577,0.261737,0.274582,0.216824,0.246338,0.239527,0.266911,0.324861,0.244048,0.266117,0.21383,0.227629,0.178151
3,Albumin,0.302933,0.245631,0.308385,0.25696,0.225847,0.242713,0.218852,0.215799,0.271407,0.23835,0.241231,0.368676,0.253718,0.22256,0.287282,0.245201,0.181224
4,BUN,0.236598,0.224053,0.245069,0.038733,0.277838,0.235693,0.308771,0.204856,0.252615,0.247639,0.271124,0.281561,0.237951,0.256407,0.2469,0.205267,0.306612
5,Bilirubin,0.231291,0.250564,0.265393,0.098811,0.267808,0.258753,0.256349,0.205388,0.303023,0.255608,0.265551,0.239788,0.24481,0.222906,0.190586,0.242239,0.210118
6,Cholesterol,0.25168,0.231705,0.233628,0.169514,0.234425,0.274246,0.292118,0.201527,0.292166,0.258185,0.268551,0.178686,0.185244,0.216891,0.247719,0.277958,0.250856
7,Creatinine,0.255735,0.263072,0.283228,0.242628,0.258071,0.278139,0.214152,0.23179,0.269252,0.251534,0.281256,0.188572,0.222748,0.263322,0.231723,0.248706,0.18941
8,DiasABP,0.236598,0.265557,0.260386,0.12953,0.238133,0.36964,0.231949,0.207031,0.286422,0.259389,0.293405,0.250004,0.206029,0.202719,0.227379,0.176758,0.2079
9,FiO2,0.243917,0.263631,0.248451,0.118587,0.236096,0.279574,0.257063,0.218261,0.281995,0.23735,0.246835,0.178762,0.210799,0.248204,0.257538,0.231925,0.243773


#### BRITS

In [21]:
print("BRITS - MAE")
print("************")
for i in range(len(subgroups)):
    print(subgroups[i]) 
    print("-------------")
    for j in range(len(variables)):
        print(variables[j], ":" ,testing_mae_brits_append_subgroups[i][j])

BRITS - MAE
************
General
-------------
ALP : 0.2707367647804506
ALT : 0.23896646826300247
AST : 0.28878421876356886
Albumin : 0.32158155068506383
BUN : 0.24372880799652813
Bilirubin : 0.2516354143011555
Cholesterol : 0.25355475512462006
Creatinine : 0.2723656167137088
DiasABP : 0.25076286871595893
FiO2 : 0.2514528318158146
GCS : 0.5404476843954903
Glucose : 0.2736264937761537
HCO3 : 0.2620078276782997
HCT : 0.24734390750479204
HR : 0.24368038771195344
K : 0.24904497537352557
Lactate : 0.2648993516274657
MAP : 0.27131795054600477
MechVent : 0.26718994902053
Mg : 0.2740963009645579
NIDiasABP : 0.23742945499040968
NIMAP : 0.24410922965602125
NISysABP : 0.26102740091037013
Na : 0.27076051164665826
PaCO2 : 0.24434648335436876
PaO2 : 0.261024932337629
Platelets : 0.24876741391001614
RespRate : 0.25789694312491024
SaO2 : 0.2707833075331891
SysABP : 0.25657190476225933
Temp : 0.2649548352687018
TroponinI : 0.2501403631064177
TroponinT : 0.2427855355938498
Urine : 0.2717610269676055
WBC

In [22]:
df_brits_mae = pd.DataFrame(variables)

for i in range(len(subgroups)):
    df_brits_mae[subgroups[i]] = testing_mae_brits_append_subgroups[i]

#### USGAN 

In [31]:
print("USGAN - MAE")
print("************")
for i in range(len(subgroups)):
    print(subgroups[i]) 
    print("-------------")
    for j in range(len(variables)):
        print(variables[j], ":" ,testing_mae_usgan_append_subgroups[i][j])

USGAN - MAE
************
General
-------------
ALP : 0.28327919700514537
ALT : 0.25698387822352164
AST : 0.29381425780834014
Albumin : 0.3316851965822937
BUN : 0.2558014721066809
Bilirubin : 0.26569659912151444
Cholesterol : 0.2717877965417484
Creatinine : 0.2807011303352554
DiasABP : 0.25648366397544786
FiO2 : 0.26940385736786227
GCS : 0.30802102369019674
Glucose : 0.27290957048290815
HCO3 : 0.2796737059846553
HCT : 0.2611459327041813
HR : 0.25871885034518405
K : 0.2614191212323757
Lactate : 0.2785402883573093
MAP : 0.28843045647664245
MechVent : 0.28180003190404546
Mg : 0.2834200527338646
NIDiasABP : 0.2516281229311397
NIMAP : 0.26122135945979064
NISysABP : 0.27605309628950486
Na : 0.27928460154237206
PaCO2 : 0.25462611017019066
PaO2 : 0.27273692419304074
Platelets : 0.2615611081737651
RespRate : 0.26835562858972456
SaO2 : 0.2824451172805901
SysABP : 0.2731640090539184
Temp : 0.27050709716394056
TroponinI : 0.2706214236991907
TroponinT : 0.259756080398352
Urine : 0.28173089277097757


In [32]:
df_usgan_mae = pd.DataFrame(variables)

for i in range(len(subgroups)):
    df_usgan_mae[subgroups[i]] = testing_mae_usgan_append_subgroups[i]

In [33]:
df_usgan_mae

Unnamed: 0,0,General,Female,Male,Undefined Gender,+65,-65,ICUType 1,ICUType 2,ICUType 3,ICUType 4,Undefined classification,Low Weight,Normal Weight,Overweight,Obesity 1,Obesity 2,Obesity 3
0,ALP,0.283279,0.260204,0.292055,0.639126,0.27895,0.3112,0.324575,0.23286,0.284586,0.274382,0.3084,0.256726,0.27865,0.256221,0.302301,0.219735,0.282214
1,ALT,0.256984,0.266539,0.25154,0.0,0.297289,0.277807,0.285709,0.251689,0.266937,0.261526,0.291585,0.301074,0.233951,0.228284,0.302541,0.448541,0.250768
2,AST,0.293814,0.303223,0.269565,0.384212,0.279343,0.280914,0.302549,0.236463,0.29215,0.263452,0.298456,0.373478,0.268835,0.295488,0.255971,0.239183,0.225336
3,Albumin,0.331685,0.281944,0.324433,0.283477,0.256423,0.277574,0.263155,0.21553,0.296938,0.245375,0.260768,0.352473,0.268574,0.237813,0.322414,0.23258,0.216323
4,BUN,0.255801,0.252679,0.267172,0.152919,0.299424,0.271188,0.310056,0.249563,0.286176,0.262668,0.302922,0.35031,0.263515,0.274829,0.271573,0.211423,0.304615
5,Bilirubin,0.265697,0.279324,0.293216,0.230283,0.285931,0.266058,0.283463,0.226065,0.332316,0.280361,0.293361,0.288386,0.26928,0.271721,0.254418,0.240187,0.236993
6,Cholesterol,0.271788,0.254231,0.258275,0.284804,0.255848,0.291156,0.310502,0.21985,0.307705,0.280798,0.290638,0.186191,0.21431,0.250294,0.257985,0.307338,0.283547
7,Creatinine,0.280701,0.296907,0.30121,0.221904,0.285033,0.29613,0.24412,0.244509,0.312245,0.275603,0.292766,0.186234,0.280157,0.285964,0.269462,0.279966,0.219875
8,DiasABP,0.256484,0.288802,0.280281,0.180854,0.264324,0.368516,0.293311,0.208421,0.310414,0.278863,0.313247,0.289131,0.223244,0.222269,0.25946,0.24075,0.203677
9,FiO2,0.269404,0.282501,0.275456,0.070404,0.265951,0.308677,0.299912,0.23076,0.305293,0.266495,0.265278,0.19652,0.244525,0.245101,0.278537,0.242282,0.225017
