In [1]:
import torch
from Models.Baseline_Methods.last_imputation import last_imputation
from Models.Baseline_Methods.mean_imputation import mean_imputation
from Models.Baseline_Methods.median_imputation import median_imputation
from Models.Baseline_Methods.zero_imputation import zero_imputation
from Utils.metric_service import evaluate_metrics
from Models.BRITS.delta_service import make_delta_simplified
from Models.SAITS.saits_utils import saits_evaluate
from Models.BRITS.brits_utils import brits_evaluate
from Models.MRNN.mrnn_utils import mrnn_evaluate
import statistics

In [2]:
device = 'cpu'
if torch.cuda.is_available():
    device = "cuda:0"   # change to another gpu if needed

In [3]:
def metric_list_to_eval(metric_list):
    return {
    'mae': '{mean} +/- {std}'.format(mean = statistics.mean([m['imputation mae'].item() for m in metric_list]), std = statistics.stdev([m['imputation mae'].item() for m in metric_list])),
    'rmse': '{mean} +/- {std}'.format(mean = statistics.mean([m['imputation rmse'].item() for m in metric_list]), std = statistics.stdev([m['imputation rmse'].item() for m in metric_list])),
    'mre': '{mean} +/- {std}'.format(mean = statistics.mean([m['imputation mre'].item() for m in metric_list]), std = statistics.stdev([m['imputation mre'].item() for m in metric_list]))}

# ALL ICU DATA

In [4]:
all_icu_dic = torch.load('all_feat_all_icu_compatible_unique')
#val_set_all_icu = all_icu_dic['validation_set'].to(device)
#val_missing_mask_all_icu = all_icu_dic['missing_mask_validation'].to(device)
#val_indicating_mask_all_icu = all_icu_dic['indicating_mask_validation'].to(device)
num_features = 10
train_set_all_icu = all_icu_dic['train_set'].to(device)
train_missing_mask_all_icu = all_icu_dic['missing_mask_train'].to(device)
test_set_all_icu = all_icu_dic['test_set'].to(device)
test_missing_mask_all_icu = all_icu_dic['missing_mask_test'].to(device)
test_indicating_mask_all_icu = all_icu_dic['indicating_mask_test'].to(device)
test_deltas = make_delta_simplified(all_icu_dic['missing_mask_test']).to(device)

In [5]:
# baseline methods
model_last = last_imputation()
model_mean = mean_imputation(num_features)
model_mean.train(train_set_all_icu*train_missing_mask_all_icu,train_missing_mask_all_icu)
model_median = median_imputation(num_features)
model_median.train(train_set_all_icu*train_missing_mask_all_icu,train_missing_mask_all_icu)
model_zero = zero_imputation()

In [6]:
imputation_last = model_last.impute(test_set_all_icu*test_missing_mask_all_icu,test_missing_mask_all_icu)
imputation_mean = model_mean.impute(test_set_all_icu*test_missing_mask_all_icu,test_missing_mask_all_icu)
imputation_median = model_median.impute(test_set_all_icu*test_missing_mask_all_icu,test_missing_mask_all_icu)
imputation_zero = model_zero.impute(test_set_all_icu*test_missing_mask_all_icu,test_missing_mask_all_icu)

In [7]:
eval_dic_all_icu = {
    'last': evaluate_metrics(imputation_last,test_set_all_icu, test_indicating_mask_all_icu),
    'mean':  evaluate_metrics(imputation_mean, test_set_all_icu, test_indicating_mask_all_icu),
    'median': evaluate_metrics(imputation_median, test_set_all_icu, test_indicating_mask_all_icu),
    'zero': evaluate_metrics(imputation_zero, test_set_all_icu, test_indicating_mask_all_icu)
}

In [8]:
eval_dic_all_icu

{'last': {'rmse': tensor(0.5219, device='cuda:0'),
  'mae': tensor(0.1690, device='cuda:0'),
  'mre': tensor(0.5690, device='cuda:0')},
 'mean': {'rmse': tensor(0.6580, device='cuda:0'),
  'mae': tensor(0.2972, device='cuda:0'),
  'mre': tensor(1.0005, device='cuda:0')},
 'median': {'rmse': tensor(0.6595, device='cuda:0'),
  'mae': tensor(0.2944, device='cuda:0'),
  'mre': tensor(0.9913, device='cuda:0')},
 'zero': {'rmse': tensor(0.6580, device='cuda:0'),
  'mae': tensor(0.2970, device='cuda:0'),
  'mre': tensor(1., device='cuda:0')}}

In [9]:
saits_all_icu_list = torch.load('saits_all_feat_all_icu_list')
brits_all_icu_list = torch.load('brits_all_feat_all_icu_list')
mrnn_all_icu_list = torch.load('mrnn_all_feat_all_icu_list')

In [10]:
saits_all_icu_metric_list = []
for t in saits_all_icu_list:
    model = t['model'].to(device)
    saits_all_icu_metric_list.append(saits_evaluate(model, test_set_all_icu, test_missing_mask_all_icu, test_indicating_mask_all_icu))
brits_all_icu_metric_list = []
for t in brits_all_icu_list:
    model = t['model'].to(device)
    brits_all_icu_metric_list.append(brits_evaluate(model, test_set_all_icu, test_missing_mask_all_icu, test_indicating_mask_all_icu, test_deltas))
mrnn_all_icu_metric_list = []
for t in mrnn_all_icu_list:
    model = t['model']
    mrnn_all_icu_metric_list.append(mrnn_evaluate(model, test_set_all_icu, test_missing_mask_all_icu, test_indicating_mask_all_icu, test_deltas))



In [11]:
eval_dic_all_icu['SAITS'] = metric_list_to_eval(saits_all_icu_metric_list)

eval_dic_all_icu['BRITS'] = metric_list_to_eval(brits_all_icu_metric_list)
eval_dic_all_icu['MRNN'] = metric_list_to_eval(mrnn_all_icu_metric_list)

In [12]:
eval_dic_all_icu

{'last': {'rmse': tensor(0.5219, device='cuda:0'),
  'mae': tensor(0.1690, device='cuda:0'),
  'mre': tensor(0.5690, device='cuda:0')},
 'mean': {'rmse': tensor(0.6580, device='cuda:0'),
  'mae': tensor(0.2972, device='cuda:0'),
  'mre': tensor(1.0005, device='cuda:0')},
 'median': {'rmse': tensor(0.6595, device='cuda:0'),
  'mae': tensor(0.2944, device='cuda:0'),
  'mre': tensor(0.9913, device='cuda:0')},
 'zero': {'rmse': tensor(0.6580, device='cuda:0'),
  'mae': tensor(0.2970, device='cuda:0'),
  'mre': tensor(1., device='cuda:0')},
 'SAITS': {'mae': '0.1317450985312462 +/- 0.0007521993712309375',
  'rmse': '0.45556535124778746 +/- 0.0037330136215141447',
  'mre': '0.44357359409332275 +/- 0.0025325938230718677'},
 'BRITS': {'mae': '0.13817179650068284 +/- 0.0007273374806953895',
  'rmse': '0.4558850258588791 +/- 0.001058615201040269',
  'mre': '0.4652116984128952 +/- 0.0024488801004228264'},
 'MRNN': {'mae': '0.1536129578948021 +/- 0.004476994229747861',
  'rmse': '0.482818284630775

# Heart Data

In [13]:
heart_only_dic = torch.load('all_feat_heart_only_compatible_unique')
#val_set_all_icu = all_icu_dic['validation_set'].to(device)
#val_missing_mask_all_icu = all_icu_dic['missing_mask_validation'].to(device)
#val_indicating_mask_all_icu = all_icu_dic['indicating_mask_validation'].to(device)
num_features = 10
train_set_heart_only = heart_only_dic['train_set'].to(device)
train_missing_mask_heart_only = heart_only_dic['missing_mask_train'].to(device)
test_set_heart_only = heart_only_dic['test_set'].to(device)
test_missing_mask_heart_only = heart_only_dic['missing_mask_test'].to(device)
test_indicating_mask_heart_only = heart_only_dic['indicating_mask_test'].to(device)
test_deltas_heart_only = make_delta_simplified(heart_only_dic['missing_mask_test']).to(device)

In [14]:
# baseline methods
model_last = last_imputation()
model_mean_heart = mean_imputation(num_features)
model_mean.train(train_set_heart_only*train_missing_mask_heart_only,train_missing_mask_heart_only)
model_median_heart = median_imputation(num_features)
model_median.train(train_set_heart_only*train_missing_mask_heart_only,train_missing_mask_heart_only)
model_zero = zero_imputation()

In [15]:
imputation_last_heart = model_last.impute(test_set_heart_only*test_missing_mask_heart_only,test_missing_mask_heart_only)
imputation_mean_heart = model_mean.impute(test_set_heart_only*test_missing_mask_heart_only,test_missing_mask_heart_only)
imputation_median_heart = model_median.impute(test_set_heart_only*test_missing_mask_heart_only,test_missing_mask_heart_only)
imputation_zero_heart = model_zero.impute(test_set_heart_only*test_missing_mask_heart_only,test_missing_mask_heart_only)

In [16]:
eval_dic_heart_only = {
    'last': evaluate_metrics(imputation_last_heart,test_set_heart_only, test_indicating_mask_heart_only),
    'mean':  evaluate_metrics(imputation_mean_heart, test_set_heart_only, test_indicating_mask_heart_only),
    'median': evaluate_metrics(imputation_median_heart, test_set_heart_only, test_indicating_mask_heart_only),
    'zero': evaluate_metrics(imputation_zero_heart, test_set_heart_only, test_indicating_mask_heart_only)
}

In [17]:
saits_heart_only_list = torch.load('saits_all_feat_heart_only_list')
brits_heart_only_list= torch.load('brits_all_feat_heart_only_list')
mrnn_heart_only_list = torch.load('mrnn_all_feat_heart_only_list')

saits_transfer_list = torch.load('saits_all_feat_transfer_list')
brits_transfer_list = torch.load('brits_all_feat_transfer_list')
mrnn_transfer_list = torch.load('mrnn_all_feat_transfer_list')

In [18]:
saits_heart_only_metric_list = []
for t in saits_heart_only_list:
    model = t['model'].to(device)
    saits_heart_only_metric_list.append(saits_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only))

brits_heart_only_metric_list = []
for t in brits_heart_only_list:
    model = t['model'].to(device)
    brits_heart_only_metric_list.append(brits_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only, test_deltas_heart_only))

mrnn_heart_only_metric_list = []
for t in mrnn_heart_only_list:
    model = t['model']
    mrnn_heart_only_metric_list.append(mrnn_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only, test_deltas_heart_only))

saits_transfer_metric_list = []
for t in saits_transfer_list:
    model = t['model']
    saits_transfer_metric_list.append(saits_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only))

brits_transfer_metric_list = []
for t in brits_transfer_list:
    model = t['model']
    brits_transfer_metric_list.append(brits_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only, test_deltas_heart_only))

mrnn_transfer_metric_list = []
for t in mrnn_transfer_list:
    model = t['model']
    mrnn_transfer_metric_list.append(mrnn_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only, test_deltas_heart_only))

saits_all_train_heart_eval_metric_list = []
for t in saits_all_icu_list:
    model = t['model']
    saits_all_train_heart_eval_metric_list.append(saits_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only))

brits_all_train_heart_eval_metric_list = []
for t in brits_all_icu_list:
    model = t['model']
    brits_all_train_heart_eval_metric_list.append(brits_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only, test_deltas_heart_only))

mrnn_all_train_heart_eval_metric_list = []
for t in mrnn_all_icu_list:
    model = t['model']
    mrnn_all_train_heart_eval_metric_list.append(mrnn_evaluate(model, test_set_heart_only, test_missing_mask_heart_only, test_indicating_mask_heart_only, test_deltas_heart_only))

In [19]:
eval_dic_heart_only['SAITS trained on heart'] = metric_list_to_eval(saits_heart_only_metric_list)

eval_dic_heart_only['BRITS trained on heart'] = metric_list_to_eval(brits_heart_only_metric_list)

eval_dic_heart_only['MRNN trained on heart'] = metric_list_to_eval(mrnn_heart_only_metric_list)

eval_dic_heart_only['SAITS transfer'] = metric_list_to_eval(saits_transfer_metric_list)

eval_dic_heart_only['BRITS transfer'] = metric_list_to_eval(brits_transfer_metric_list)


eval_dic_heart_only['MRNN transfer'] = metric_list_to_eval(mrnn_transfer_metric_list)

eval_dic_heart_only['SAITS trained on ALL'] = metric_list_to_eval(saits_all_train_heart_eval_metric_list)

eval_dic_heart_only['BRITS trained on ALL'] = metric_list_to_eval(brits_all_train_heart_eval_metric_list)

eval_dic_heart_only['MRNN trained on ALL'] = metric_list_to_eval(mrnn_all_train_heart_eval_metric_list)

In [20]:
eval_dic_heart_only

{'last': {'rmse': tensor(0.3859, device='cuda:0'),
  'mae': tensor(0.1662, device='cuda:0'),
  'mre': tensor(0.5673, device='cuda:0')},
 'mean': {'rmse': tensor(0.5566, device='cuda:0'),
  'mae': tensor(0.2931, device='cuda:0'),
  'mre': tensor(1.0007, device='cuda:0')},
 'median': {'rmse': tensor(0.5588, device='cuda:0'),
  'mae': tensor(0.2907, device='cuda:0'),
  'mre': tensor(0.9926, device='cuda:0')},
 'zero': {'rmse': tensor(0.5565, device='cuda:0'),
  'mae': tensor(0.2929, device='cuda:0'),
  'mre': tensor(1., device='cuda:0')},
 'SAITS trained on heart': {'mae': '0.13596890419721602 +/- 0.0015334600416566575',
  'rmse': '0.3106510162353516 +/- 0.0034608727334073117',
  'mre': '0.46422640681266786 +/- 0.005235557400733331'},
 'BRITS trained on heart': {'mae': '0.1379660561680794 +/- 0.0006379032548373497',
  'rmse': '0.3097578167915344 +/- 0.0012445385046467704',
  'mre': '0.4710451036691666 +/- 0.0021779366661780703'},
 'MRNN trained on heart': {'mae': '0.16008420586585997 +/- 

In [21]:
print('Heart Only: mae/mre/rmse')
for k in eval_dic_heart_only.keys():
    print(k, eval_dic_heart_only[k]['mae'],'/', eval_dic_heart_only[k]['mre'], '/', eval_dic_heart_only[k]['rmse'])

Heart Only: mae/mre/rmse
last tensor(0.1662, device='cuda:0') / tensor(0.5673, device='cuda:0') / tensor(0.3859, device='cuda:0')
mean tensor(0.2931, device='cuda:0') / tensor(1.0007, device='cuda:0') / tensor(0.5566, device='cuda:0')
median tensor(0.2907, device='cuda:0') / tensor(0.9926, device='cuda:0') / tensor(0.5588, device='cuda:0')
zero tensor(0.2929, device='cuda:0') / tensor(1., device='cuda:0') / tensor(0.5565, device='cuda:0')
SAITS trained on heart 0.13596890419721602 +/- 0.0015334600416566575 / 0.46422640681266786 +/- 0.005235557400733331 / 0.3106510162353516 +/- 0.0034608727334073117
BRITS trained on heart 0.1379660561680794 +/- 0.0006379032548373497 / 0.4710451036691666 +/- 0.0021779366661780703 / 0.3097578167915344 +/- 0.0012445385046467704
MRNN trained on heart 0.16008420586585997 +/- 0.008459495103872592 / 0.5465610980987549 +/- 0.028882496932761972 / 0.34419241547584534 +/- 0.011560419290776445
SAITS transfer 0.12501098662614823 +/- 9.153766706572154e-05 / 0.4268137

In [22]:
print('All ICU: mae/mre/rmse')
for k in eval_dic_all_icu.keys():
    print(k, eval_dic_all_icu[k]['mae'],'/', eval_dic_all_icu[k]['mre'], '/', eval_dic_all_icu[k]['rmse'])

All ICU: mae/mre/rmse
last tensor(0.1690, device='cuda:0') / tensor(0.5690, device='cuda:0') / tensor(0.5219, device='cuda:0')
mean tensor(0.2972, device='cuda:0') / tensor(1.0005, device='cuda:0') / tensor(0.6580, device='cuda:0')
median tensor(0.2944, device='cuda:0') / tensor(0.9913, device='cuda:0') / tensor(0.6595, device='cuda:0')
zero tensor(0.2970, device='cuda:0') / tensor(1., device='cuda:0') / tensor(0.6580, device='cuda:0')
SAITS 0.1317450985312462 +/- 0.0007521993712309375 / 0.44357359409332275 +/- 0.0025325938230718677 / 0.45556535124778746 +/- 0.0037330136215141447
BRITS 0.13817179650068284 +/- 0.0007273374806953895 / 0.4652116984128952 +/- 0.0024488801004228264 / 0.4558850258588791 +/- 0.001058615201040269
MRNN 0.1536129578948021 +/- 0.004476994229747861 / 0.5172006726264954 +/- 0.015073633915467875 / 0.4828182846307755 +/- 0.006551771879296376
