### Importing libraries

In [39]:
import pypots
import os
import sys
from pypots.utils.metrics import calc_mae
from pypots.optim import Adam
from pypots.imputation import SAITS, BRITS
import numpy as np
import benchpots
from pypots.utils.random import set_random_seed
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import pandas as pd    

### Loading database

In [40]:
set_random_seed()

from pypotsModify.benchpotsMAE.datasets import preprocess_physionet2012

# Load the PhysioNet-2012 dataset
physionet2012_dataset = preprocess_physionet2012(subset="all", rate=0.1)

# Take a look at the generated PhysioNet-2012 dataset, you'll find that everything has been prepared for you,
# data splitting, normalization, additional artificially-missing values for evaluation, etc.
print(physionet2012_dataset.keys())

2025-02-03 21:14:00 [INFO]: Have set the random seed as 2022 for numpy and pytorch.
2025-02-03 21:14:00 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-02-03 21:14:00 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-02-03 21:14:00 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...


2025-02-03 21:14:00 [INFO]: Loaded successfully!
2025-02-03 21:14:13 [INFO]: 68807 values masked out in the val set as ground truth, take 9.97% of the original observed values
2025-02-03 21:14:13 [INFO]: 68807 values masked out in the val set as ground truth, take 9.97% of the original observed values
2025-02-03 21:14:13 [INFO]: 86319 values masked out in the test set as ground truth, take 9.99% of the original observed values
2025-02-03 21:14:13 [INFO]: 86319 values masked out in the test set as ground truth, take 9.99% of the original observed values
2025-02-03 21:14:13 [INFO]: Total sample number: 11988
2025-02-03 21:14:13 [INFO]: Total sample number: 11988
2025-02-03 21:14:13 [INFO]: Training set size: 7671 (63.99%)
2025-02-03 21:14:13 [INFO]: Training set size: 7671 (63.99%)
2025-02-03 21:14:13 [INFO]: Validation set size: 1918 (16.00%)
2025-02-03 21:14:13 [INFO]: Validation set size: 1918 (16.00%)
2025-02-03 21:14:13 [INFO]: Test set size: 2399 (20.01%)
2025-02-03 21:14:13 [INFO]

dict_keys(['n_classes', 'n_steps', 'n_features', 'scaler', 'train_X', 'train_y', 'train_ICUType', 'val_X', 'val_y', 'val_ICUType', 'test_X', 'test_y', 'test_ICUType', 'female_gender_test_X', 'female_gender_test_y', 'test_ICUType_female_gender', 'male_gender_test_X', 'male_gender_test_y', 'test_ICUType_male_gender', 'undefined_gender_test_X', 'undefined_gender_test_y', 'test_ICUType_undefined_gender', 'more_than_or_equal_to_65_test_X', 'more_than_or_equal_to_65_test_y', 'test_ICUType_more_than_or_equal_to_65', 'less_than_65_test_X', 'less_than_65_test_y', 'test_ICUType_less_than_65', 'ICUType_1_test_X', 'ICUType_1_test_y', 'test_ICUType_1', 'ICUType_2_test_X', 'ICUType_2_test_y', 'test_ICUType_2', 'ICUType_3_test_X', 'ICUType_3_test_y', 'test_ICUType_3', 'ICUType_4_test_X', 'ICUType_4_test_y', 'test_ICUType_4', 'classificacao_undefined_test_X', 'classificacao_undefined_test_y', 'test_ICUType_classificacao_undefined', 'classificacao_baixo_peso_test_X', 'classificacao_baixo_peso_test_y', 

In [41]:
# assemble the datasets for training
dataset_for_training = {
    "X": physionet2012_dataset['train_X'],
}
# assemble the datasets for validation
dataset_for_validating = {
    "X": physionet2012_dataset['val_X'],
    "X_ori": physionet2012_dataset['val_X_ori'],
}

dataset_for_testing_ori = {
    "X_ori": physionet2012_dataset['test_X_ori'],
    "female_gender_test_X_ori": physionet2012_dataset['female_gender_test_X_ori'],
    "male_gender_test_X_ori": physionet2012_dataset['male_gender_test_X_ori'],
    "undefined_gender_test_X_ori": physionet2012_dataset['undefined_gender_test_X_ori'],
    "more_than_or_equal_to_65_test_X_ori":  physionet2012_dataset['more_than_or_equal_to_65_test_X_ori'],
    "less_than_65_test_X_ori": physionet2012_dataset['less_than_65_test_X_ori'],
    "ICUType_1_test_X_ori": physionet2012_dataset['ICUType_1_test_X_ori'],
    "ICUType_2_test_X_ori": physionet2012_dataset['ICUType_2_test_X_ori'],
    "ICUType_3_test_X_ori": physionet2012_dataset['ICUType_3_test_X_ori'],
    "ICUType_4_test_X_ori": physionet2012_dataset['ICUType_4_test_X_ori'],
    "classificacao_undefined_test_X_ori": physionet2012_dataset['classificacao_undefined_test_X_ori'],
    "classificacao_baixo_peso_test_X_ori": physionet2012_dataset['classificacao_baixo_peso_test_X_ori'],
    "classificacao_normal_peso_test_X_ori": physionet2012_dataset['classificacao_normal_peso_test_X_ori'],
    "classificacao_sobrepeso_test_X_ori": physionet2012_dataset['classificacao_sobrepeso_test_X_ori'],
    "classificacao_obesidade_1_test_X_ori": physionet2012_dataset['classificacao_obesidade_1_test_X_ori'],
    "classificacao_obesidade_2_test_X_ori": physionet2012_dataset['classificacao_obesidade_2_test_X_ori'],
    "classificacao_obesidade_3_test_X_ori": physionet2012_dataset['classificacao_obesidade_3_test_X_ori']
}

# assemble the datasets for test
dataset_for_testing = {
    "X": physionet2012_dataset['test_X'],
    "female_gender_test_X": physionet2012_dataset['female_gender_test_X'],
    "male_gender_test_X": physionet2012_dataset['male_gender_test_X'],
    "undefined_gender_test_X": physionet2012_dataset['undefined_gender_test_X'],
    "more_than_or_equal_to_65_test_X":  physionet2012_dataset['more_than_or_equal_to_65_test_X'],
    "less_than_65_test_X": physionet2012_dataset['less_than_65_test_X'],
    "ICUType_1_test_X": physionet2012_dataset['ICUType_1_test_X'],
    "ICUType_2_test_X": physionet2012_dataset['ICUType_2_test_X'],
    "ICUType_3_test_X": physionet2012_dataset['ICUType_3_test_X'],
    "ICUType_4_test_X": physionet2012_dataset['ICUType_4_test_X'],
    "classificacao_undefined_test_X": physionet2012_dataset['classificacao_undefined_test_X'],
    "classificacao_baixo_peso_test_X": physionet2012_dataset['classificacao_baixo_peso_test_X'],
    "classificacao_normal_peso_test_X": physionet2012_dataset['classificacao_normal_peso_test_X'],
    "classificacao_sobrepeso_test_X": physionet2012_dataset['classificacao_sobrepeso_test_X'],
    "classificacao_obesidade_1_test_X": physionet2012_dataset['classificacao_obesidade_1_test_X'],
    "classificacao_obesidade_2_test_X": physionet2012_dataset['classificacao_obesidade_2_test_X'],
    "classificacao_obesidade_3_test_X": physionet2012_dataset['classificacao_obesidade_3_test_X']
}

## calculate the mask to indicate the ground truth positions in test_X_ori, will be used by metric funcs to evaluate models
test_X_indicating_mask = []
test_X_ori = []
for i, j in zip(dataset_for_testing_ori.values(), dataset_for_testing.values()):
    test_X_indicating_mask.append(np.isnan(i) ^ np.isnan(j))
    test_X_ori.append(np.nan_to_num(i))# metric functions do not accpet input with NaNs, hence fill NaNs with 0


teste = "Teste"


In [42]:
test_X_indicating_mask_variable = []
test_X_ori_variable = []
for i in range(len(test_X_indicating_mask)):
    test_X_indicating_mask_variable.append(test_X_indicating_mask[i].reshape(37, len(test_X_indicating_mask[i]) * 48))
    test_X_ori_variable.append(test_X_ori[i].reshape(37, len(test_X_ori[i]) * 48))

### Inicialize the models

In [43]:
saits = SAITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_layers=1,
    d_model=256,
    d_ffn=128,
    n_heads=4,
    d_k=64,
    d_v=64,
    dropout=0.1,
    ORT_weight=1,  # you can adjust the weight values of arguments ORT_weight
    # and MIT_weight to make the SAITS model focus more on one task. Usually you can just leave them to the default values, i.e. 1.
    MIT_weight=1,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    saving_path="tutorial_results/imputation/saits",
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

2025-02-03 21:14:13 [INFO]: No given device, using default device: cuda
2025-02-03 21:14:13 [INFO]: Model files will be saved to tutorial_results/imputation/saits/20250203_T211413
2025-02-03 21:14:13 [INFO]: Tensorboard file will be saved to tutorial_results/imputation/saits/20250203_T211413/tensorboard
2025-02-03 21:14:13 [INFO]: SAITS initialized with the given hyperparameters, the number of trainable parameters: 720,182


In [44]:
brits = BRITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=128,
    batch_size=32,
    # here we set epochs=10 for a quick demo, you can set it to 100 or more for better performance
    epochs=10,
    # here we set patience=3 to early stop the training if the evaluting loss doesn't decrease for 3 epoches.
    # You can leave it to defualt as None to disable early stopping.
    patience=3,
    # give the optimizer. Different from torch.optim.Optimizer, you don't have to specify model's parameters when
    # initializing pypots.optim.Optimizer. You can also leave it to default. It will initilize an Adam optimizer with lr=0.001.
    optimizer=Adam(lr=1e-3),
    # this num_workers argument is for torch.utils.data.Dataloader. It's the number of subprocesses to use for data loading.
    # Leaving it to default as 0 means data loading will be in the main process, i.e. there won't be subprocesses.
    # You can increase it to >1 if you think your dataloading is a bottleneck to your model training speed
    num_workers=0,
    # just leave it to default as None, PyPOTS will automatically assign the best device for you.
    # Set it as 'cpu' if you don't have CUDA devices. You can also set it to 'cuda:0' or 'cuda:1' if you have multiple CUDA devices, even parallelly on ['cuda:0', 'cuda:1']
    device=None,
    # set the path for saving tensorboard and trained model files
    saving_path="tutorial_results/imputation/brits",
    # only save the best model after training finished.
    # You can also set it as "better" to save models performing better ever during training.
    model_saving_strategy="best",
)

2025-02-03 21:14:13 [INFO]: No given device, using default device: cuda
2025-02-03 21:14:13 [INFO]: Model files will be saved to tutorial_results/imputation/brits/20250203_T211413
2025-02-03 21:14:13 [INFO]: Tensorboard file will be saved to tutorial_results/imputation/brits/20250203_T211413/tensorboard
2025-02-03 21:14:13 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 239,344


### Train the models

In [46]:
# train the model on the training set, and validate it on the validating set to select the best model for testing in the next step
saits.fit(train_set=dataset_for_training, val_set=dataset_for_validating)

2025-02-03 21:15:44 [INFO]: Epoch 001 - training loss: 0.5679, validation loss: 6.7518
2025-02-03 21:15:47 [INFO]: Epoch 002 - training loss: 0.5101, validation loss: 6.7300
2025-02-03 21:15:50 [INFO]: Epoch 003 - training loss: 0.4759, validation loss: 6.7068
2025-02-03 21:15:53 [INFO]: Epoch 004 - training loss: 0.4479, validation loss: 6.6785
2025-02-03 21:15:56 [INFO]: Epoch 005 - training loss: 0.4287, validation loss: 6.6738
2025-02-03 21:15:59 [INFO]: Epoch 006 - training loss: 0.4127, validation loss: 6.6652
2025-02-03 21:16:01 [INFO]: Epoch 007 - training loss: 0.4000, validation loss: 6.6616
2025-02-03 21:16:04 [INFO]: Epoch 008 - training loss: 0.3918, validation loss: 6.6614
2025-02-03 21:16:07 [INFO]: Epoch 009 - training loss: 0.3826, validation loss: 6.6523
2025-02-03 21:16:10 [INFO]: Epoch 010 - training loss: 0.3778, validation loss: 6.6492
2025-02-03 21:16:10 [INFO]: Finished training. The best model is from epoch#10.
2025-02-03 21:16:10 [INFO]: Saved the model to tut

In [None]:
# train the model on the training set, and validate it on the validating set to select the best model for testing in the next step
brits.fit(train_set=dataset_for_training, val_set=dataset_for_validating)

2025-02-03 20:42:18 [INFO]: Epoch 001 - training loss: 0.9367, validation loss: 6.7860
2025-02-03 20:43:06 [INFO]: Epoch 002 - training loss: 0.7334, validation loss: 6.7463
2025-02-03 20:43:55 [INFO]: Epoch 003 - training loss: 0.6826, validation loss: 6.7330
2025-02-03 20:44:44 [INFO]: Epoch 004 - training loss: 0.6580, validation loss: 6.7299
2025-02-03 20:45:34 [INFO]: Epoch 005 - training loss: 0.6434, validation loss: 6.7283
2025-02-03 20:46:23 [INFO]: Epoch 006 - training loss: 0.6319, validation loss: 6.7274
2025-02-03 20:47:13 [INFO]: Epoch 007 - training loss: 0.6230, validation loss: 6.7291
2025-02-03 20:48:02 [INFO]: Epoch 008 - training loss: 0.6156, validation loss: 6.7310
2025-02-03 20:48:51 [INFO]: Epoch 009 - training loss: 0.6090, validation loss: 6.7335
2025-02-03 20:48:51 [INFO]: Exceeded the training patience. Terminating the training procedure...
2025-02-03 20:48:51 [INFO]: Finished training. The best model is from epoch#6.
2025-02-03 20:48:51 [INFO]: Saved the mo

### The testing stage

In [10]:
# the testing stage, impute the originally-missing values and artificially-missing values in the test set
saits_imputation = []
for value in  dataset_for_testing.values():
   _dict = {'X':value}
   saits_results = saits.predict(_dict)
   saits_imputation.append(saits_results["imputation"])  

teste = 'Teste'


In [None]:
saits_imputation_variable = []
for i in range(len(saits_imputation)):
    saits_imputation_variable.append(saits_imputation[i].reshape(37, len(saits_imputation[i]) * 48))

teste = 'Teste'


In [13]:
# the testing stage, impute the originally-missing values and artificially-missing values in the test set
brits_imputation = []
for value in dataset_for_testing.values():
    _dict = {'X':value}
    brits_results = brits.predict(_dict)
    brits_imputation.append(brits_results["imputation"])

In [14]:
brits_imputation_variable = []
for i in range(len(brits_imputation)):
    brits_imputation_variable.append(brits_imputation[i].reshape(37, len(brits_imputation[i]) * 48))

### Calculate mean absolute error

In [15]:
testing_mae_saits_append_subgroups = []
testing_mae_saits_append_variables = []
for i in range(len(saits_imputation_variable)):
    for j in range(len(saits_imputation_variable[i])):
        testing_mae_saits_append_variables.append(calc_mae(saits_imputation_variable[i][j], test_X_ori_variable[i][j], test_X_indicating_mask_variable[i][j]))
    testing_mae_saits_append_subgroups.append(testing_mae_saits_append_variables)
    testing_mae_saits_append_variables = []
Teste = 'Teste'  

In [16]:
len(saits_imputation_variable[0][0])

115152

In [17]:
saits_imputation_variable[0][0]

array([-0.17572749, -0.33525062, -0.2920901 , ..., -0.39829695,
       -0.4246574 , -0.32629728], dtype=float32)

In [18]:
testing_mae_brits_append_subgroups = []
testing_mae_brits_append_variables = []
for i in range(len(brits_imputation_variable)):
    for j in range(len(brits_imputation_variable[i])):
      testing_mae_brits_append_variables.append(calc_mae(brits_imputation_variable[i][j], test_X_ori_variable[i][j], test_X_indicating_mask_variable[i][j]))
    testing_mae_brits_append_subgroups.append(testing_mae_brits_append_variables)
    testing_mae_brits_append_variables = []
Teste = 'Teste'      

In [19]:
subgroups = ["General", "Female", "Male", "Undefined Gender", "+65", "-65", "ICUType 1", "ICUType 2", "ICUType 3", "ICUType 4", "Undefined classification", "Low Weight", "Normal Weight", "Overweight", "Obesity 1", "Obesity 2", "Obesity 3"]

In [20]:
variables = ["ALP", "ALT", "AST", "Albumin", "BUN", "Bilirubin", "Cholesterol", "Creatinine", "DiasABP", "FiO2", "GCS", "Glucose", "HCO3", "HCT", "HR", "K", "Lactate", "MAP", "MechVent", "Mg", "NIDiasABP", "NIMAP", "NISysABP", "Na", "PaCO2", "PaO2", "Platelets", "RespRate", "SaO2", "SysABP", "Temp", "TroponinI", "TroponinT", "Urine", "WBC", "Weight", "Ph"]

In [51]:
print("SAITS - MAE")
print("************")
for i in range(len(subgroups)): 
    print(subgroups[i]) 
    print("-------------")
    for j in range(len(variables)):
        print(variables[j], ":" ,testing_mae_saits_append_subgroups[i][j])

SAITS - MAE
************
General
-------------
ALP : 0.2440735081867829
ALT : 0.22046217364373275
AST : 0.25568722882330536
Albumin : 0.29847833600882456
BUN : 0.22985173455468852
Bilirubin : 0.22725130828766607
Cholesterol : 0.24701664973419732
Creatinine : 0.2592212683114156
DiasABP : 0.243781030860427
FiO2 : 0.23853749568653815
GCS : 0.27676375468717657
Glucose : 0.2580730157772219
HCO3 : 0.2467922036158317
HCT : 0.23359712286280157
HR : 0.22216202758297737
K : 0.23554960434035374
Lactate : 0.2474192102983516
MAP : 0.24995386348025248
MechVent : 0.2518169563606991
Mg : 0.25303408847351844
NIDiasABP : 0.2252697300550535
NIMAP : 0.23511521913807404
NISysABP : 0.24801578352897705
Na : 0.2544188781201658
PaCO2 : 0.2258657992958765
PaO2 : 0.23870115539483938
Platelets : 0.23428297313636684
RespRate : 0.24681471685847184
SaO2 : 0.2476932683745194
SysABP : 0.23040656793541414
Temp : 0.23933265423486466
TroponinI : 0.2422161649604066
TroponinT : 0.23088828778385911
Urine : 0.244265709769190

In [52]:
print("BRITS - MAE")
print("************")
for i in range(len(subgroups)):
    print(subgroups[i]) 
    print("-------------")
    for j in range(len(variables)):
        print(variables[j], ":" ,testing_mae_brits_append_subgroups[i][j])

BRITS - MAE
************
General
-------------
ALP : 0.266644733489862
ALT : 0.23900266150625304
AST : 0.2854938721784185
Albumin : 0.3203138448779601
BUN : 0.24102697640531243
Bilirubin : 0.24960542526114163
Cholesterol : 0.2519937112900071
Creatinine : 0.2680821614127534
DiasABP : 0.246698476118878
FiO2 : 0.24946330367166566
GCS : 0.6473109930776697
Glucose : 0.28087676441624365
HCO3 : 0.2583291079764593
HCT : 0.24207674792433023
HR : 0.2410863112357303
K : 0.24739804172497776
Lactate : 0.25947332814140556
MAP : 0.2684142254010802
MechVent : 0.2627332565047349
Mg : 0.269234288659691
NIDiasABP : 0.2316112509677321
NIMAP : 0.24172368575241565
NISysABP : 0.2554514344344389
Na : 0.2665953580476764
PaCO2 : 0.2399055118050039
PaO2 : 0.25675305111767854
Platelets : 0.24342133381562436
RespRate : 0.2554206545499794
SaO2 : 0.26948549719958176
SysABP : 0.25779013358763925
Temp : 0.26511418636617634
TroponinI : 0.24506253875316777
TroponinT : 0.24178232451883985
Urine : 0.26819732476676317
WBC 