# Import libraries

In [1]:
import os
import sys
import pypots
import numpy as np
import benchpots
import matplotlib.pyplot as plt
from pypots.optim import Adam
from pypots.imputation import SAITS, BRITS, USGAN, GPVAE, MRNN
from pypots.utils.random import set_random_seed
from functions.toolkits import toolkits
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from MAEModify.error import calc_mae


2025-05-08 22:37:17.773499: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746754637.791887  227524 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746754637.797563  227524 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-08 22:37:17.815676: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


[34m
████████╗██╗███╗   ███╗███████╗    ███████╗███████╗██████╗ ██╗███████╗███████╗    █████╗ ██╗
╚══██╔══╝██║████╗ ████║██╔════╝    ██╔════╝██╔════╝██╔══██╗██║██╔════╝██╔════╝   ██╔══██╗██║
   ██║   ██║██╔████╔██║█████╗█████╗███████╗█████╗  ██████╔╝██║█████╗  ███████╗   ███████║██║
   ██║   ██║██║╚██╔╝██║██╔══╝╚════╝╚════██║██╔══╝  ██╔══██╗██║██╔══╝  ╚════██║   ██╔══██║██║
   ██║   ██║██║ ╚═╝ ██║███████╗    ███████║███████╗██║  ██║██║███████╗███████║██╗██║  ██║██║
   ╚═╝   ╚═╝╚═╝     ╚═╝╚══════╝    ╚══════╝╚══════╝╚═╝  ╚═╝╚═╝╚══════╝╚══════╝╚═╝╚═╝  ╚═╝╚═╝
ai4ts v0.0.3 - building AI for unified time-series analysis, https://time-series.ai [0m



# Load Dataset

In [2]:
set_random_seed()
physionet2012_dataset = benchpots.datasets.preprocess_physionet2012(subset="all", rate=0.1)
print(physionet2012_dataset.keys())

2025-05-08 22:37:23 [INFO]: Have set the random seed as 2022 for numpy and pytorch.
2025-05-08 22:37:23 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-05-08 22:37:23 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-05-08 22:37:23 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...
2025-05-08 22:37:23 [INFO]: Loaded successfully!
2025-05-08 22:37:38 [INFO]: 68807 values masked out in the val set as ground truth, take 9.97% of the original observed values
2025-05-08 22:37:38 [INFO]: 86319 values masked out in the test set as ground truth, take 9.99% of the original observed values
2025-05-08 22:37:38 [INFO]: Total sample number: 11988
2025-05-08 22:37:38 [INFO]: Training set size: 7671 (63.99%)
2025-05-08 22:37:38 [INFO]: Validation set size: 

dict_keys(['n_classes', 'n_steps', 'n_features', 'scaler', 'train_X', 'train_y', 'train_ICUType', 'val_X', 'val_y', 'val_ICUType', 'test_X', 'test_y', 'test_ICUType', 'val_X_ori', 'test_X_ori'])


In [3]:
dataset_for_training = {
    "X": physionet2012_dataset['train_X'],
}

dataset_for_validating = {
    "X": physionet2012_dataset['val_X'],
    "X_ori": physionet2012_dataset['val_X_ori'],
}

dataset_for_testing = {
    "X": physionet2012_dataset['test_X'],
}

test_X_indicating_mask = np.isnan(physionet2012_dataset['test_X_ori']) ^ np.isnan(physionet2012_dataset['test_X'])
test_X_ori = np.nan_to_num(physionet2012_dataset['test_X_ori']) 

# Train/Load Models

## SAITS

In [4]:
saits = SAITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_layers=1,
    d_model=256,
    d_ffn=128,
    n_heads=4,
    d_k=64,
    d_v=64,
    dropout=0.1,
    ORT_weight=1,  
    MIT_weight=1,
    batch_size=32,
    epochs=10,
    patience=3,
    optimizer=Adam(lr=1e-3),
    num_workers=0,
    device=None,
    model_saving_strategy="best",
)

2025-05-08 22:37:42 [INFO]: No given device, using default device: cpu
2025-05-08 22:37:42 [INFO]: Using customized MAE as the training loss function.
2025-05-08 22:37:42 [INFO]: Using customized MSE as the validation metric function.
2025-05-08 22:37:42 [INFO]: SAITS initialized with the given hyperparameters, the number of trainable parameters: 720,182


In [5]:
saits.load("../mae/tutorial_results/imputation/saits/20250422_T181642/SAITS.pypots")

2025-05-08 22:37:44 [INFO]: Model loaded successfully from ../mae/tutorial_results/imputation/saits/20250422_T181642/SAITS.pypots


## BRITS

In [6]:
brits = BRITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=128,
    batch_size=32,
    epochs=10,
    patience=3,
    optimizer=Adam(lr=1e-3),
    num_workers=0,
    device=None,
    model_saving_strategy="best",
)

2025-05-08 22:37:46 [INFO]: No given device, using default device: cpu
2025-05-08 22:37:46 [INFO]: Using customized MAE as the training loss function.
2025-05-08 22:37:46 [INFO]: Using customized MSE as the validation metric function.
2025-05-08 22:37:46 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 239,344


In [7]:
brits.load("../mae/tutorial_results/imputation/brits/20250422_T181643/BRITS.pypots")

2025-05-08 22:37:48 [INFO]: Model loaded successfully from ../mae/tutorial_results/imputation/brits/20250422_T181643/BRITS.pypots


## US-GAN

In [8]:
us_gan = USGAN(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=256,
    lambda_mse=1,
    dropout=0.1,
    G_steps=1,
    D_steps=1,
    batch_size=32,
    epochs=10,
    patience=3,
    G_optimizer=Adam(lr=1e-3),
    D_optimizer=Adam(lr=1e-3),
    num_workers=0,
    device=None,
    model_saving_strategy="best",
)

2025-05-08 22:37:48 [INFO]: No given device, using default device: cpu
2025-05-08 22:37:48 [INFO]: USGAN initialized with the given hyperparameters, the number of trainable parameters: 1,258,517


In [9]:
us_gan.load("../mae/tutorial_results/imputation/us_gan/20250422_T181643/USGAN.pypots")

2025-05-08 22:37:50 [INFO]: Model loaded successfully from ../mae/tutorial_results/imputation/us_gan/20250422_T181643/USGAN.pypots


## GP-VAE

In [10]:
gp_vae = GPVAE(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    latent_size=37,
    encoder_sizes=(128,128),
    decoder_sizes=(256,256),
    kernel="cauchy",
    beta=0.2,
    M=1,
    K=1,
    sigma=1.005,
    length_scale=7.0,
    kernel_scales=1,
    window_size=24,
    batch_size=32,
    epochs=10,
    patience=3,
    optimizer=Adam(lr=1e-3),
    num_workers=0,
    device=None,
    model_saving_strategy="best",
)

2025-05-08 22:37:52 [INFO]: No given device, using default device: cpu
2025-05-08 22:37:52 [INFO]: GPVAE initialized with the given hyperparameters, the number of trainable parameters: 229,652


In [11]:
gp_vae.load("../mae/tutorial_results/imputation/gp_vae/20250422_T181643/GPVAE.pypots")

2025-05-08 22:37:54 [INFO]: Model loaded successfully from ../mae/tutorial_results/imputation/gp_vae/20250422_T181643/GPVAE.pypots


## MRNN

In [12]:
mrnn = MRNN(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    rnn_hidden_size=128,
    epochs=10,
    patience=3,
    optimizer=Adam(lr=1e-3),
    num_workers=0,
    device=None,
    model_saving_strategy="best",
)

2025-05-08 22:37:56 [INFO]: No given device, using default device: cpu
2025-05-08 22:37:56 [INFO]: Using customized RMSE as the training loss function.
2025-05-08 22:37:56 [INFO]: Using customized MSE as the validation metric function.
2025-05-08 22:37:56 [INFO]: MRNN initialized with the given hyperparameters, the number of trainable parameters: 107,951


In [13]:
mrnn.load("../mae/tutorial_results/imputation/mrnn/20250422_T181643/MRNN.pypots")

2025-05-08 22:37:58 [INFO]: Model loaded successfully from ../mae/tutorial_results/imputation/mrnn/20250422_T181643/MRNN.pypots


# Imputation models

## SAITS

In [14]:
saits_results = saits.predict(dataset_for_testing)
saits_imputation = saits_results["imputation"]

## BRITS

In [15]:
brits_results = brits.predict(dataset_for_testing)
brits_imputation = brits_results["imputation"]

## US-GAN

In [16]:
us_gan_results = us_gan.predict(dataset_for_testing)
us_gan_imputation = us_gan_results["imputation"]

## GP-VAE

In [17]:
gp_vae_results = gp_vae.predict(dataset_for_testing)
gp_vae_imputation = gp_vae_results["imputation"]

## MRNN

In [18]:
mrnn_results = mrnn.predict(dataset_for_testing)
mrnn_imputation = mrnn_results["imputation"]

# AE/MAE Models

## SAITS

In [19]:
saits_mae, saits_ae = calc_mae(
    saits_imputation,
    test_X_ori,
    test_X_indicating_mask,
)

## BRITS

In [20]:
brits_mae, brits_ae = calc_mae(
    brits_imputation,
    test_X_ori,
    test_X_indicating_mask,
)

## US-GAN

In [21]:
usgan_mae, usgan_ae = calc_mae(
    us_gan_imputation,
    test_X_ori,
    test_X_indicating_mask,
)

## GP-VAE

In [22]:
gp_vae_imputation = np.squeeze(gp_vae_imputation, axis=1)

gpvae_mae, gpvae_ae = calc_mae(
    gp_vae_imputation,
    test_X_ori,
    test_X_indicating_mask,
)

## MRNN

In [23]:
mrnn_mae, mrnn_ae = calc_mae(
    mrnn_imputation,
    test_X_ori,
    test_X_indicating_mask,
)

# Bootstrap

## SAITS

In [24]:
saits_ae = saits_ae.reshape(len(saits_ae) * 48 * 37)

In [25]:
saits_mask = test_X_indicating_mask.reshape(len(test_X_indicating_mask) * 48 * 37) 

In [None]:
results_bootstrap_saits_general = toolkits.bootstrap_v3(saits_ae, saits_mask, 9000)

#### Calculating lower bound and upper bound

In [None]:
lower_bounds_saits_general, upper_bounds_saits_general = toolkits.calc_lower_and_upper_bound_percentile(results_bootstrap_saits_general)

print(lower_bounds_saits_general)
print(upper_bounds_saits_general)

#### Mean values of lower bound and upper bound


In [None]:
mean_values_ci_saits_general = toolkits.calc_mean_values_ci(lower_bounds_saits_general, upper_bounds_saits_general)

print(mean_values_ci_saits_general)

## BRITS

In [29]:
brits_ae = brits_ae.reshape(len(brits_ae) * 48 * 37)

In [30]:
brits_mask = test_X_indicating_mask.reshape(len(test_X_indicating_mask) * 48 * 37) 

In [None]:
results_bootstrap_brits_general = toolkits.bootstrap_v3(brits_ae, brits_mask, 9000)

#### Calculating lower bound and upper bound

In [None]:
lower_bounds_brits_general, upper_bounds_brits_general = toolkits.calc_lower_and_upper_bound_percentile(results_bootstrap_brits_general)

print(lower_bounds_brits_general)
print(upper_bounds_brits_general)

#### Mean values of lower bound and upper bound


In [None]:
mean_values_ci_brits_general = toolkits.calc_mean_values_ci(lower_bounds_brits_general, upper_bounds_brits_general)

print(mean_values_ci_brits_general)

## USGAN

In [34]:
usgan_ae = usgan_ae.reshape(len(usgan_ae) * 48 * 37)

In [35]:
usgan_mask = test_X_indicating_mask.reshape(len(test_X_indicating_mask) * 48 * 37) 

In [None]:
results_bootstrap_usgan_general = toolkits.bootstrap_v3(usgan_ae, usgan_mask, 9000)

#### Calculating lower bound and upper bound

In [None]:
lower_bounds_usgan_general, upper_bounds_usgan_general = toolkits.calc_lower_and_upper_bound_percentile(results_bootstrap_usgan_general)

print(lower_bounds_usgan_general)
print(upper_bounds_usgan_general)

#### Mean values of lower bound and upper bound


In [None]:
mean_values_ci_usgan_general = toolkits.calc_mean_values_ci(lower_bounds_usgan_general, upper_bounds_usgan_general)

print(mean_values_ci_usgan_general)

## GP-VAE

In [39]:
gpvae_ae = gpvae_ae.reshape(len(gpvae_ae)*48*37)

In [40]:
gpvae_mask = test_X_indicating_mask.reshape(len(test_X_indicating_mask) * 48 * 37) 

In [None]:
results_bootstrap_gpvae_general = toolkits.bootstrap_v3(gpvae_ae, gpvae_mask, 9000)

#### Calculating lower bound and upper bound

In [None]:
lower_bounds_gpvae_general, upper_bounds_gpvae_general = toolkits.calc_lower_and_upper_bound_percentile(results_bootstrap_gpvae_general)

print(lower_bounds_gpvae_general)
print(upper_bounds_gpvae_general)

#### Mean values of lower bound and upper bound


In [None]:
mean_values_ci_gpvae_general = toolkits.calc_mean_values_ci(lower_bounds_gpvae_general, upper_bounds_gpvae_general)

print(mean_values_ci_gpvae_general)

## MRNN

In [44]:
mrnn_ae = mrnn_ae.reshape(len(mrnn_ae)*48*37)

In [45]:
mrnn_mask = test_X_indicating_mask.reshape(len(test_X_indicating_mask) * 48 * 37) 

In [None]:
results_bootstrap_mrnn_general = toolkits.bootstrap_v3(mrnn_ae, mrnn_mask, 9000)

#### Calculating lower bound and upper bound

In [None]:
lower_bounds_mrnn_general, upper_bounds_mrnn_general = toolkits.calc_lower_and_upper_bound_percentile(results_bootstrap_mrnn_general)

print(lower_bounds_mrnn_general)
print(upper_bounds_mrnn_general)

#### Mean values of lower bound and upper bound


In [None]:
mean_values_ci_mrnn_general = toolkits.calc_mean_values_ci(lower_bounds_mrnn_general, upper_bounds_mrnn_general)

print(mean_values_ci_mrnn_general)