In [8]:
from typing import *
import pandas as pd
import os
import torch.nn as nn
import torch
from transformers import LongformerTokenizerFast, LongformerForSequenceClassification, Trainer, TrainingArguments, LongformerConfig, T5Config
from datasets import Dataset
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import wandb
import argparse
from functools import partial

# FIXME import below need path from root
from helpers import find_next_dir_index , explode_train_target, convert_to_list, merge_embs_to_seq, compute_metrics, group_train_val
from MyTrainer import MyTrainer 
from BoXHED_Fuse.models.ClinicalLSTM import ClinicalLSTM
# from ..models.ClinicalLSTM import ClinicalLSTM



# ===== Initialize Args ===== 
# parser = argparse.ArgumentParser()
# parser.add_argument('--test', action='store_true', help='enable testing mode')
# parser.add_argument('--use-wandb', action = 'store_true', help = 'enable wandb', default=False)
# parser.add_argument('--gpu-no', dest = 'GPU_NO', help='use GPU_NO specified (this may be a single number or several. eg: 1 or 1,2,3,4)')
# parser.add_argument('--note-type', dest = 'note_type', help='which notes, radiology or discharge?')
# parser.add_argument('--num-epochs', dest = 'num_epochs', help = 'num_epochs to train')
# parser.add_argument('--noteid-mode', dest = 'noteid_mode', help = 'kw: all or recent')
# args = parser.parse_args()

args = argparse.Namespace
args.test = True
args.GPU_NO = -1
args.note_type = 'radiology'
args.num_epochs = 1
args.noteid_mode = 'all'
args.model_name = 'LSTM'

args.num_epochs = int(args.num_epochs)

model_name_ft1 = 'Clinical-T5-Base'
train_embs_path = f'{os.getenv("BHF_ROOT")}/JSS_SUBMISSION_NEW/data/embs{"/testing" if args.test else ""}/{model_name_ft1}_{args.note_type[:3]}_{args.noteid_mode}_out/from_epoch1/10/train_embs.pt'
train_target_path = f'{os.getenv("BHF_ROOT")}/JSS_SUBMISSION_NEW/data/targets{"/testing" if args.test else ""}/till_end_mimic_iv_extra_features_train_NOTE_TARGET_2_{args.note_type[:3]}_{args.noteid_mode}.csv'
model_name = 'clinical_lstm'
model_out_dir = f'{os.getenv("BHF_ROOT")}/model_outputs/{model_name}_{args.note_type[:3]}_{args.noteid_mode}_out'


In [13]:
def validate_train_embseq(train_embseq, train_target):
    train_embseq = train_embseq.copy()
    train_embseq['emb_seq_len'] = train_embseq['emb_seq'].apply(len)
    train_target['NOTE_ID_SEQ_len'] = train_target['NOTE_ID_SEQ'].apply(len)
    assert((train_target['NOTE_ID_SEQ_len'] == train_embseq['emb_seq_len']).all())


In [14]:

if not os.path.exists(model_out_dir):
    os.makedirs(model_out_dir)
run_cntr = find_next_dir_index(model_out_dir)
model_out_dir = os.path.join(model_out_dir, str(run_cntr))
assert(not os.path.exists(model_out_dir))
# os.makedirs(model_out_dir) # FIXME
print(f'created all dirs in model_out_dir', model_out_dir)

if args.test:
    # train_target_path = os.path.join(os.path.dirname(train_target_path), 'testing', os.path.basename(train_target_path))
    model_out_dir = os.path.join(os.path.dirname(model_out_dir), 'testing', os.path.basename(model_out_dir))


# ===== Read Data =====
train_embs = torch.load(train_embs_path)
train_target = pd.read_csv(train_target_path, converters = {'NOTE_ID_SEQ': convert_to_list})
target = 'delta_in_2_days' # FIXME
train_target.rename(columns = {target:'label'}, inplace=True)
# ===== Merge data into {note_embs_seq, label}, where note_seq is a list of embs =====
train_target_exploded = explode_train_target(train_target)
train_embs_df = pd.DataFrame()
train_embs_df['emb'] = [np.array(e) for e in train_embs]
train_embs_df = pd.concat([train_target_exploded, train_embs_df], axis=1)
train_embseq = merge_embs_to_seq(train_target, train_embs_df)


created all dirs in model_out_dir /home/ugrads/a/aa_ron_su/BoXHED_Fuse/BoXHED_Fuse//model_outputs/clinical_lstm_rad_all_out/7


In [15]:
# ===== Validate =====

validate_train_embseq(train_embseq, train_target)

In [34]:
# ===== Prepare train val split =====
train_idxs, val_idxs = group_train_val(train_embseq['ICUSTAY_ID'])
train_data = train_embseq.iloc[train_idxs]
val_data = train_embseq.iloc[val_idxs]
train_data = Dataset.from_pandas(train_data).select_columns(['emb_seq', 'label'])
val_data = Dataset.from_pandas(val_data).select_columns(['emb_seq', 'label'])

train_data.set_format('torch', columns=['emb_seq', 'label'])
val_data.set_format('torch', columns=['emb_seq', 'label'])

# ===== Train LSTM ===== 

clin_lstm = ClinicalLSTM()
# main_finetune()



# ===== Save Sequential Embeddings =====
# extract_emb_seq()



In [16]:
from torch.utils import data

# maximum sequence length
max_num_notes = 32
doc_emb_size = 64 # 768
    
class Data_Encoder_FAST(data.Dataset):

    def __init__(self, list_IDs, labels, df):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs
        self.df = df

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        # Load data and get label        
        y = self.labels[index]
        index = self.list_IDs[index]
        doc_seqs = torch.cat(list(self.df[self.df.HADM_ID == index].DOC_EMB.values), 0)

        xlnet_outputs = torch.zeros(size=(max_num_notes, doc_emb_size), dtype=torch.float)        
        if len(doc_seqs) > max_num_notes:
            xlnet_outputs[:max_num_notes] = doc_seqs[:max_num_notes]
        else:
            xlnet_outputs[:len(doc_seqs)] = doc_seqs

        return xlnet_outputs.cuda(), y

In [1]:
import copy

def main_finetune(lr, batch_size, train_epoch, dataFolder, prediction_label):
    '''
    finetune the model. Here, we are finetuning the LSTM!

    
    '''
    lr = lr
    BATCH_SIZE = batch_size
    train_epoch = train_epoch
    
    loss_history = []
    
    model = ClinicalLSTM()
    model.cuda()
    
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, dim = 0)
            
    print('--- Data Preparation ---')
    
    params = {'batch_size': BATCH_SIZE,
              'shuffle': True,
              'num_workers': 0, 
              'drop_last': True}

    # df_train = pd.read_csv(dataFolder + '/train.csv')
    # df_val = pd.read_csv(dataFolder + '/val.csv')
    # df_test = pd.read_csv(dataFolder + '/test.csv')
    
    # doc_train = torch.load(dataFolder + '/train_doc_emb.pt')
    # doc_val = torch.load(dataFolder + '/val_doc_emb.pt')
    # doc_test = torch.load(dataFolder + '/test_doc_emb.pt')
    
    def doc_emb_to_df(doc, df):
        output_seq = [torch.unsqueeze(doc[i],0) for i in range(doc.shape[0])]
        df = df.assign(DOC_EMB = output_seq)
        return df
    
    df_train = doc_emb_to_df(doc_train, df_train)
    df_val = doc_emb_to_df(doc_val, df_val)
    df_test = doc_emb_to_df(doc_test, df_test)
    
    if prediction_label == 'PMV':
        train_unique = df_train[['HADM_ID','Label']].drop_duplicates().reset_index(drop = True)
        val_unique = df_val[['HADM_ID','Label']].drop_duplicates().reset_index(drop = True)
        test_unique = df_test[['HADM_ID','Label']].drop_duplicates().reset_index(drop = True)

        training_set = Data_Encoder_FAST(train_unique.HADM_ID.values, train_unique.Label.values, df_train)
        training_generator = data.DataLoader(training_set, **params)
        breakpoint() # dimensions of training_set? dataloader? 

        validation_set = Data_Encoder_FAST(val_unique.HADM_ID.values, val_unique.Label.values, df_val)
        validation_generator = data.DataLoader(validation_set, **params)

        testing_set = Data_Encoder_FAST(test_unique.HADM_ID.values, test_unique.Label.values, df_test)
        testing_generator = data.DataLoader(testing_set, **params)
    
    elif prediction_label == 'Mortality':
        train_unique = df_train[['HADM_ID','DEATH_90']].drop_duplicates().reset_index(drop = True)
        val_unique = df_val[['HADM_ID','DEATH_90']].drop_duplicates().reset_index(drop = True)
        test_unique = df_test[['HADM_ID','DEATH_90']].drop_duplicates().reset_index(drop = True)
        
        training_set = Data_Encoder_FAST(train_unique.HADM_ID.values, train_unique.DEATH_90.values, df_train)
        training_generator = data.DataLoader(training_set, **params)

        validation_set = Data_Encoder_FAST(val_unique.HADM_ID.values, val_unique.DEATH_90.values, df_val)
        validation_generator = data.DataLoader(validation_set, **params)
    
        testing_set = Data_Encoder_FAST(test_unique.HADM_ID.values, test_unique.DEATH_90.values, df_test)
        testing_generator = data.DataLoader(testing_set, **params)
    else:
        print("Please modify the label value for your own downstream prediction task.")
    
    
    opt = torch.optim.Adam(model.parameters(), lr = lr)
    # early stopping
    max_auc = 0
    model_max = copy.deepcopy(model)
   
    print('--- Go for Training ---')
    torch.backends.cudnn.benchmark = True
    for epo in range(train_epoch):
        model.train()
        for i, (output, label) in enumerate(training_generator):
            breakpoint() # what is the output? How does the LSTM take in data?
            score = model(output.cuda())
       
            label = Variable(torch.from_numpy(np.array(label)).float()).cuda()
            
            loss_fct = torch.nn.BCELoss()
            m = torch.nn.Sigmoid()
            n = torch.squeeze(m(score))
            
            loss = loss_fct(n, label)
            loss_history.append(loss)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
           
        # every epoch test
        with torch.set_grad_enabled(False):
            auc, auprc, logits, loss = test_finetune(validation_generator, model)
            if auc > max_auc:
                model_max = copy.deepcopy(model)
                max_auc = auc
                 
            print('Validation at Epoch '+ str(epo + 1) + ' , AUROC: '+ str(auc) + ' , AUPRC: ' + str(auprc))
    
    print('--- Go for Testing ---')
    try:
        with torch.set_grad_enabled(False):
            auc, auprc, logits, loss = test_finetune(testing_generator, model_max)
            print('Testing AUROC: ' + str(auc) + ' , AUPRC: ' + str(auprc) + ' , Test loss: '+str(loss))
    except:
        print('testing failed')
    return model_max, loss_history

In [30]:
train_data['emb_seq'][0].shape

torch.Size([45, 64])

In [None]:
train_embseq.ICUSTAY_ID.nunique()

13

In [27]:
(train_embseq.emb_seq.iloc[0])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,54,55,56,57,58,59,60,61,62,63
0,0.045718,0.02955,-0.006603,-0.030292,0.037942,-0.016469,-0.094294,-0.01997,-0.014836,0.082431,...,-0.044541,0.055484,-0.101424,-0.070186,0.006019,0.02989,0.090776,0.048945,0.032682,0.0251
1,0.014879,0.053706,-0.026616,-0.055694,0.026509,-0.09789,-0.020897,-0.097928,-0.014248,0.010718,...,-0.011992,-0.055746,0.078403,0.005336,-0.027869,-0.003987,0.028131,0.107723,0.034473,0.030731
2,-0.025563,0.094409,-0.040374,-0.070944,0.156529,-0.005459,-0.025024,-0.080697,-0.032454,0.163034,...,-0.018262,-0.037139,0.01197,-0.074624,0.048606,0.013519,0.093437,0.066523,-0.047791,-0.023772
3,0.034758,0.04123,0.053005,-0.012524,-0.050618,0.050894,-0.061397,-0.112498,0.039355,0.024527,...,-0.139965,0.122681,-0.043349,-0.016505,0.013531,0.004904,0.054962,0.005899,0.011502,0.007898
4,-0.050177,0.03969,-0.013134,-0.057276,0.044675,0.002194,0.001921,-0.077664,-0.065639,-0.0312,...,-0.000984,-0.005697,0.05384,0.021765,0.003299,0.047243,0.057308,0.103604,-0.004837,0.013525
5,-0.014579,0.068516,-0.006934,-0.046155,0.016,0.005347,-0.097081,-0.084453,0.020215,0.074538,...,-0.046301,0.034606,-0.050701,-0.110988,0.076473,0.027367,0.031817,0.073614,0.004667,-0.048693
6,-0.003845,0.025578,-0.016001,-0.068146,-0.084224,-0.011092,-0.076685,-0.05153,-0.051489,0.038842,...,-0.026712,-0.009303,0.015449,-0.022771,0.09602,0.043913,0.03223,0.028959,-0.101444,0.00662
7,0.001945,0.137261,0.076068,0.071644,0.054881,-0.108232,-0.139761,-0.130954,-0.088365,0.100763,...,-0.214328,-0.034124,-0.086951,-0.035603,-0.018621,0.015729,-0.053997,0.040064,-0.00053,0.042536
8,0.025607,0.021977,-0.025969,-0.02683,0.027689,0.028476,-0.121105,-0.104698,0.039678,0.127245,...,-0.100903,0.085193,-0.091519,-0.071606,0.011091,-0.02074,0.127078,0.006519,0.021186,0.070626
9,0.031743,0.113015,0.034682,-0.006114,-0.006546,-0.061236,-0.013636,-0.120102,-0.050542,0.134738,...,-0.140696,0.096582,-0.068511,-0.057512,0.016246,-0.024609,0.093552,0.073588,0.147881,-0.008034


In [None]:
train_target_exploded = explode_train_target(train_target)
train_embs_df = pd.DataFrame(train_embs.tolist())
train_embs_df = pd.concat([train_target_exploded, train_embs_df], axis=1)
train_embseq = merge_embs_to_seq(train_target, train_embs_df)

In [3]:
train_embs_df['NOTE_ID'] = train_target_exploded['NOTE_ID']

In [19]:
train_embseq.emb_seq.iloc[0].shape

(45, 64)

In [None]:
train_target_exploded = explode_train_target(train_target)
train_target_exploded

In [None]:
train_target_exploded