In [2]:
import numpy as np
from sklearn.preprocessing import StandardScaler
from pygrinder import mcar, calc_missing_rate
from benchpots.datasets import preprocess_physionet2012

In [2]:
data = preprocess_physionet2012(subset='set-a',rate=0.1) # Our ecosystem libs will automatically download and extract it
train_X, val_X, test_X = data["train_X"], data["val_X"], data["test_X"]
print(train_X.shape)  # (n_samples, n_steps, n_features)
print(val_X.shape)  # samples (n_samples) in train set and val set are different, but they have the same sequence len (n_steps) and feature dim (n_features)
print(f"We have {calc_missing_rate(train_X):.1%} values missing in train_X") 

2025-06-19 22:51:36 [INFO]: You're using dataset physionet_2012, please cite it properly in your work. You can find its reference information at the below link: 
https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/physionet_2012
2025-06-19 22:51:36 [INFO]: Dataset physionet_2012 has already been downloaded. Processing directly...
2025-06-19 22:51:36 [INFO]: Dataset physionet_2012 has already been cached. Loading from cache directly...
2025-06-19 22:51:36 [INFO]: Loaded successfully!
2025-06-19 22:51:38 [INFO]: 22900 values masked out in the val set as ground truth, take 10.04% of the original observed values
2025-06-19 22:51:38 [INFO]: 28501 values masked out in the test set as ground truth, take 9.90% of the original observed values
2025-06-19 22:51:38 [INFO]: Total sample number: 3997
2025-06-19 22:51:38 [INFO]: Training set size: 2557 (63.97%)
2025-06-19 22:51:38 [INFO]: Validation set size: 640 (16.01%)
2025-06-19 22:51:38 [INFO]: Test set size: 800 (20.02%)
2025-06-19 22:5

(2557, 48, 37)
(640, 48, 37)
We have 79.7% values missing in train_X


In [3]:
train_set = {"X": train_X}  # in training set, simply put the incomplete time series into it
val_set = {
    "X": val_X,
    "X_ori": data["val_X_ori"],  # in validation set, we need ground truth for evaluation and picking the best model checkpoint
}
test_set = {"X": test_X}  # in test set, only give the testing incomplete time series for model to impute
test_X_ori = data["test_X_ori"]  # test_X_ori bears ground truth for evaluation
indicating_mask = np.isnan(test_X) ^ np.isnan(test_X_ori)  # mask indicates the values that are missing in X but not in X_ori, i.e. where the gt values are 

In [4]:
from pypots.imputation import SAITS  # import the model you want to use
from pypots.nn.functional import calc_mae
saits = SAITS(n_steps=train_X.shape[1], n_features=train_X.shape[2], n_layers=2, d_model=256, n_heads=4, d_k=64, d_v=64, d_ffn=128, dropout=0.1, epochs=5)
saits.fit(train_set, val_set)  # train the model on the dataset
imputation = saits.impute(test_set)  # impute the originally-missing values and artificially-missing values
mae = calc_mae(imputation, np.nan_to_num(test_X_ori), indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)
saits.save("save_it_here/saits_physionet2012.pypots")  # save the model for future use
saits.load("save_it_here/saits_physionet2012.pypots")  # reload the serialized model file for following imputation or training

2025-06-19 22:52:11 [INFO]: No given device, using default device: cpu
2025-06-19 22:52:11 [INFO]: Using customized MAE as the training loss function.
2025-06-19 22:52:11 [INFO]: Using customized MSE as the validation metric function.
2025-06-19 22:52:11 [INFO]: SAITS initialized with the given hyperparameters, the number of trainable parameters: 1,378,358


[34m
████████╗██╗███╗   ███╗███████╗    ███████╗███████╗██████╗ ██╗███████╗███████╗    █████╗ ██╗
╚══██╔══╝██║████╗ ████║██╔════╝    ██╔════╝██╔════╝██╔══██╗██║██╔════╝██╔════╝   ██╔══██╗██║
   ██║   ██║██╔████╔██║█████╗█████╗███████╗█████╗  ██████╔╝██║█████╗  ███████╗   ███████║██║
   ██║   ██║██║╚██╔╝██║██╔══╝╚════╝╚════██║██╔══╝  ██╔══██╗██║██╔══╝  ╚════██║   ██╔══██║██║
   ██║   ██║██║ ╚═╝ ██║███████╗    ███████║███████╗██║  ██║██║███████╗███████║██╗██║  ██║██║
   ╚═╝   ╚═╝╚═╝     ╚═╝╚══════╝    ╚══════╝╚══════╝╚═╝  ╚═╝╚═╝╚══════╝╚══════╝╚═╝╚═╝  ╚═╝╚═╝
ai4ts v0.0.3 - building AI for unified time-series analysis, https://time-series.ai [0m





ValueError: Something is wrong. best_loss is NaN/Inf after training.

# Training SAITS for VitalDb data

In [9]:
!pip install vitaldb



## Load & resample all VitalDB files

In [17]:
from pathlib import Path
import os, vitaldb as vdb, numpy as np, pandas as pd

# ---------- settings ----------
vital_dir      = Path("/Users/muhammadaneequz.zaman/Dropbox/Digital Twin (Umer Huzaifa)/vitalDB_v1")
track_keep     = ["SNUADC/ART ", "SNUADC/ECG_II", "SNUADC/ECG_V5 ", 
                  "SNUADC/PLETH", "Primus/CO2", "BIS/EEG1_WAV", "BIS/EEG2_WAV"]   # pick tracks present in *every* case
target_fs      = "1S"                                           # 1-second grid
sequence_len   = 60*10                                          # 10-min snippets → n_steps = 600
# ------------------------------

def read_one(file_path, tracks_keep):
    """
    Return a pandas DataFrame whose index is the native VitalDB timestamp
    (datetime) and whose columns are the requested track names.
    Missing samples remain as NaN.
    """
    all_tracks = vdb.vital_trks(str(file_path))
    #print(all_tracks)
    print(tracks_keep)
    #numeric_tracks = [t for t in all_tracks if t in tracks_keep]

    # if not numeric_tracks:
    #     raise ValueError("none of the requested tracks in this file")
    
    return vdb.vital_recs(
        str(file_path),
        # track_names=all_tracks,
        track_names=tracks_keep,
        return_timestamp=False,      # keep absolute clock time
        return_datetime=False,
        return_pandas=True,
    )

all_cases = []                      # dict: filename  -> DataFrame
bad_files = []
for f in sorted(vital_dir.glob("*.vital")):
    try:
        df = read_one(f, track_keep)
        df_numeric = df.apply(pd.to_numeric, errors="coerce")  # strings -> NaN
        all_cases.append(df_numeric.to_numpy(dtype=np.float32))   # shape (sequence_len, n_features)
    except Exception as e:
        print(f"skip {f.name}: {e}")
        bad_files.append(f.name)

dataset = np.stack(all_cases, axis=0)               # ==> (n_samples, n_steps, n_features)
print("Dataset shape:", dataset.shape, "  (skipped", len(bad_files), "files)")

['SNUADC/ART ', 'SNUADC/ECG_II', 'SNUADC/ECG_V5 ', 'SNUADC/PLETH', 'Primus/CO2', 'BIS/EEG1_WAV', 'BIS/EEG2_WAV']
Dataset shape: (1, 5770575, 7)   (skipped 0 files)


## Train / val / test split

In [18]:
from sklearn.model_selection import train_test_split

X            = dataset.squeeze(0)          # → (5_771_049, 80)
window       = 600                         # 10 minutes if you resample to 1 Hz
stride       = 600                         # non-overlapping; use <window for overlap
segments = [
    X[i : i + window]
    for i in range(0, X.shape[0] - window + 1, stride)
]
segments = np.stack(segments)              # (n_segments, 600, 80)
print("segments shape:", segments.shape)

train_X, test_X = train_test_split(segments,  test_size=0.15, random_state=42)
train_X,  val_X = train_test_split(train_X, test_size=0.15, random_state=42)

print("train", train_X.shape, "val", val_X.shape, "test", test_X.shape)


segments shape: (9617, 600, 7)
train (6947, 600, 7) val (1227, 600, 7) test (1443, 600, 7)


## Add extra synthetic missingness on val set

In [19]:
from pygrinder import mcar, calc_missing_rate

val_X_ori = val_X.copy()             # keep a pristine copy
val_X     = mcar(val_X, p=0.10)   # mask-at-random 10 %

test_X_ori = test_X.copy()           # ditto for the test set
indicating_mask = np.isnan(test_X) ^ np.isnan(test_X_ori)
print(f"Real miss rate train  : {calc_missing_rate(train_X):.1%}")
print(f"Real+fake miss rate val: {calc_missing_rate(val_X):.1%}")

Real miss rate train  : 28.6%
Real+fake miss rate val: 35.7%


## Wrap in the dictionaries SAITS expects

In [20]:
train_set = {"X": train_X}
val_set   = {"X": val_X, "X_ori": val_X_ori}
test_set  = {"X": test_X}

## Instantiate, train, evaluate just like the example

In [21]:
from pypots.imputation import SAITS
from pypots.nn.functional import calc_mae

saits = SAITS(
    n_steps   = train_X.shape[1],
    n_features= train_X.shape[2],
    n_layers  = 2,
    d_model   = 256,
    n_heads   = 4,
    d_k       = 64,
    d_v       = 64,
    d_ffn     = 128,
    dropout   = 0.1,
    epochs    = 20,
    patience  = 5,                  # early-stop patience (optional)
    device    = "cpu"
    #device    = "cuda:0"
)

2025-06-21 12:54:51 [INFO]: Using the given device: cpu
2025-06-21 12:54:51 [INFO]: Using customized MAE as the training loss function.
2025-06-21 12:54:51 [INFO]: Using customized MSE as the validation metric function.
2025-06-21 12:54:51 [INFO]: SAITS initialized with the given hyperparameters, the number of trainable parameters: 1,331,942


In [None]:
saits.fit(train_set, val_set)

# ---- test-time imputation ----
imputation = saits.impute(test_set)               # same shape as test_X
mae        = calc_mae(imputation, np.nan_to_num(test_X_ori), indicating_mask)
print("MAE on held-out values:", mae)

In [23]:
saits.save("models/saits_vitaldb.pypots", overwrite=True)

2025-06-22 13:20:09 [INFO]: Saved the model to models/saits_vitaldb.pypots


In [9]:
!ls

CITATION.cff                  [34mmodels[m[m
conda_env_dependencies.yml    [34mNNI_tuning[m[m
[34mconfigs[m[m                       Paper_SAITS.pdf
[34mdataset_generating_scripts[m[m    README.md
[34mfigs[m[m                          run_models.py
Global_Config.py              [34msave_it_here[m[m
LICENSE                       Simple_example.ipynb
[34mmodeling[m[m                      Simple_RNN_on_imputed_data.py


In [8]:
saits.load("models/saits_vitaldb.pypots")

  loaded_file = torch.load(path, map_location=map_location)
2025-06-20 11:49:36 [INFO]: Model loaded successfully from models/saits_vitaldb.pypots


## IMPUTE THE ORIGINAL DATA  (train + val + test, NaNs only)

In [None]:
# Concatenate all three splits so we fill every real gap in one go
orig_concat = np.concatenate([train_X, val_X_ori, test_X_ori], axis=0)
orig_imputed = saits.impute({"X": orig_concat})        # <-- returns np.ndarray

# You can now split it back if you want
n_train = train_X.shape[0]
n_val   = val_X_ori.shape[0]
imputed_train = orig_imputed[:n_train]
imputed_val_full = orig_imputed[n_train:n_train+n_val]  # val set with real NaNs filled
imputed_test  = orig_imputed[n_train+n_val:]

## IMPUTE THE SYNTHETICALLY MASKED VALIDATION SET

In [24]:
# val_X has *extra* 10 % MCAR holes; we already built val_set = {"X": val_X}
imputed_val_masked = saits.impute(val_set)              # same shape as val_X
# evaluate MAE on those artificial holes
masked_mae = calc_mae(imputed_val_masked, 
                      np.nan_to_num(val_X_ori), 
                      np.isnan(val_X) ^ np.isnan(val_X_ori))
print("MAE on synthetically missing points in val set:", masked_mae)

MAE on synthetically missing points in val set: 4.696758137825921


## SHOW 15 RANDOM IMPUTATIONS vs. GROUND-TRUTH

In [27]:
import random, pandas as pd

feature_names = track_keep                        # your 7 channels in that order
mask_idx = np.where((np.isnan(val_X)) & ~np.isnan(val_X_ori))  # positions you hid
n_show = min(30, mask_idx[0].size)               # show up to 15 rows
rows = random.sample(range(mask_idx[0].size), n_show)

print(mask_idx)
records = []
for k in rows:
    s, t, f = mask_idx[0][k], mask_idx[1][k], mask_idx[2][k]
    records.append({
        "sample#":    s,
        "time_step":  t,
        "channel":    feature_names[f],
        "ground_truth": float(val_X_ori[s, t, f]),
        "imputed":     float(imputed_val_masked[s, t, f]),
        "abs_error":   abs(val_X_ori[s, t, f] - imputed_val_masked[s, t, f]),
    })

comparison_df = pd.DataFrame(records)
print("\n===  SAITS imputation on synthetic holes (random 30)  ===")
print(comparison_df.round(3).to_string(index=False))

(array([   0,    0,    0, ..., 1226, 1226, 1226]), array([  2,   5,   6, ..., 592, 594, 596]), array([4, 6, 3, ..., 5, 1, 1]))

===  SAITS imputation on synthetic holes (random 30)  ===
 sample#  time_step       channel  ground_truth  imputed  abs_error
     138        587  BIS/EEG1_WAV        42.250   35.298   6.952000
    1091        517  SNUADC/PLETH        29.260   30.146   0.886000
     447        549    Primus/CO2        22.275   23.010   0.735000
     276         98  BIS/EEG2_WAV        13.250    9.953   3.297000
     370        592 SNUADC/ECG_II        -0.049    0.000   0.049000
     374        449 SNUADC/ECG_II         0.327    0.016   0.311000
     541        308  BIS/EEG2_WAV        11.700   29.209  17.509001
     147        174  SNUADC/PLETH        35.184   35.505   0.320000
     803        438  BIS/EEG2_WAV        23.400   23.049   0.351000
      43        469  BIS/EEG1_WAV        48.500   53.088   4.588000
     975        438 SNUADC/ECG_II         0.781    0.014   0.76700