In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from argparse import Namespace
from transformers.optimization import AdamW
from evaluator import Evaluator
from tqdm import tqdm
from typing import Tuple
from modeling_gru import GRU_TS_Explained

from timeshap.explainer import local_report, global_report, calc_local_report, plot_local_report,calc_global_explanations,plot_global_report
from timeshap.wrappers import TorchModelWrapper
from utils import parse_args,set_all_seeds,set_output_dir,Logger, save_results
from dataset import Dataset
from timeshap.utils import calc_avg_event
from timeshap.utils import get_avg_score_with_avg_event


In [10]:
parser = parse_args(model_type = 'gru', hid_dim =32,dropout = 0.2, lr = 0.0005, max_epochs = 1, gradient_accumulation_steps = 8)
args, _ = parser.parse_known_args()
set_all_seeds(args.seed+int(args.run.split('o')[0]))
set_output_dir(args)
args.logger = Logger(args.output_dir, 'log.txt')
args.logger.write('\n'+str(args))
dataset = Dataset(args)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


16/10/2024 14:13:00 >> Namespace(dataset='mimic_iii', train_frac=0.7, run='1o10', model_type='gru', max_obs=880, hid_dim=32, num_layers=2, num_heads=4, dropout=0.2, attention_dropout=0.2, kernel_size=4, r=24, M=12, max_timesteps=880, num_ts_feat=51, num_demo_feat=3, hours_look_ahead=24, ref_points=24, pretrain=0, output_dir='../outputs/mimic_iii/gru,hid_dim:32,dropout:0.2,lr:0.0005,gradient_accumulation_steps:8,max_epochs:1|train_frac:0.7|run:1o10', output_dir_prefix='', seed=2024, max_epochs=1, patience=10, lr=0.0005, train_batch_size=16, gradient_accumulation_steps=8, eval_batch_size=16, print_train_loss_every=100, validate_after=-1, validate_every=None, load_ckpt_path=None, logger=<utils.Logger object at 0x2b5519720>)

16/10/2024 14:13:02 >> Preparing dataset mimic_iii
16/10/2024 14:13:02 >> Removing variables not in training set: ['.3% normal Saline']
16/10/2024 14:13:03 >> # train, val, test TS: [3918, 1120, 800]
16/10/2024 14:13:03 >> pos class weight: 3.934508816120907
16/10/20

12199it [00:00, 1833655.19it/s]


16/10/2024 14:13:03 >> # intervals: 24


941654it [00:00, 1737845.85it/s]


In [11]:
d_ts = dataset.X
d_demo = dataset.demo
demo_expanded = np.expand_dims(d_demo, axis=1)
demo_tiled = np.tile(demo_expanded, (1, 24, 1))
d_all = np.concatenate((d_ts, demo_tiled), axis=2)

num_samples, num_timestamps, num_features = d_all.shape

sample_ids = np.repeat(np.arange(num_samples), num_timestamps)
timestamps = np.tile(np.arange(num_timestamps), num_samples)

# Flatten the data
flattened_data = d_all.reshape(num_samples * num_timestamps, num_features)

# Combine into a DataFrame
d_all_transformed = pd.DataFrame(flattened_data, columns=[f'feature_{i}' for i in range(num_features)])
d_all_transformed['sample_id'] = sample_ids
d_all_transformed['timestamp'] = timestamps

# Reorder columns to have 'sample_id' and 'timestamp' first
columns = ['sample_id', 'timestamp'] + [f'feature_{i}' for i in range(num_features)]
d_all_transformed = d_all_transformed[columns]

# Rename features
variable_names = list(dataset.var_to_ind_mapping.keys())
feat_names = [v+'_value' for v in variable_names]
feat_names += [v+'_obs' for v in variable_names]
feat_names += [v+'_delta' for v in variable_names]
feat_names += ['age','gender','height']
d_all_transformed.columns = ['sample_id', 'timestamp'] + feat_names

train_sample_ids = dataset.splits['train']
test_sample_ids = dataset.splits['test']
d_train_transformed = d_all_transformed[d_all_transformed['sample_id'].isin(train_sample_ids)]
d_test_transformed = d_all_transformed[d_all_transformed['sample_id'].isin(test_sample_ids)]

feats_in_scope = [v+'_value' for v in variable_names]
average_event = calc_avg_event(d_train_transformed, numerical_feats=feat_names, categorical_feats=[])
average_event

positive_sequence_id = test_sample_ids[0]
pos_x_pd = d_test_transformed[d_test_transformed['sample_id'] == positive_sequence_id]

# select model features only
pos_x_data = pos_x_pd[feat_names]
# convert the instance to numpy so TimeSHAP receives it
pos_x_data = np.expand_dims(pos_x_data.to_numpy().copy(), axis=0)
print("pos x data shape: ", pos_x_data.shape)

pos x data shape:  (1, 24, 156)


In [12]:
model = GRU_TS_Explained(args)
model.to(args.device)
results = {'epoch':[],'train_auroc':[],'val_auroc':[],'test_auroc':[]}

#training
num_train = len(dataset.splits['train'])
args.logger.write('\nSize of training data: ' + str(num_train))
num_batches_per_epoch = num_train/args.train_batch_size
args.logger.write('\nNo. of training batches per epoch: '
                    +str(num_batches_per_epoch))
args.max_steps = int(round(num_batches_per_epoch)*args.max_epochs)
args.logger.write('\nMax steps: ' + str(num_batches_per_epoch))

if args.validate_every is None:
    args.validate_every = int(np.ceil(num_batches_per_epoch)) #validate after each batch

num_steps = 0
optimizer = AdamW(filter(lambda p:p.requires_grad, model.parameters()), lr=args.lr)
train_bar = tqdm(range(args.max_steps))
evaluator = Evaluator(args)

#before training, calculate metrics
if args.validate_after<0:
    res_val = evaluator.evaluate_gru_explained(model, dataset, 'val',  train_step=-1)
    res_train = evaluator.evaluate_gru_explained(model, dataset, 'eval_train', train_step=-1)
    res_test = evaluator.evaluate_gru_explained(model, dataset, 'test', train_step=-1)
    results['epoch'].append(0)
    results['train_auroc'].append(res_train['auroc'])
    results['val_auroc'].append(res_val['auroc'])
    results['test_auroc'].append(res_test['auroc'])

#training
model.train()
for step in train_bar:
    data_batch = dataset.get_batch()
    data_batch = {k:v.to(args.device) for k,v in data_batch.items()}
    tensor_a = data_batch['ts']  # torch.Size([16, 24, 150])
    tensor_b = data_batch['demo']      # torch.Size([16, 2])
    # Expand tensor_b to match the dimensions of tensor_a (i.e., [16, 24, 2])
    tensor_b_expanded = tensor_b.unsqueeze(1).expand(-1, 24, -1)  # torch.Size([16, 24, 2])
    # Concatenate tensor_b_expanded along the last dimension with tensor_a
    data = torch.cat((tensor_a, tensor_b_expanded), dim=2)  # torch.Size([16, 24, 152])
    
    pred,logits = model(data)
    loss = model.binary_cls_final(logits,data_batch['labels'])

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(),0.3)
    if (step+1)%args.gradient_accumulation_steps==0:
        optimizer.step()
        optimizer.zero_grad()

    num_steps += 1

    # run validatation
    if (num_steps>=args.validate_after) and (num_steps%args.validate_every==0):
        # get metrics on test and validation splits
        res_val = evaluator.evaluate_gru_explained(model, dataset, 'val', train_step=step)
        res_train = evaluator.evaluate_gru_explained(model, dataset, 'eval_train', train_step=step)
        res_test = evaluator.evaluate_gru_explained(model, dataset, 'test', train_step=step)
        results['train_auroc'].append(res_train['auroc'])
        results['val_auroc'].append(res_val['auroc'])
        results['test_auroc'].append(res_test['auroc'])
        model.train(True)

dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1



16/10/2024 14:13:44 >> Size of training data: 3918

16/10/2024 14:13:44 >> No. of training batches per epoch: 244.875

16/10/2024 14:13:44 >> Max steps: 244.875





16/10/2024 14:13:44 >> Evaluating on split = val



[A
running forward pass: 100%|██████████| 70/70 [00:00<00:00, 610.48it/s]


16/10/2024 14:13:44 >> Result on val split at train step -1: AUROC: 0.4806245249875427: AUPRC: 0.18891448277186873: MINRP: 0.20521739130434782

16/10/2024 14:13:44 >> Evaluating on split = eval_train



[A
running forward pass: 100%|██████████| 125/125 [00:00<00:00, 1171.54it/s]


16/10/2024 14:13:45 >> Result on eval_train split at train step -1: AUROC: 0.46951989845663594: AUPRC: 0.20252150126540075: MINRP: 0.214321482223335

16/10/2024 14:13:45 >> Evaluating on split = test



running forward pass: 100%|██████████| 50/50 [00:00<00:00, 1218.44it/s]


16/10/2024 14:13:45 >> Result on test split at train step -1: AUROC: 0.45888809353376286: AUPRC: 0.18916672688752756: MINRP: 0.20898876404494382





16/10/2024 14:13:45 >> Evaluating on split = val



running forward pass: 100%|██████████| 70/70 [00:00<00:00, 1245.16it/s]


16/10/2024 14:13:45 >> Result on val split at train step 244: AUROC: 0.5766940642946663: AUPRC: 0.248782793335385: MINRP: 0.26582278481012656

16/10/2024 14:13:45 >> Evaluating on split = eval_train



running forward pass: 100%|██████████| 125/125 [00:00<00:00, 1311.26it/s]


16/10/2024 14:13:45 >> Result on eval_train split at train step 244: AUROC: 0.5792579249007158: AUPRC: 0.26009096995161046: MINRP: 0.28738317757009346

16/10/2024 14:13:45 >> Evaluating on split = test



running forward pass: 100%|██████████| 50/50 [00:00<00:00, 1273.27it/s]
100%|██████████| 245/245 [00:01<00:00, 211.58it/s]

16/10/2024 14:13:46 >> Result on test split at train step 244: AUROC: 0.5467955141970889: AUPRC: 0.23145497221297126: MINRP: 0.24848484848484848





In [16]:
from timeshap.explainer import prune_all, pruning_statistics, event_explain_all, feat_explain_all
from timeshap.plot import plot_global_event, plot_global_feat

In [18]:
model_wrapped = TorchModelWrapper(model)
f_hs = lambda ts, y=None: model_wrapped.predict_last_hs(ts, y)
schema = list(d_all_transformed.columns)

pruning_dict = {'tol': [0.00001]}
prun_indexes = prune_all(f_hs, pos_x_pd, pruning_dict, average_event, feat_names, schema, 'sample_id', 'timestamp')
pruning_stats = pruning_statistics(prun_indexes, pruning_dict.get('tol'))
pruning_stats

Unnamed: 0,Tolerance,Mean,Std
0,0.00001,24.0,
1,No Pruning,24.0,


In [None]:
plot_features = {k:k for k in feat_names}
feature_dict = {'path': '/content/drive/My Drive/Colab Notebooks/code/outputs/explainable/feat_gru_2309.csv','rs': [42], 'nsamples': [1600], 'feature_names': feat_names, 'plot_features': plot_features,}

feat_data = feat_explain_all(f_hs, d_all_transformed, feature_dict, prun_indexes, average_event, feat_names, schema, 'sample_id', 'timestamp')