# Train models

With this notebook, we train simple untunted recurrent models with sufficient performance to demonstrate the effect of data augmentation strategies

In [None]:
import os, time, pickle, boto3, sys
from datetime import datetime

import tensorflow as tf
import tensorflow.keras as tfk
import tensorflow.keras.layers as tfkl
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_curve, average_precision_score, roc_curve, auc, confusion_matrix

In [None]:
from utils.datagen import PROCESSED_DATAPATH, MODEL_INPUT_DATAPATH, RESULTFILE_DATAPATH, DataGen, strategies, targets, get_datafile
from utils.utils import read_data, dump_data
from utils.connections import model_input_bucket, processed_data_bucket, get_s3_keys_as_generator, download_file

# check pre-reqs. 
assert(tf.__version__[0]=='2')
assert len(tf.config.experimental.list_physical_devices('GPU')) > 0

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

In [None]:
def create_model(embedding_weight, name, timesteps=500, dropout=0.3):
    max_features = embedding_weight.shape[0]
    output_dim = embedding_weight.shape[1]
    model_input = tfk.Input(shape=(timesteps,), name=f'{name}_input')
    x = tfkl.Embedding(max_features+1, output_dim, name=f'{name}_embed')(model_input)
    x = tfkl.Bidirectional(tfkl.LSTM(units=10, name=f'{name}_recurrent',
                                     return_sequences=False, 
                                     kernel_initializer='glorot_uniform', 
                                     bias_initializer='zeros'))(x)
    x = tfkl.LeakyReLU(alpha=0.3)(x)
    x = tfkl.Dropout(dropout)(x)
    x = tfkl.Dense(10, activation='relu')(x)
    x = tfkl.Dropout(dropout)(x)
    model_output = tfkl.Dense(2, activation='softmax', name=f'{name}_output')(x)
    return tfk.Model(inputs=model_input, outputs=model_output)

In [None]:
# This is definitely not the most efficient way of performing validation - the tfrecord
# format is optimised for training, and for validation where no filtering is required.
# It would have been better to not apply oversampling to the validation set for 
# oversample_basic and oversample_tte strategies, but it can be worked around.

# For data augmentation strategies, we need to randomly filter the same number of 
# augmented samples for each ID before performing validation


def valid_epoch(data_valid, model, valid_step, data_index):
    valid_iter = iter(data_valid)
    preds = []
    ys = []
    id_list = []
    valid_loss = []
    while True:
        try:
            x_clin, x_vs, y, ids, tt = next(valid_iter)
            x = [x_clin, x_vs]
            if data_index >= 0:
                x = x[data_index]
            pred, loss = valid_step(x, y, model)
            valid_loss.append(loss.numpy())
            preds.append(pred.numpy())
            ys.append(y.numpy())
            id_list.append(ids.numpy())
        except StopIteration:
            break
    probabilities = np.vstack(preds)[:,0]
    Y = np.hstack(ys)
    IDS = np.hstack(id_list)
    unique_filter = np.unique(IDS, return_index=True)[1]
    probabilities = probabilities[unique_filter]
    Y = Y[unique_filter]
    predictions = np.where(probabilities < 0.5, 0., 1.)
    tp = len(np.where((predictions==Y)&(predictions==1))[0])
    fp = len(np.where((predictions!=Y)&(predictions==1))[0])
    tn = len(np.where((predictions==Y)&(predictions==0))[0])
    fn = len(np.where((predictions!=Y)&(predictions==0))[0])
    accuracy = (tp + tn)/len(predictions)
    fpr, tpr, thresholds = roc_curve(Y, probabilities)
    val_auc = auc(fpr, tpr)
    sensitivity = tp/(tp + fn)
    specificity = tn/(tn + fp)
    try:
        WDR = (tp + fp)/tp
    except ZeroDivisionError:
        WDR = 0
    return tp, fp, tn, fn, accuracy, val_auc, sensitivity, specificity, WDR

def train_epoch(data_train, model, train_step, data_index):
    t0 = time.clock()
    train_iter = iter(data_train)
    epoch_losses = []
    while True:
        try:
            x_clin, x_vs, y, ids, tt = next(train_iter)
            _, loss = train_step(x_clin, y, clin_model)
            epoch_losses.append(loss.numpy())
        except StopIteration:
            break
    t1 = time.clock() - t0
    return t1, np.mean(epoch_losses)

In [None]:
def train_model(data_train, data_valid, model, data_index, epochs, label, epoch_start=0):
    result_strings = []
    @tf.function
    def get_loss(Y, predictions):
        return tfk.backend.binary_crossentropy(predictions[:,0], tf.cast(Y, tf.float32))

    @tf.function
    def train_step(x, y, model):
        with tf.GradientTape() as tape:
            preds = model(x, training=True)
            loss = tf.reduce_mean(get_loss(y, preds))
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return preds, loss

    @tf.function
    def valid_step(x, y, model, training=False):
        preds = model(x, training=training)
        loss = tf.reduce_mean(get_loss(y, preds))
        return preds, loss

    runfile_name = datetime.now().strftime("%Y%m%d_%H:%M:%S")
    for e in range(epoch_start, epochs):
        elapsed_time, epoch_loss = train_epoch(data_train, model, train_step, data_index)
        tp, fp, tn, fn, accuracy, val_auc, sensitivity, specificity, WDR = valid_epoch(data_valid, model, valid_step, data_index)
        result_string = (f'{label}\t{e}\t{elapsed_time}\t{np.mean(epoch_loss)}\t{tp}\t{fp}\t{tn}\t{fn}\t{accuracy}\t{val_auc}\t{sensitivity}\t{specificity}\t{WDR}')
        print(result_string)
        result_strings.append(result_string)
        model.save_weights(os.path.join(RESULTFILE_DATAPATH, f'{label}_e{e}.h5'))
        with open(os.path.join(RESULTFILE_DATAPATH, 'summaries', f'{label}_{runfile_name}.tsv'), 'a+') as outfile:
            outfile.write(result_string + '\n')
    return result_strings

In [None]:
fold = 0
batch_size=128

for target in ['hosp_death', 'icu_death']:
    for strategy in strategies:
        if target != 'hosp_death' or strategy != 'original':
            data_train = get_datafile(target, strategy, fold=fold, 
                                      phase='train', batch_size=batch_size, 
                                      model_type='both')
            data_valid = get_datafile(target, strategy, fold=fold, 
                                      phase='valid', batch_size=batch_size, 
                                      model_type='both')

            weight_files = [f for f in os.listdir(RESULTFILE_DATAPATH) if strategy in f and target in f and 'h5' in f]
            completed_epochs = max([int(w.split('.')[0].split('_')[-1].strip('e')) for w in weight_files])
            target_weight_file = f'{strategy}_{target}_e{completed_epochs}.h5'
            clin_model = create_single_recurrent(clin_embedding_weight, 'clin')
            clin_model.load_weights(os.path.join(RESULTFILE_DATAPATH, target_weight_file))

            optimizer = tfk.optimizers.Adam(learning_rate=1e-4, clipnorm=1)

            clin_results = train_model(data_train, data_valid, clin_model, 0, completed_epochs*3, f'{strategy}_{target}', completed_epochs + 1)
            print('='*100)
        
