In [None]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils import data
from torch import nn 
from torch.utils.data import SequentialSampler

import argparse
import copy
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from time import time
from sklearn.metrics import roc_auc_score, precision_recall_curve, roc_curve, f1_score, auc, average_precision_score, confusion_matrix, classification_report
from sklearn.utils.fixes import signature
from sklearn.model_selection import KFold
torch.manual_seed(1)    # reproducible torch:2 np:3
np.random.seed(1)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# from models_xlnet import clinical_xlnet_seq, clinical_xlnet_lstm_FAST
# from stream_xlnet import Data_Encoder_Seq, Data_Encoder_FAST
# from configuration_xlnet import XLNetConfig
# from Ranger import Ranger

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [None]:
def main_finetune(lr, batch_size, train_epoch, dataFolder, prediction_label):
    '''
    finetune the model. Here, we are finetuning the LSTM!
    '''
    lr = lr
    BATCH_SIZE = batch_size
    train_epoch = train_epoch
    
    loss_history = []
    
    model = clinical_xlnet_lstm_FAST()
    model.cuda()
    
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, dim = 0)
            
   #
    print('--- Data Preparation ---')
    
    params = {'batch_size': BATCH_SIZE,
              'shuffle': True,
              'num_workers': 0, 
              'drop_last': True}

    df_train = pd.read_csv(dataFolder + '/train.csv')
    df_val = pd.read_csv(dataFolder + '/val.csv')
    df_test = pd.read_csv(dataFolder + '/test.csv')
    
    doc_train = torch.load(dataFolder + '/train_doc_emb.pt')
    doc_val = torch.load(dataFolder + '/val_doc_emb.pt')
    doc_test = torch.load(dataFolder + '/test_doc_emb.pt')
    
    def doc_emb_to_df(doc, df):
        output_seq = [torch.unsqueeze(doc[i],0) for i in range(doc.shape[0])]
        df = df.assign(DOC_EMB = output_seq)
        return df
    
    df_train = doc_emb_to_df(doc_train, df_train)
    df_val = doc_emb_to_df(doc_val, df_val)
    df_test = doc_emb_to_df(doc_test, df_test)
    
    if prediction_label == 'PMV':
        train_unique = df_train[['HADM_ID','Label']].drop_duplicates().reset_index(drop = True)
        val_unique = df_val[['HADM_ID','Label']].drop_duplicates().reset_index(drop = True)
        test_unique = df_test[['HADM_ID','Label']].drop_duplicates().reset_index(drop = True)

        training_set = Data_Encoder_FAST(train_unique.HADM_ID.values, train_unique.Label.values, df_train)
        training_generator = data.DataLoader(training_set, **params)

        validation_set = Data_Encoder_FAST(val_unique.HADM_ID.values, val_unique.Label.values, df_val)
        validation_generator = data.DataLoader(validation_set, **params)

        testing_set = Data_Encoder_FAST(test_unique.HADM_ID.values, test_unique.Label.values, df_test)
        testing_generator = data.DataLoader(testing_set, **params)
    
    elif prediction_label == 'Mortality':
        train_unique = df_train[['HADM_ID','DEATH_90']].drop_duplicates().reset_index(drop = True)
        val_unique = df_val[['HADM_ID','DEATH_90']].drop_duplicates().reset_index(drop = True)
        test_unique = df_test[['HADM_ID','DEATH_90']].drop_duplicates().reset_index(drop = True)
        
        training_set = Data_Encoder_FAST(train_unique.HADM_ID.values, train_unique.DEATH_90.values, df_train)
        training_generator = data.DataLoader(training_set, **params)

        validation_set = Data_Encoder_FAST(val_unique.HADM_ID.values, val_unique.DEATH_90.values, df_val)
        validation_generator = data.DataLoader(validation_set, **params)
    
        testing_set = Data_Encoder_FAST(test_unique.HADM_ID.values, test_unique.DEATH_90.values, df_test)
        testing_generator = data.DataLoader(testing_set, **params)
    else:
        print("Please modify the label value for your own downstream prediction task.")
    
    
    opt = torch.optim.Adam(model.parameters(), lr = lr)
    # early stopping
    max_auc = 0
    model_max = copy.deepcopy(model)
   
    print('--- Go for Training ---')
    torch.backends.cudnn.benchmark = True
    for epo in range(train_epoch):
        model.train()
        for i, (output, label) in enumerate(training_generator):
            score = model(output.cuda())
       
            label = Variable(torch.from_numpy(np.array(label)).float()).cuda()
            
            loss_fct = torch.nn.BCELoss()
            m = torch.nn.Sigmoid()
            n = torch.squeeze(m(score))
            
            loss = loss_fct(n, label)
            loss_history.append(loss)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
           
        # every epoch test
        with torch.set_grad_enabled(False):
            auc, auprc, logits, loss = test_finetune(validation_generator, model)
            if auc > max_auc:
                model_max = copy.deepcopy(model)
                max_auc = auc
                 
            print('Validation at Epoch '+ str(epo + 1) + ' , AUROC: '+ str(auc) + ' , AUPRC: ' + str(auprc))
    
    print('--- Go for Testing ---')
    try:
        with torch.set_grad_enabled(False):
            auc, auprc, logits, loss = test_finetune(testing_generator, model_max)
            print('Testing AUROC: ' + str(auc) + ' , AUPRC: ' + str(auprc) + ' , Test loss: '+str(loss))
    except:
        print('testing failed')
    return model_max, loss_history

In [None]:
args = argparse.Namespace()
args.Learning_Rate_Finetune = 2e-5
args.Batch_Size_Finetune = 128
args.Training_Epoch_Finetune = 20
dataFolder = 'FIXME'
args.prediction_label = 'delta_in_2_days'
model_max, loss_history =  main_finetune(args.Learning_Rate_Finetune, args.Batch_Size_Finetune, args.Training_Epoch_Finetune, dataFolder, args.prediction_label)