## RNN

- Reference codes : Deep Learning for Healthcare, CS598, UIUC, Spring2022

In [2]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split


# set seed
seed = 100
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "./preprocessed/"

In [3]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pid.pkl'), 'rb'))
x_dem = pickle.load(open(os.path.join(DATA_PATH,'x_dem.pkl'), 'rb'))
x_per = pickle.load(open(os.path.join(DATA_PATH,'x_per_added.pkl'), 'rb'))
x_d = pickle.load(open(os.path.join(DATA_PATH,'x_d_added.pkl'), 'rb'))


In [4]:
print(len(pids))
print(len(x_dem))
print(len(x_per))
print(len(x_d))

46520
46520
46520
46520


In [5]:
print("Patient ID:", pids[0])
print("DemoGraphic:", x_dem[0])
print("Examination of first visit:", x_per[0][0])
print("Prescription of first visit:", x_d[0][0])


# for visit in range(len(vids[3])):
#     print(f"\t{visit}-th visit id:", vids[3][visit])
#     print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
#     print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 249
DemoGraphic: [0, 1, 1, 3, 0]
Examination of first visit: [9.8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Prescription of first visit: [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [6]:
### Leaving Only patients who have both of examination and prescriptions
dem_no_idx = set()
per_no_idx = set()
d_no_idx = set()

for idx in range(len(x_dem)):
    
    if len(x_dem[idx]) == 0:
        dem_no_idx.add(idx)
    
    if len(x_per[idx]) == 0:
        per_no_idx.add(idx)
    
    if len(x_d[idx]) == 0:
        d_no_idx.add(idx)

In [7]:
tot_no_idx = per_no_idx | d_no_idx
all_idx = set(range(len(x_dem)))
alive_idx = all_idx - tot_no_idx

alive_idx_list = list(alive_idx)
alive_idx_list.sort()


x_dem_f = [x_dem[i] for i in alive_idx_list]
x_per_f = [x_per[i] for i in alive_idx_list]
x_d_f = [x_d[i] for i in alive_idx_list]


In [8]:
print(len(x_dem_f))
print(len(x_per_f))
print(len(x_d_f))


19432
19432
19432


### Build Dataset

In [9]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, x_dem, x_per, x_d):
        
        self.x_dem = x_dem
        self.x_per = x_per
        self.x_d = x_d
    
    def __len__(self):
        
        return len(self.x_d)
    
    def __getitem__(self, index):
        
        x_dem = self.x_dem[index]
        x_per = self.x_per[index]
        x_d = self.x_d[index]
        return x_dem, x_per, x_d
        

dataset = CustomDataset(x_dem_f, x_per_f, x_d_f)

### Collate

In [10]:
num_patients = len(x_per_f) # 11619
num_visits = [len(patient) for patient in x_per_f] #
max_num_visits = max(num_visits)

num_codes = len(x_d_f[0][0])
num_per_codes = len(x_per_f[0][0])
print(max_num_visits)


146


In [11]:
## 1. (x_d) For all patients, Number of Visit (len(x_d)) must be same -> Need to be padded
## 2. (x_d) For all patients, Number of Prescription codes must be same -> Already Done
## 3. (x_per) For all patients, Number of Visit (len(x_per)) must be same -> Need to be padded
## 4. Mask -> 0 For padded intputs -> No Need

# def collate_fn(data):

#     x_dem_now, x_per_now, x_d_now = zip(*data)

#     x_dem_now = torch.tensor(x_dem_now[0], dtype=torch.float32)
#     x_per_now = x_per_now[0]
#     x_d_now = x_d_now[0]

#     # print(len(x_per_now))
#     # print(len(x_d_now))


#     # x_d_new = torch.zeros((1, max_num_visits, num_codes), dtype=torch.long)
#     # x_per_new = torch.zeros((1, max_num_visits, num_per_codes), dtype=torch.long)


#     ## (x_d) 
#     length_valid_visit = len(x_d_now)
#     # print("Valid Visit : {0}".format(length_valid_visit))
#     num_pad_visit = max_num_visits - length_valid_visit
#     pad = [0] * num_codes  # [0, ... , 0] : For 400 zeros
#     tmp = x_d_now.copy()
#     tmp.extend([pad] * num_pad_visit)
#     x_d_now = torch.tensor(tmp.copy(), dtype=torch.float32)

#     # (x_per)
#     length_valid = len(x_per_now)
#     num_pad_visit = max_num_visits - length_valid
#     pad = [0] * num_per_codes
#     tmp2 = x_per_now.copy()
#     tmp2.extend([pad] * num_pad_visit)
#     x_per_now = torch.tensor(tmp2.copy(), dtype=torch.float32)
    
#     return x_dem_now, x_per_now, x_d_now


## No collating
def collate_fn(data):

    x_dem_now, x_per_now, x_d_now = zip(*data)

    x_dem_now = torch.tensor(x_dem_now[0], dtype=torch.float32)
    x_per_now = torch.tensor(x_per_now[0], dtype=torch.float32)
    x_d_now = torch.tensor(x_d_now[0], dtype=torch.float32)


    
    return x_dem_now, x_per_now, x_d_now

In [12]:
loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
loader_iter = iter(loader)
x_dem_new, x_per_new, x_d_new = next(loader_iter)
# print(x[0])
# print(x_d_new[0])

In [13]:
# x_per, x_d : num patients, max num visits, num codes
# x_dem : num_patients, num demo codes

print(x_dem_new.size())
print(x_per_new.size())  
print(x_d_new.size())  


torch.Size([5])
torch.Size([11, 19])
torch.Size([11, 400])


In [14]:
split = int(len(dataset)*0.9)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 17488
Length of val dataset: 1944


In [15]:
def load_data(train_dataset, val_dataset, collate_fn):
    
    batch_size = 1
    train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [16]:
train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)


In [17]:
print(len(train_loader)) ## divided by 32, no remainder
print(len(val_loader)) ## remainder exists

17488
1944


In [18]:
for a, b, c in train_loader:
    print(a.size())
    print(b.size())
    print(c.size())
    break




torch.Size([5])
torch.Size([32, 19])
torch.Size([32, 400])


In [31]:
### RNN architecture
### W_hh, W_ih, W_ho
### Two Rnns, one for x_c, one for x_d
### Combine Two Rnns
class TwoRNN(nn.Module):
    
    def __init__(self, num_codes, num_codes_d):
        super().__init__()

        self.num_hidden_dim = 128
        self.num_codes_d = num_codes_d


        ## W_ih
        self.embedding = nn.Linear(num_codes + 5, self.num_hidden_dim)
        self.embedding_d = nn.Linear(self.num_codes_d, self.num_hidden_dim)

        # W_hh
        self.linear_general = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim)

        # W_ho
        self.linear_fin = nn.Linear(self.num_hidden_dim *2, self.num_codes_d)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    
    def forward(self, x_dem_new, x_per_new, x_d_new):
        
        # batch_size = x_d_new.shape[0] # batch_size = 1
        x_dem_new = x_dem_new.squeeze(0) # num_demos
        x_per_new = x_per_new.squeeze(0) # num_visit, num_examinations
        x_d_new = x_d_new.squeeze(0) #num_visit, num_codes

        if len(x_per_new.size()) == 1:
            x_per_new = x_per_new.unsqueeze(0)
            x_d_new = x_d_new.unsqueeze(0)

        num_visits = x_per_new.size()[0]

        for v_idx in range(num_visits):
            if v_idx == 0:
                h_t_minus_1_c = nn.init.kaiming_uniform_(torch.empty(1, self.num_hidden_dim)).squeeze(0)
                h_t_minus_1_d = nn.init.kaiming_uniform_(torch.empty(1, self.num_hidden_dim)).squeeze(0)


                # W_ih

                x_c = self.embedding(torch.concat([x_dem_new, x_per_new[v_idx]], dim=0))
                x_d = self.embedding_d(x_d_new[v_idx])

                # W_hh
                h_t_c = self.linear_general(torch.concat([x_c,h_t_minus_1_c]))
                h_t_d = self.linear_general(torch.concat([x_d,h_t_minus_1_d]))


                # W_ho
                y_t_hat = self.linear_fin(torch.concat([h_t_c, h_t_d]))
                y_t_hat = self.sigmoid(y_t_hat)
                
                # y_hats = y_t_hat.clone().detach().unsqueeze(0)
                y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)

            else:
                # W_ih
                x_c = self.embedding(torch.concat([x_dem_new, x_per_new[v_idx]], dim=0))
                x_d = self.embedding_d(x_d_new[v_idx])

                # W_hh
                h_t_c = self.linear_general(torch.concat([x_c,h_t_c]))
                h_t_d = self.linear_general(torch.concat([x_d,h_t_d]))


                # W_ho
                y_t_hat = self.linear_fin(torch.concat([h_t_c, h_t_d]))
                y_t_hat = self.sigmoid(y_t_hat)

                y_hats = torch.concat([y_hats, y_t_hat.unsqueeze(0)], dim=0)


        return y_hats
    

# load the model here
twornn = TwoRNN(num_codes = num_per_codes, num_codes_d=num_codes)
y_hats = twornn.forward(x_dem_new, x_per_new, x_d_new)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


In [32]:
twornn

TwoRNN(
  (embedding): Linear(in_features=24, out_features=128, bias=True)
  (embedding_d): Linear(in_features=400, out_features=128, bias=True)
  (linear_general): Linear(in_features=256, out_features=128, bias=True)
  (linear_fin): Linear(in_features=256, out_features=400, bias=True)
  (sigmoid): Sigmoid()
  (tanh): Tanh()
  (relu): ReLU()
)

In [33]:
criterion = nn.BCELoss(reduction='none')
optimizer = torch.optim.Adam(list(twornn.parameters()), lr=1e-2)

In [43]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval_model(model, val_loader):
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    
    for idx, (x_dem_new, x_per_new, x_d_new) in enumerate(val_loader):
        y_hat = model.forward(x_dem_new, x_per_new, x_d_new)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, x_d_new.detach().to('cpu')), dim=0)

        if idx == 100:
            break
    """
    TODO:
        Calculate precision, recall, f1, and roc auc scores.
        Use `average='binary'` for calculating precision, recall, and fscore.
    """
    p, r, f, roc_auc = None, None, None, None
    # print(y_pred.size()) ## batch_size *76, 400
    # print(y_true.size()) ## batch_size *76, 400

    # Micro F1
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='micro')
    return p, r, f

In [35]:
def train(model, train_loader, val_loader, n_epochs):
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for idx, (x_dem_new, x_per_new, x_d_new) in enumerate(train_loader):
            """
            TODO:
                1. zero grad
                2. model forward
                3. calculate loss
                4. loss backward
                5. optimizer step
            """
            loss = None
            # your code here
            optimizer.zero_grad()
            
            ## model forward
            y_hat = model.forward(x_dem_new, x_per_new, x_d_new)
            
            ##loss
            # print(y_hat.size())
            # print(x_d_new.size())

            if len(x_d_new.size()) == 1:
                x_d_new = x_d_new.unsqueeze(0)
                loss = criterion(y_hat, x_d_new)
            else:
                loss = criterion(y_hat, x_d_new).sum()
            
            # print(loss.size()) (76,400) -> 1
            # print(loss)
            loss.requires_grad_(True)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            # if idx % 100 == 0:
            #     print(idx)
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f = eval_model(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}'
              .format(epoch+1, p, r, f))

In [36]:
# number of epochs to train the model
# about 1 hour, for 1 epoch
n_epochs = 1
twornn = TwoRNN(num_codes = num_per_codes, num_codes_d=num_codes)
train(twornn, train_loader, val_loader, n_epochs)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


Epoch: 1 	 Training Loss: 2404.108556
Epoch: 1 	 Validation p: 0.03, r:0.52, f: 0.05


In [44]:
p,r,f = eval_model(twornn, val_loader)
print(p)
print(r)
print(f)

0.025941455006729865
0.5205985231247571
0.04942029386535322


  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


In [71]:
y_hats = dsca_net.forward(x_dem_new, x_per_new, x_d_new)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


In [133]:
from sklearn.metrics import jaccard_score


def eval_model_jaccard(model, val_loader):
    
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    
    for idx, (x_dem_new, x_per_new, x_d_new) in enumerate(val_loader):
        y_hat = model.forward(x_dem_new, x_per_new, x_d_new)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, x_d_new.detach().to('cpu')), dim=0)

        if idx == 100:
            break
    """
    TODO:
        Calculate precision, recall, f1, and roc auc scores.
        Use `average='binary'` for calculating precision, recall, and fscore.
    """
    p, r, f, roc_auc = None, None, None, None
    # print(y_pred.size()) ## batch_size *76, 400
    # print(y_true.size()) ## batch_size *76, 400

    # Micro F1
    j = jaccard_score(y_true, y_pred, average='micro')
    return j

In [134]:
## Jaccard Score
j = eval_model_jaccard(twornn, val_loader)
print(j)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


0.024604706478719464
