In [None]:
import torch
from eicu_reader import eICUReader
from mimic_reader import MIMICReader

from torch.optim import Adam
from dtsc_caff_model import TempSepConv_CAFF

from sklearn import metrics
import pandas as pd
import numpy as np

import os
import random


In [None]:
config = {}
config['dataset'] = 'eICU' #'MIMIC'
config['task'] = 'LoS' # 'mortality'
config['diagnosis_size'] = 64
config['sum_losses'] = False
config['loss'] = 'msle'
config['last_linear_size'] = 32

if config['dataset'] == 'eICU':
    config['no_diag'] = False
    config['main_dropout_rate'] = 0.45
    if config['task'] == 'mortality':
        config['n_epochs'] = 6
    else:
        config['n_epochs'] = 6
    config['batch_size'] = 16
    config["batch_size_test"] = 8
    config['n_layers'] = 11 
    config['kernel_size'] = 4
    config['temp_kernels'] = [12] * config['n_layers']
    config['learning_rate'] = 0.002
    config['temp_dropout_rate'] = 0.05

elif config['dataset'] == 'MIMIC':
    config['no_diag'] = True
    config['main_dropout_rate'] = 0
    config['n_epochs'] = 10 if config['task'] is not 'mortality' else 6
    config['batch_size'] = 8
    config['batch_size_test'] = 8 
    config['n_layers'] = 8
    config['kernel_size'] = 5
    config['learning_rate'] = 0.002
    config['temp_dropout_rate'] = 0.05
    config['temp_kernels'] = [11] * config['n_layers']

config['model_ckpt_path'] = './mdl_checkpoints/'

config['mdl_name'] = 'DTSC-CAFF'
config['window'] = 335 # max 15 days prediction window for LoS prediction
config['L2_regularisation'] = 0.0001
config['sum_losses'] = True


if torch.cuda.is_available():
    device = torch.device('cuda:3')
else:
    device = torch.device('cpu')

# get datareader
if config['dataset'] == 'MIMIC':
    datareader = MIMICReader
    data_path = "./data/MIMIC_data/"
else:
    datareader = eICUReader
    data_path = "./data/eICU_data/"

### Data Loaders

In [None]:
train_datareader = datareader(data_path + 'train', max_len=config['window'], device=device)
val_datareader = datareader(data_path + 'val', max_len= config['window'], device=device)
test_datareader = datareader(data_path + 'test', max_len=config['window'], device=device)

### Model 

In [None]:
model = TempSepConv_CAFF(config=config,
                                   no_ts_features=train_datareader.no_ts_features,
                                   no_daig_features=train_datareader.no_daig_features,
                                   no_flat_features=train_datareader.no_flat_features).to(device=device)
optimiser = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['L2_regularisation'])

### Metrices

In [None]:
class CustomBins:
    inf = 1e18
    bins = [(-inf, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 14), (14, +inf)]
    nbins = len(bins)

def get_bin_custom(x, nbins, one_hot=False):
    for i in range(nbins):
        a = CustomBins.bins[i][0]
        b = CustomBins.bins[i][1]
        if a <= x < b:
            if one_hot:
                onehot = np.zeros((CustomBins.nbins,))
                onehot[i] = 1
                return onehot
            return i
    return None

def mean_absolute_percentage_error(y_true, y_pred):
    return np.mean(np.abs((y_true - y_pred) / np.maximum(4/24, y_true))) * 100  # this stops the mape being a stupidly large value when y_true happens to be very small

def mean_squared_logarithmic_error(y_true, y_pred):
    return np.mean(np.square(np.log(y_true/y_pred)))

def root_mean_squared_logarithmic_error(y_true, y_pred):
    # mean_square_log_error = np.mean(np.square(np.log(y_true/y_pred)))
    square_error = np.square((np.log(y_true + 1) - np.log(y_pred + 1)))
    mean_square_log_error = np.mean(square_error)
    rmsle_loss = np.sqrt(mean_square_log_error)
    return rmsle_loss

def print_metrics_regression(y_true, predictions, verbose=1):
    print('==> Length of Stay:')
    y_true_bins = [get_bin_custom(x, CustomBins.nbins) for x in y_true]
    prediction_bins = [get_bin_custom(x, CustomBins.nbins) for x in predictions]
    cf = metrics.confusion_matrix(y_true_bins, prediction_bins)
    if verbose:
        print('Custom bins confusion matrix:')
        print(cf)

    kappa = metrics.cohen_kappa_score(y_true_bins, prediction_bins, weights='linear')
    mad = metrics.mean_absolute_error(y_true, predictions)
    mse = metrics.mean_squared_error(y_true, predictions)
    rmse = np.sqrt(metrics.mean_squared_error(y_true, predictions))
    mape = mean_absolute_percentage_error(y_true, predictions)
    msle = mean_squared_logarithmic_error(y_true, predictions)
    rmsle = root_mean_squared_logarithmic_error(y_true, predictions)
    r2 = metrics.r2_score(y_true, predictions)

    if verbose:
        print('Mean absolute deviation (MAD) = {}'.format(mad))
        print('Mean squared error (MSE) = {}'.format(mse))
        print('Root Mean squared error (RMSE) = {}'.format(rmse)) # RMSE
        print('Mean absolute percentage error (MAPE) = {}'.format(mape))
        print('Mean squared logarithmic error (MSLE) = {}'.format(msle))
        print('Root Mean squared logarithmic error (RMSLE) = {}'.format(rmsle)) # RMSLE
        print('R^2 Score = {}'.format(r2))
        print('Cohen kappa score = {}'.format(kappa))

    return [mad, mse, rmse, mape, msle, rmsle, r2, kappa]

def print_metrics_mortality(y_true, prediction_probs, verbose=1):
    print('==> Mortality:')
    prediction_probs = np.array(prediction_probs)
    prediction_probs = np.transpose(np.append([1 - prediction_probs], [prediction_probs], axis=0))
    predictions = prediction_probs.argmax(axis=1)
    cf = metrics.confusion_matrix(y_true, predictions, labels=range(2))
    
    if verbose:
        print('Confusion matrix:')
        print(cf)
    cf = cf.astype(np.float32)

    acc = (cf[0][0] + cf[1][1]) / np.sum(cf)
    prec0 = cf[0][0] / (cf[0][0] + cf[1][0])
    prec1 = cf[1][1] / (cf[1][1] + cf[0][1])
    rec0 = cf[0][0] / (cf[0][0] + cf[0][1])
    rec1 = cf[1][1] / (cf[1][1] + cf[1][0])

    auroc = metrics.roc_auc_score(y_true, prediction_probs[:, 1])
    (precisions, recalls, thresholds) = metrics.precision_recall_curve(y_true, prediction_probs[:, 1])
    auprc = metrics.auc(recalls, precisions)
    f1macro = metrics.f1_score(y_true, predictions, average='macro')

    results = {'Accuracy': acc, 'Precision Survived': prec0, 'Precision Died': prec1, 'Recall Survived': rec0,
               'Recall Died': rec1, 'Area Under the Receiver Operating Characteristic curve (AUROC)': auroc,
               'Area Under the Precision Recall curve (AUPRC)': auprc, 'F1 score (macro averaged)': f1macro}
    if verbose:
        for key in results:
            print('{} = {}'.format(key, results[key]))

    return [acc, prec0, prec1, rec0, rec1, auroc, auprc, f1macro]

In [None]:
def _remove_padding(y, mask, device):
    """
        Filters out padding from tensor of predictions or labels

        Args:
            y: tensor of los predictions or labels
            mask (bool_type): tensor showing which values are padding (0) and which are data (1)
    """
    
    y = y.where(mask, torch.tensor(float('nan')).to(device=device)).flatten().detach().cpu().numpy()
    y = y[~np.isnan(y)]
    return y


### Training and Validation

In [None]:
no_train_batches = len(train_datareader.patients) // config['batch_size']
checkpoint_counter = 0
no_train_batches = len(train_datareader.patients) // config['batch_size']
n_epochs = config['n_epochs']
max_auroc = 0
max_auprc = 0
max_f1macro = 0
max_msle = 100
max_r2 = 0
max_kapa = 0
mort_pred_time=18
best=True
file_name_best = '{}/{}_{}_Best.pth'.format(config['model_ckpt_path'], config['dataset'], config['task'])
file_name_last = '{}/{}_{}_Last.pth'.format(config['model_ckpt_path'], config['dataset'], config['task'])

remove_padding = lambda y, mask: _remove_padding(y, mask, device=device)

bool_type = torch.cuda.BoolTensor if device == torch.device('cuda:3') else torch.BoolTensor

In [None]:
for epoch in range(n_epochs):
    model.train() 
    train_batches = train_datareader.batch_gen(batch_size=config['batch_size'])
    train_loss = []
    train_y_hat_los = np.array([])
    train_y_los = np.array([])
    train_y_hat_mort = np.array([])
    train_y_mort = np.array([])
    print('Train Epoch#{}'.format(epoch))
    for batch_idx, batch in enumerate(train_batches):
        if batch[0].size(0) <= 1:
            continue

        # unpack batch
        if config['dataset'] == 'MIMIC':
            padded, mask, flat, los_labels, mort_labels, seq_lengths = batch
            diagnoses = None
        else:
            padded, mask, diagnoses, flat, los_labels, mort_labels, seq_lengths = batch

        optimiser.zero_grad()
        y_hat_los, y_hat_mort, w1, w2 = model(padded, diagnoses, flat)


        loss = model.loss(y_hat_los, y_hat_mort, los_labels, mort_labels, mask, seq_lengths, device, 
                          config['sum_losses'], config['loss'])
        if batch_idx % 1000 == 0:
            print('Train Loss: {} - batch: {} / Num Batches: {}'.format(loss, batch_idx, no_train_batches))

        loss.backward()
        optimiser.step()
        train_loss.append(loss.item())

        if config['task'] in ('LoS'):
            train_y_hat_los = np.append(train_y_hat_los, remove_padding(y_hat_los, mask.type(bool_type)))
            train_y_los = np.append(train_y_los, remove_padding(los_labels, mask.type(bool_type)))
            # print('  train_y_hat_los.shape: ==========>',train_y_hat_los.shape)
            # print('  train_y_los.shape: ==========>',train_y_los.shape)
        if config['task'] in ('mortality') and mort_labels.shape[1] >= mort_pred_time:
            train_y_hat_mort = np.append(train_y_hat_mort,
                                         remove_padding(y_hat_mort[:, mort_pred_time],
                                                             mask.type(bool_type)[:, mort_pred_time]))
            train_y_mort = np.append(train_y_mort, remove_padding(mort_labels[:, mort_pred_time],
                                                                       mask.type(bool_type)[:, mort_pred_time]))

    print('Train Metrics:')
    mean_train_loss = sum(train_loss) / len(train_loss)
    if config['task'] in ('LoS'):
        los_metrics_list = print_metrics_regression(train_y_los, train_y_hat_los) 

    if config['task'] in ('mortality'):
        mort_metrics_list = print_metrics_mortality(train_y_mort, train_y_hat_mort)

    print('Epoch: {} | Train Loss: {:3.4f}'.format(epoch, mean_train_loss))
    
    ########################################### Validation #########################################    
    
    model.eval()
    val_batches = val_datareader.batch_gen(batch_size=config['batch_size_test'])
    val_loss = []
    val_y_hat_los = np.array([])
    val_y_los = np.array([])
    val_y_hat_mort = np.array([])
    val_y_mort = np.array([])
    print('Validation Epoch#{}'.format(epoch))
    for batch in val_batches:
        if batch[0].size(0) <= 1:
            continue

        if config['dataset'] == 'MIMIC':
            padded, mask, flat, los_labels, mort_labels, seq_lengths = batch
            diagnoses = None
        else:
            padded, mask, diagnoses, flat, los_labels, mort_labels, seq_lengths = batch

        y_hat_los, y_hat_mort, w1, w2 = model(padded, diagnoses, flat)
        
        loss = model.loss(y_hat_los, y_hat_mort, los_labels, mort_labels, mask, seq_lengths, device, 
                          config['sum_losses'], config['loss'])
        val_loss.append(loss.item())  

        if config['task'] == 'LoS':
            val_y_hat_los = np.append(val_y_hat_los,
                                        remove_padding(y_hat_los, mask.type(bool_type)))
            val_y_los = np.append(val_y_los, remove_padding(los_labels, mask.type(bool_type)))
        if config['task'] == 'mortality' and mort_labels.shape[1] >= mort_pred_time:
            val_y_hat_mort = np.append(val_y_hat_mort,
                                         remove_padding(y_hat_mort[:, mort_pred_time],
                                                             mask.type(bool_type)[:, mort_pred_time]))
            val_y_mort = np.append(val_y_mort, remove_padding(mort_labels[:, mort_pred_time],
                                                                   mask.type(bool_type)[:, mort_pred_time]))

    print('Validation Metrics:')
    mean_val_loss = sum(val_loss) / len(val_loss)
    
    print('Epoch: {} | Validation Loss: {:3.4f}'.format(epoch, mean_val_loss))
    
    if config['task'] == 'LoS':
        los_metrics_list = print_metrics_regression(val_y_los, val_y_hat_los) 
        ##########################Saving Model#####################
        cur_msle = los_metrics_list[4]
        cur_r2 = los_metrics_list[6]
        cur_kapa = los_metrics_list[7]
        print('cur_msle: ', cur_msle)
        print('max_msle: ', max_msle)
        if cur_msle < max_msle:
            print('\n------------ Save model checkpoint best ------------\n')
            max_msle = cur_msle
            state = {
                'net': model.state_dict(),
                'optimiser': optimiser.state_dict(),
                'epoch': epoch
            }
            torch.save(state, file_name_best)
            
        elif epoch == n_epochs - 1:
            print('\n------------ Save model checkpoint last------------\n')
            state = {
                'net': model.state_dict(),
                'optimiser': optimiser.state_dict(),
                'epoch': epoch
            }
            torch.save(state, file_name_last)

    if config['task'] == 'mortality':
        mort_metrics_list = print_metrics_mortality(val_y_mort, val_y_hat_mort)
        ##########################Saving Model#####################
        cur_auroc = mort_metrics_list[5] #5:'auroc', 6:'auprc', 7:'f1macro'
        cur_f1_score = mort_metrics_list[7]

        if cur_auroc > max_auroc and config['task'] == 'mortality':
            max_auroc = cur_auroc
            state = {
                'net': model.state_dict(),
                'optimiser': optimiser.state_dict(),
                'epoch': epoch
            }
            torch.save(state, file_name_best)
            print('\n------------ Save model checkpoint best ------------\n')

        elif epoch == n_epochs - 1:
            print('\n------------ Save model checkpoint last------------\n')
            state = {
                'net': model.state_dict(),
                'optimiser': optimiser.state_dict(),
                'epoch': epoch
            }
            torch.save(state, file_name_last)

### Test

In [None]:

if best:
    ####################Load Best Checkpoint######################
    checkpoint = torch.load(file_name_best)
    save_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['net'])
    optimiser.load_state_dict(checkpoint['optimiser'])
    ##############################################################           
else:
    ####################Load Last Checkpoint######################
    checkpoint = torch.load(file_name_last)
    save_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['net'])
    optimiser.load_state_dict(checkpoint['optimiser'])
    ##############################################################           

model.eval()
test_batches = test_datareader.batch_gen(batch_size=config['batch_size_test'])
test_loss = []
test_y_hat_los = np.array([])
test_y_los = np.array([])
test_y_hat_mort = np.array([])
test_y_mort = np.array([])

for batch in test_batches:
    if batch[0].size(0) <= 1:
        continue

    if config['dataset'] == 'MIMIC':
        padded, mask, flat, los_labels, mort_labels, seq_lengths = batch
        diagnoses = None
    else:
        padded, mask, diagnoses, flat, los_labels, mort_labels, seq_lengths = batch

    y_hat_los, y_hat_mort, w1, w2 = model(padded, diagnoses, flat)
    loss = model.loss(y_hat_los, y_hat_mort, los_labels, mort_labels, mask, seq_lengths, device,
                           config['sum_losses'], config['loss'])
    test_loss.append(loss.item()) 

    if config['task'] == 'LoS':
        test_y_hat_los = np.append(test_y_hat_los,
                                  remove_padding(y_hat_los, mask.type(bool_type)))
        test_y_los = np.append(test_y_los, remove_padding(los_labels, mask.type(bool_type)))
    if config['task'] == 'mortality' and mort_labels.shape[1] >= mort_pred_time:
        test_y_hat_mort = np.append(test_y_hat_mort,
                                   remove_padding(y_hat_mort[:, mort_pred_time],
                                                       mask.type(bool_type)[:, mort_pred_time]))
        test_y_mort = np.append(test_y_mort, remove_padding(mort_labels[:, mort_pred_time],
                                                                 mask.type(bool_type)[:, mort_pred_time]))

print('Test Metrics:')
mean_test_loss = sum(test_loss) / len(test_loss)

if config['task'] == 'LoS':
    print('  ====> test_y_los.shape: ', test_y_los.shape)
    print('  ====> test_y_hat_los.shape: ', test_y_hat_los.shape)
    los_metrics_list = print_metrics_regression(test_y_los, test_y_hat_los) 
if config['task'] == 'mortality':
    mort_metrics_list = print_metrics_mortality(test_y_mort, test_y_hat_mort)

print('Test Loss: {:3.4f}'.format(mean_test_loss))
