In [1]:
import numpy as np
import pandas as pd
import sklearn
import torch
import torch.nn as nn
import torch.nn.functional as F

FILE_NAME_SUFFIX = ".multi_admission" 
# FILE_NAME_SUFFIX = "_1000"
# FILE_NAME_SUFFIX = "" 
DATASET_PATH = "Datasets/"

# !ls $DATASET_PATH

def read_samples(input_file, file_name_suffix):
    file_name = DATASET_PATH + input_file + file_name_suffix
    return pd.read_csv(file_name, error_bad_lines=False, keep_default_na=False)

In [2]:
patients = read_samples("PATIENTS.csv", FILE_NAME_SUFFIX)
admissions = read_samples("ADMISSIONS.csv", FILE_NAME_SUFFIX)
diagnoses = read_samples("DIAGNOSES_ICD.csv", FILE_NAME_SUFFIX)
icu_stays = read_samples("ICUSTAYS.csv", FILE_NAME_SUFFIX)
procedures = read_samples("PROCEDURES_ICD.csv", FILE_NAME_SUFFIX)

In [3]:
# print("Patients:")
# print(patients.head(5))
# print("Admissions:")
# print(admissions.head(5))
# print("Diagnoses:")
# print(diagnoses.head(5))
# print("ICU stays:")
# print(icu_stays.head(5))
# print("Procedures:")
# print(procedures.head(5))

In [4]:
patients = patients.set_index("SUBJECT_ID", drop=False)
patients["num_admissions"] = admissions.groupby("SUBJECT_ID").size().to_frame("num_admissions")
patients = patients[patients.num_admissions > 1]

In [5]:
admissions = admissions[admissions.SUBJECT_ID.isin(patients.SUBJECT_ID)]
procedures = procedures[procedures.SUBJECT_ID.isin(patients.SUBJECT_ID)]
diagnoses = diagnoses[diagnoses.SUBJECT_ID.isin(patients.SUBJECT_ID)]
# icu_stays = icu_stays[icu_stays.SUBJECT_ID.isin(patients.SUBJECT_ID)]

In [6]:
last_admission = admissions[admissions.groupby(['SUBJECT_ID'])['ADMITTIME'].transform(max) == admissions['ADMITTIME']]
# print(last_admission.size)
# print(admissions.size)
previous_admissions = admissions[admissions.groupby(['SUBJECT_ID'])['ADMITTIME'].transform(max) != admissions['ADMITTIME']]
# print(previous_admissions.size)

In [7]:
seqs = []
morts = []
types = {}
for patient in patients.itertuples():
#     print(patient.ROW_ID)
#     print(patient.SUBJECT_ID)
    patient_admissions = []
    for patient_admission in previous_admissions[previous_admissions.SUBJECT_ID == patient.SUBJECT_ID].itertuples():
        icd9_codes = []
#         print(patient_admission.HADM_ID)
#         newdf = df[(df.origin == "JFK") & (df.carrier == "B6")]
        diagnoses_filtered = diagnoses[(diagnoses.SUBJECT_ID == patient.SUBJECT_ID) & (diagnoses.HADM_ID == patient_admission.HADM_ID)]
        procedures_filtered = procedures[(procedures.SUBJECT_ID == patient.SUBJECT_ID) & (procedures.HADM_ID == patient_admission.HADM_ID)]
#         print(diagnoses_filtered)
        for admission_diagnosis in diagnoses_filtered.itertuples():
            if admission_diagnosis.ICD9_CODE in types:
                icd9_codes.append(types[admission_diagnosis.ICD9_CODE])
            else:
                types[admission_diagnosis.ICD9_CODE] = len(types)
                icd9_codes.append(types[admission_diagnosis.ICD9_CODE])
        for admission_procedures in procedures_filtered.itertuples():
            if admission_procedures.ICD9_CODE in types:
                icd9_codes.append(types[admission_procedures.ICD9_CODE])
            else:
                types[admission_procedures.ICD9_CODE] = len(types)
                icd9_codes.append(types[admission_procedures.ICD9_CODE])
#             print(admission_diagnosis.SEQ_NUM)
#             print(icd9_codes)
        patient_admissions.append(icd9_codes)
    morts.append(patient.EXPIRE_FLAG)
    seqs.append(patient_admissions)
# print(patient_morts)    

In [8]:
# print(np.array(seq).shape)
# print(seqs)
# print("Mortality:", morts[335])
# for visit in range(len(seqs[335])):
#     print(f"\t{visit}-th admission diagnosis labels:", seqs[335][visit])
# print(f"admission diagnosis labels:", seqs[335])

In [9]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, seqs, morts):
        
        self.x = seqs
        self.y = morts
    
    def __len__(self):
        
        return len(self.y)
    
    def __getitem__(self, index):
        
        #Tuple of (seq,label) where seq is seq[i][j][k] and label is mortality
        label = self.y[index]
        sequence = self.x[index]
        return(sequence,label)
        

dataset = CustomDataset(seqs, morts)

In [10]:
def collate_fn(data):
    sequences, labels = zip(*data)
    import copy

    max_visits = 0
    max_codes = 0
    x = []
    masks = []
    
    for patient_visits in sequences:
        max_visits = max(len(patient_visits),max_visits)
        for patient_visit_codes in patient_visits:
            max_codes = max(len(patient_visit_codes),max_codes)
            
#     print(max_visits)
#     print(max_codes) 

    for patient_visits in sequences:
        
        patient_masks = []
        patient_visits_c = []
        
        for patient_single_visit_codes in patient_visits:
            mask = [1]*len(patient_single_visit_codes)
            patient_single_visit_codes_c = copy.deepcopy(patient_single_visit_codes)
#             print("mask before")
#             print(mask)
            if len(patient_single_visit_codes) < max_codes:
                padding = max_codes - len(patient_single_visit_codes)
#                 print(patient_single_visit_codes)
                patient_single_visit_codes_c += [0] * padding
                mask += [0] * padding
#             print("mask after")
#             print(mask)
#             print(print(patient_single_visit_codes))
            patient_visits_c.append(patient_single_visit_codes_c)
            patient_masks.append(mask)
#                 print(patient_visit_codes)   
#             print("patient_masks")
#             print(patient_masks)
        
#         print(patient_visits)
        
        if len(patient_visits) < max_visits:
            for i in range (0, (max_visits - len(patient_visits))):        
                patient_visits_c.append(([0] * max_codes))
                patient_masks.append(([0] * max_codes))                
#         print(patient_visits)

        x.append(patient_visits_c)
        masks.append(patient_masks)
    
#     print("masks") 
#     print(masks)
#     print("x")
#     print(x)
#     print(sequences)
    x = torch.Tensor(x).long()
    masks = torch.Tensor(masks).bool()
    
    y = torch.Tensor(labels).float()
          
    return x, masks, y    

In [11]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 6029
Length of val dataset: 1508


In [12]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    # your code here
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset,batch_size=32, collate_fn=collate_fn)

    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)
# print(len(train_loader))
# print(len(val_loader))

In [13]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        

    """
    
    # your code here
#     print(x[1])
#     print(masks[1])
    x_copy = x.clone()
    x_copy[masks==False] = 0
    sum_embeddings = torch.sum(x_copy,2)
#     print(sum_embeddings)
    return sum_embeddings

In [14]:
def get_last_visit(hidden_states, masks):
    batch_size, visits, embedding_dim = hidden_states.shape
    masks = torch.sum(masks, 2)
    masks = torch.min(masks, torch.ones_like(masks))
    masks = torch.sum(masks, 1)
    masks = masks - torch.ones_like(masks)
    masks = masks.unsqueeze(1).expand(batch_size, embedding_dim).unsqueeze(1)
    masks = torch.max(masks, torch.zeros_like(masks)) 
    last_visit = torch.gather(hidden_states, 1, masks)
    last_visit = torch.flatten(last_visit, 1, 2)
    return last_visit

In [15]:
class NaiveRNN(nn.Module):

    
    def __init__(self, num_codes):
        super().__init__()

        self.em = nn.Embedding(num_embeddings = num_codes, embedding_dim = 128)
        
        self.rnn = nn.GRU(input_size = 128, hidden_size = 128, batch_first = True)
        
        self.fc = nn.Linear(128, 1)
        
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks):

        x = self.em(x)
        x = sum_embeddings_with_mask(x, masks)
        x , _ = self.rnn(x)
        x = get_last_visit(x, masks)
        x = self.fc(x)
        x = self.sigmoid(x)
        x = x.view(-1)
        return x
    

# load the model here
naive_rnn = NaiveRNN(num_codes = len(diagnoses))
naive_rnn

NaiveRNN(
  (em): Embedding(260326, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [16]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)


In [17]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, val_loader):

    model.eval()
    val_labels = []
    val_probs = []
    
    for step, batch in enumerate(val_loader):
        x, masks, labels = batch
        
        with torch.no_grad():
            
            probs = model(x, masks)
            val_labels.extend(labels.detach().tolist())
            val_probs.extend(probs.detach().numpy().reshape(-1).tolist())
            
    precision, recall, f1, _ = precision_recall_fscore_support(val_labels, np.array(val_probs) > 0.5, average='binary')
    roc_auc = roc_auc_score(val_labels, val_probs)
    print(sum(val_labels))
    print(len(val_labels))
    
    print(f"roc_auc:{roc_auc:3f}, precision:{precision:.3f},recall:{recall:3f},f1:{f1:3f}")
    return precision, recall, f1, roc_auc

In [18]:
def train(model, train_loader, val_loader, n_epochs):


    model.train()
    
    for epoch in range(n_epochs):
        train_loss = 0
        for step, batch in enumerate(train_loader):
            x, masks, labels = batch
            
            y_hat = model.forward(x, masks)
            optimizer.zero_grad()
            loss = criterion(y_hat, labels)
            loss.backward()
            
            optimizer.step()
            train_loss+=loss.item()
            
        train_loss = train_loss/len(train_loader)
#         print(train_loss)
        eval_model(model, val_loader)

    
# number of epochs to train the model
n_epochs = 50
train(naive_rnn, train_loader, val_loader, n_epochs)

788.0
1508
roc_auc:0.716950, precision:0.663,recall:0.691624,f1:0.677019
788.0
1508
roc_auc:0.723858, precision:0.651,recall:0.717005,f1:0.682367
788.0
1508
roc_auc:0.730831, precision:0.666,recall:0.706853,f1:0.685961
788.0
1508
roc_auc:0.725130, precision:0.667,recall:0.695431,f1:0.680745


KeyboardInterrupt: 