In [40]:
import numpy as np
import pandas as pd
import sklearn
import torch
import torch.nn as nn
import torch.nn.functional as F

# FILE_NAME_SUFFIX = ".multi_admission" 
FILE_NAME_SUFFIX = "_1000"
# FILE_NAME_SUFFIX = "" 
DATASET_PATH = "Datasets/"

# !ls $DATASET_PATH

def read_samples(input_file, file_name_suffix):
    file_name = DATASET_PATH + input_file + file_name_suffix
    return pd.read_csv(file_name, error_bad_lines=False, keep_default_na=False)

In [41]:
patients = read_samples("PATIENTS.csv", FILE_NAME_SUFFIX)
admissions = read_samples("ADMISSIONS.csv", FILE_NAME_SUFFIX)
diagnoses = read_samples("DIAGNOSES_ICD.csv", FILE_NAME_SUFFIX)
icu_stays = read_samples("ICUSTAYS.csv", FILE_NAME_SUFFIX)
procedures = read_samples("PROCEDURES_ICD.csv", FILE_NAME_SUFFIX)
notes = read_samples("NOTEEVENTS.csv", FILE_NAME_SUFFIX)

In [42]:
# print("Patients:")
# print(patients.head(5))
# print("Admissions:")
# print(admissions.head(5))
# print("Diagnoses:")
# print(diagnoses.head(5))
# print("ICU stays:")
# print(icu_stays.head(5))
# print("Procedures:")
# print(procedures.head(5))

In [43]:
print(admissions.groupby('ADMISSION_TYPE').size())

ADMISSION_TYPE
ELECTIVE     137
EMERGENCY    845
NEWBORN      220
URGENT        40
dtype: int64


In [44]:
admissions = admissions[admissions['ADMISSION_TYPE']!='NEWBORN']

In [45]:
# admissions.ADMITTIME = pd.to_datetime(admissions.ADMITTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')
admissions['ADMIT_DATE'] = admissions.ADMITTIME.apply(lambda x: str(x).split(' ')[0])
admissions['ADMIT_DATE'] = pd.to_datetime(admissions.ADMIT_DATE, format = '%Y-%m-%d', errors = 'coerce')
# admissions.DISCHTIME = pd.to_datetime(admissions.DISCHTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')
admissions['DISCH_DATE'] = admissions.DISCHTIME.apply(lambda x: str(x).split(' ')[0])
admissions['DISCH_DATE'] = pd.to_datetime(admissions.DISCH_DATE, format = '%Y-%m-%d', errors = 'coerce')
admissions.DEATHTIME = pd.to_datetime(admissions.DEATHTIME, format='%Y-%m-%d %H:%M:%S', errors='coerce')

In [46]:
print(admissions[['SUBJECT_ID','ADMITTIME','ADMIT_DATE','DISCHTIME','DISCH_DATE']])

      SUBJECT_ID            ADMITTIME ADMIT_DATE            DISCHTIME  \
0             22  2196-04-09 12:26:00 2196-04-09  2196-04-10 15:54:00   
1             23  2153-09-03 07:15:00 2153-09-03  2153-09-08 19:10:00   
2             23  2157-10-18 19:34:00 2157-10-18  2157-10-25 14:00:00   
3             24  2139-06-06 16:14:00 2139-06-06  2139-06-09 12:48:00   
4             25  2160-11-02 02:06:00 2160-11-02  2160-11-05 14:55:00   
...          ...                  ...        ...                  ...   
1231         961  2101-11-13 00:41:00 2101-11-13  2101-11-24 15:00:00   
1233         963  2195-04-24 14:50:00 2195-04-24  2195-04-30 13:27:00   
1234         963  2204-10-25 16:42:00 2204-10-25  2204-11-04 13:55:00   
1240         969  2151-09-23 14:35:00 2151-09-23  2151-09-30 16:03:00   
1241         969  2162-05-03 15:31:00 2162-05-03  2162-05-11 14:24:00   

     DISCH_DATE  
0    2196-04-10  
1    2153-09-08  
2    2157-10-25  
3    2139-06-09  
4    2160-11-05  
...         ...

In [47]:
admissions = admissions.sort_values(['SUBJECT_ID','ADMITTIME'])
admissions = admissions.reset_index(drop = True)
admissions['NEXT_ADMIT_DATE'] = admissions.groupby('SUBJECT_ID').ADMIT_DATE.shift(-1)
admissions['NEXT_ADMISSION_TYPE'] = admissions.groupby('SUBJECT_ID').ADMISSION_TYPE.shift(-1)

In [48]:
print(admissions[['SUBJECT_ID','ADMITTIME','NEXT_ADMIT_DATE','NEXT_ADMISSION_TYPE']])

      SUBJECT_ID            ADMITTIME NEXT_ADMIT_DATE NEXT_ADMISSION_TYPE
0              3  2101-10-20 19:08:00             NaT                 NaN
1              4  2191-03-16 00:28:00             NaT                 NaN
2              6  2175-05-30 07:15:00             NaT                 NaN
3              9  2149-11-09 13:06:00             NaT                 NaN
4             11  2178-04-16 06:18:00             NaT                 NaN
...          ...                  ...             ...                 ...
1017         998  2152-06-26 16:22:00      2153-09-05            ELECTIVE
1018         998  2153-09-05 09:00:00      2153-10-07           EMERGENCY
1019         998  2153-10-07 13:57:00             NaT                 NaN
1020         999  2119-06-04 21:36:00             NaT                 NaN
1021        1000  2144-01-19 20:15:00             NaT                 NaN

[1022 rows x 4 columns]


In [49]:
rows = admissions.NEXT_ADMISSION_TYPE == 'ELECTIVE'
admissions.loc[rows,'NEXT_ADMIT_DATE'] = pd.NaT
admissions.loc[rows,'NEXT_ADMISSION_TYPE'] = np.NaN

In [50]:
admissions = admissions.sort_values(['SUBJECT_ID','ADMITTIME'])

In [51]:
# Back fill for a patient for which we removed the next admission as ELECTIVE
admissions[['NEXT_ADMIT_DATE','NEXT_ADMISSION_TYPE']] = admissions.groupby(['SUBJECT_ID'])[['NEXT_ADMIT_DATE','NEXT_ADMISSION_TYPE']].fillna(method = 'bfill')

In [52]:
# Calculate days until next admission
admissions['DAYS_NEXT_ADMIT'] = (admissions.NEXT_ADMIT_DATE - admissions.DISCH_DATE)


In [53]:
print(admissions[['SUBJECT_ID','DISCH_DATE','NEXT_ADMIT_DATE','DAYS_NEXT_ADMIT']])

      SUBJECT_ID DISCH_DATE NEXT_ADMIT_DATE DAYS_NEXT_ADMIT
0              3 2101-10-31             NaT             NaT
1              4 2191-03-23             NaT             NaT
2              6 2175-06-15             NaT             NaT
3              9 2149-11-14             NaT             NaT
4             11 2178-05-11             NaT             NaT
...          ...        ...             ...             ...
1017         998 2152-06-30      2153-10-07        464 days
1018         998 2153-09-18      2153-10-07         19 days
1019         998 2153-10-23             NaT             NaT
1020         999 2119-06-15             NaT             NaT
1021        1000 2144-02-25             NaT             NaT

[1022 rows x 4 columns]


In [54]:
admissions['OUTPUT_LABEL'] = (admissions.DAYS_NEXT_ADMIT < pd.Timedelta(days=30)).astype('int')

In [55]:
print(admissions[['SUBJECT_ID','DAYS_NEXT_ADMIT','OUTPUT_LABEL']])

      SUBJECT_ID DAYS_NEXT_ADMIT  OUTPUT_LABEL
0              3             NaT             0
1              4             NaT             0
2              6             NaT             0
3              9             NaT             0
4             11             NaT             0
...          ...             ...           ...
1017         998        464 days             0
1018         998         19 days             1
1019         998             NaT             0
1020         999             NaT             0
1021        1000             NaT             0

[1022 rows x 3 columns]


In [56]:
admissions['DURATION'] = (admissions['DISCH_DATE']- admissions['ADMIT_DATE'])

In [57]:
print(admissions[['SUBJECT_ID','ADMIT_DATE','DISCH_DATE','DURATION']])

      SUBJECT_ID ADMIT_DATE DISCH_DATE DURATION
0              3 2101-10-20 2101-10-31  11 days
1              4 2191-03-16 2191-03-23   7 days
2              6 2175-05-30 2175-06-15  16 days
3              9 2149-11-09 2149-11-14   5 days
4             11 2178-04-16 2178-05-11  25 days
...          ...        ...        ...      ...
1017         998 2152-06-26 2152-06-30   4 days
1018         998 2153-09-05 2153-09-18  13 days
1019         998 2153-10-07 2153-10-23  16 days
1020         999 2119-06-04 2119-06-15  11 days
1021        1000 2144-01-19 2144-02-25  37 days

[1022 rows x 4 columns]


In [58]:

# notes = notes[notes.HADM_ID.isin(admissions.HADM_ID)]
# notes['HADM_ID'] = notes['HADM_ID'].astype(int)
# notes = notes.sort_values(by=['SUBJECT_ID','HADM_ID','CHARTDATE'])
# notes['CHARTDATE'] = pd.to_datetime(notes.CHARTDATE, format = '%Y-%m-%d', errors = 'coerce')
# print(notes.dtypes)
# print(admissions.dtypes)

In [59]:
# admissions_notes = pd.merge(admissions[['SUBJECT_ID','HADM_ID','ADMIT_DATE','DISCH_DATE','DAYS_NEXT_ADMIT','NEXT_ADMIT_DATE','ADMISSION_TYPE','DEATHTIME','OUTPUT_LABEL','DURATION']],
#                         notes[['SUBJECT_ID','HADM_ID','CHARTDATE','TEXT','CATEGORY']],
#                         on = ['SUBJECT_ID','HADM_ID'],
#                         how = 'left')

In [60]:
print(admissions.head)

<bound method NDFrame.head of       ROW_ID  SUBJECT_ID  HADM_ID            ADMITTIME            DISCHTIME  \
0          2           3   145834  2101-10-20 19:08:00  2101-10-31 13:58:00   
1          3           4   185777  2191-03-16 00:28:00  2191-03-23 18:41:00   
2          5           6   107064  2175-05-30 07:15:00  2175-06-15 16:00:00   
3          8           9   150750  2149-11-09 13:06:00  2149-11-14 10:15:00   
4         10          11   194540  2178-04-16 06:18:00  2178-05-11 19:00:00   
...      ...         ...      ...                  ...                  ...   
1017    1238         998   166191  2152-06-26 16:22:00  2152-06-30 20:06:00   
1018    1239         998   171544  2153-09-05 09:00:00  2153-09-18 15:30:00   
1019    1240         998   149668  2153-10-07 13:57:00  2153-10-23 17:00:00   
1020    1241         999   173415  2119-06-04 21:36:00  2119-06-15 11:25:00   
1021    1242        1000   143040  2144-01-19 20:15:00  2144-02-25 06:05:00   

               DEATHT

In [61]:
admissions.OUTPUT_LABEL.value_counts()

0    936
1     86
Name: OUTPUT_LABEL, dtype: int64

In [62]:
len(admissions.groupby('SUBJECT_ID')['SUBJECT_ID'].unique())

730

In [76]:
subject_labels = admissions[['SUBJECT_ID', 'OUTPUT_LABEL']]
print(subject_labels)

      SUBJECT_ID  OUTPUT_LABEL
0              3             0
1              4             0
2              6             0
3              9             0
4             11             0
...          ...           ...
1017         998             0
1018         998             1
1019         998             0
1020         999             0
1021        1000             0

[1022 rows x 2 columns]


In [77]:
subject_labels = subject_labels.groupby('SUBJECT_ID')[['OUTPUT_LABEL']].sum().reset_index()
print(subject_labels)

     SUBJECT_ID  OUTPUT_LABEL
0             3             0
1             4             0
2             6             0
3             9             0
4            11             0
..          ...           ...
725         994             0
726         995             0
727         998             1
728         999             0
729        1000             0

[730 rows x 2 columns]


In [78]:
subject_labels.rename(columns={"OUTPUT_LABEL":"OUTPUT_LABELS_SUMMED"}, inplace=True)
print(subject_labels)

     SUBJECT_ID  OUTPUT_LABELS_SUMMED
0             3                     0
1             4                     0
2             6                     0
3             9                     0
4            11                     0
..          ...                   ...
725         994                     0
726         995                     0
727         998                     1
728         999                     0
729        1000                     0

[730 rows x 2 columns]


In [79]:
subject_labels['OUTPUT_LABEL'] = (subject_labels['OUTPUT_LABELS_SUMMED'] >= 1).astype(int)
print(subject_labels)

     SUBJECT_ID  OUTPUT_LABELS_SUMMED  OUTPUT_LABEL
0             3                     0             0
1             4                     0             0
2             6                     0             0
3             9                     0             0
4            11                     0             0
..          ...                   ...           ...
725         994                     0             0
726         995                     0             0
727         998                     1             1
728         999                     0             0
729        1000                     0             0

[730 rows x 3 columns]


In [80]:
subject_labels.drop(columns=['OUTPUT_LABELS_SUMMED'], inplace=True)

In [81]:
print(subject_labels)

     SUBJECT_ID  OUTPUT_LABEL
0             3             0
1             4             0
2             6             0
3             9             0
4            11             0
..          ...           ...
725         994             0
726         995             0
727         998             1
728         999             0
729        1000             0

[730 rows x 2 columns]


In [21]:
seqs = []
readmission = []
types = {}
for patient in patients.itertuples():
#     print(patient.ROW_ID)
#     print(patient.SUBJECT_ID)
    patient_admissions = []
    for patient_admission in previous_admissions[previous_admissions.SUBJECT_ID == patient.SUBJECT_ID].itertuples():
        icd9_codes = []
#         print(patient_admission.HADM_ID)
#         newdf = df[(df.origin == "JFK") & (df.carrier == "B6")]
        diagnoses_filtered = diagnoses[(diagnoses.SUBJECT_ID == patient.SUBJECT_ID) & (diagnoses.HADM_ID == patient_admission.HADM_ID)]
        procedures_filtered = procedures[(procedures.SUBJECT_ID == patient.SUBJECT_ID) & (procedures.HADM_ID == patient_admission.HADM_ID)]
#         print(diagnoses_filtered)
        for admission_diagnosis in diagnoses_filtered.itertuples():
            if admission_diagnosis.ICD9_CODE in types:
                icd9_codes.append(types[admission_diagnosis.ICD9_CODE])
            else:
                types[admission_diagnosis.ICD9_CODE] = len(types)
                icd9_codes.append(types[admission_diagnosis.ICD9_CODE])
        for admission_procedures in procedures_filtered.itertuples():
            if admission_procedures.ICD9_CODE in types:
                icd9_codes.append(types[admission_procedures.ICD9_CODE])
            else:
                types[admission_procedures.ICD9_CODE] = len(types)
                icd9_codes.append(types[admission_procedures.ICD9_CODE])
#             print(admission_diagnosis.SEQ_NUM)
#             print(icd9_codes)
        patient_admissions.append(icd9_codes)
    readmission.append(patient.EXPIRE_FLAG)
    seqs.append(patient_admissions)
# print(patient_morts)    

NameError: name 'previous_admissions' is not defined

In [None]:
# print(np.array(seq).shape)
# print(seqs)
# print("Mortality:", morts[335])
# for visit in range(len(seqs[335])):
#     print(f"\t{visit}-th admission diagnosis labels:", seqs[335][visit])
# print(f"admission diagnosis labels:", seqs[335])

In [None]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, seqs, morts):
        
        self.x = seqs
        self.y = morts
    
    def __len__(self):
        
        return len(self.y)
    
    def __getitem__(self, index):
        
        #Tuple of (seq,label) where seq is seq[i][j][k] and label is mortality
        label = self.y[index]
        sequence = self.x[index]
        return(sequence,label)
        

dataset = CustomDataset(seqs, morts)

In [None]:
def collate_fn(data):
    sequences, labels = zip(*data)
    import copy

    max_visits = 0
    max_codes = 0
    x = []
    masks = []
    
    for patient_visits in sequences:
        max_visits = max(len(patient_visits),max_visits)
        for patient_visit_codes in patient_visits:
            max_codes = max(len(patient_visit_codes),max_codes)
            
#     print(max_visits)
#     print(max_codes) 

    for patient_visits in sequences:
        
        patient_masks = []
        patient_visits_c = []
        
        for patient_single_visit_codes in patient_visits:
            mask = [1]*len(patient_single_visit_codes)
            patient_single_visit_codes_c = copy.deepcopy(patient_single_visit_codes)
#             print("mask before")
#             print(mask)
            if len(patient_single_visit_codes) < max_codes:
                padding = max_codes - len(patient_single_visit_codes)
#                 print(patient_single_visit_codes)
                patient_single_visit_codes_c += [0] * padding
                mask += [0] * padding
#             print("mask after")
#             print(mask)
#             print(print(patient_single_visit_codes))
            patient_visits_c.append(patient_single_visit_codes_c)
            patient_masks.append(mask)
#                 print(patient_visit_codes)   
#             print("patient_masks")
#             print(patient_masks)
        
#         print(patient_visits)
        
        if len(patient_visits) < max_visits:
            for i in range (0, (max_visits - len(patient_visits))):        
                patient_visits_c.append(([0] * max_codes))
                patient_masks.append(([0] * max_codes))                
#         print(patient_visits)

        x.append(patient_visits_c)
        masks.append(patient_masks)
    
#     print("masks") 
#     print(masks)
#     print("x")
#     print(x)
#     print(sequences)
    x = torch.Tensor(x).long()
    masks = torch.Tensor(masks).bool()
    
    y = torch.Tensor(labels).float()
          
    return x, masks, y    

In [None]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

In [None]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    # your code here
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset,batch_size=32, collate_fn=collate_fn)

    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)
# print(len(train_loader))
# print(len(val_loader))

In [None]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        

    """
    
    # your code here
#     print(x[1])
#     print(masks[1])
    x_copy = x.clone()
    x_copy[masks==False] = 0
    sum_embeddings = torch.sum(x_copy,2)
#     print(sum_embeddings)
    return sum_embeddings

In [None]:
def get_last_visit(hidden_states, masks):
    batch_size, visits, embedding_dim = hidden_states.shape
    masks = torch.sum(masks, 2)
    masks = torch.min(masks, torch.ones_like(masks))
    masks = torch.sum(masks, 1)
    masks = masks - torch.ones_like(masks)
    masks = masks.unsqueeze(1).expand(batch_size, embedding_dim).unsqueeze(1)
    masks = torch.max(masks, torch.zeros_like(masks)) 
    last_visit = torch.gather(hidden_states, 1, masks)
    last_visit = torch.flatten(last_visit, 1, 2)
    return last_visit

In [None]:
class NaiveRNN(nn.Module):

    
    def __init__(self, num_codes):
        super().__init__()

        self.em = nn.Embedding(num_embeddings = num_codes, embedding_dim = 128)
        
        self.rnn = nn.GRU(input_size = 128, hidden_size = 128, batch_first = True)
        
        self.fc = nn.Linear(128, 1)
        
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks):

        x = self.em(x)
        x = sum_embeddings_with_mask(x, masks)
        x , _ = self.rnn(x)
        x = get_last_visit(x, masks)
        x = self.fc(x)
        x = self.sigmoid(x)
        x = x.view(-1)
        return x
    

# load the model here
naive_rnn = NaiveRNN(num_codes = len(diagnoses))
naive_rnn

In [None]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=0.001)


In [None]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, val_loader):

    model.eval()
    val_labels = []
    val_probs = []
    
    for step, batch in enumerate(val_loader):
        x, masks, labels = batch
        
        with torch.no_grad():
            
            probs = model(x, masks)
            val_labels.extend(labels.detach().tolist())
            val_probs.extend(probs.detach().numpy().reshape(-1).tolist())
            
    precision, recall, f1, _ = precision_recall_fscore_support(val_labels, np.array(val_probs) > 0.5, average='binary')
    roc_auc = roc_auc_score(val_labels, val_probs)
    
    print(f"roc_auc:{roc_auc:3f}, precision:{precision:.3f},recall:{recall:3f},f1:{f1:3f}")
    return precision, recall, f1, roc_auc

In [None]:
def train(model, train_loader, val_loader, n_epochs):

    model.train()
    
    for epoch in range(n_epochs):
        train_loss = 0
        for step, batch in enumerate(train_loader):
            x, masks, labels = batch
            
            y_hat = model.forward(x, masks)
            optimizer.zero_grad()
            loss = criterion(y_hat, labels)
            loss.backward()
            
            optimizer.step()
            train_loss+=loss.item()
            
        train_loss = train_loss/len(train_loader)
#         print(train_loss)
        eval_model(model, val_loader)

    
# number of epochs to train the model
n_epochs = 50
train(naive_rnn, train_loader, val_loader, n_epochs)