## DSCA-NET

- Reference codes : Deep Learning for Healthcare, CS598, UIUC, Spring2022

In [291]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split


# set seed
seed = 100
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "./preprocessed/"

In [306]:
pids = pickle.load(open(os.path.join(DATA_PATH,'pid.pkl'), 'rb'))
x_dem = pickle.load(open(os.path.join(DATA_PATH,'x_dem.pkl'), 'rb'))
x_per = pickle.load(open(os.path.join(DATA_PATH,'x_per.pkl'), 'rb'))
x_d = pickle.load(open(os.path.join(DATA_PATH,'x_d.pkl'), 'rb'))


In [307]:
print(len(pids))
print(len(x_dem))
print(len(x_per))
print(len(x_d))

46520
46520
46520
46520


In [308]:
print("Patient ID:", pids[0])
print("DemoGraphic:", x_dem[0])
print("Examination of first visit:", x_per[0][0])
print("Prescription of first visit:", x_d[0][0])


# for visit in range(len(vids[3])):
#     print(f"\t{visit}-th visit id:", vids[3][visit])
#     print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
#     print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 249
DemoGraphic: [0, 1, 1, 3, 0]
Examination of first visit: [7.45, 330.0, 41.0, 0]
Prescription of first visit: [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [309]:
### Leaving Only patients who have both of examination and prescriptions
dem_no_idx = set()
per_no_idx = set()
d_no_idx = set()

for idx in range(len(x_dem)):
    
    if len(x_dem[idx]) == 0:
        dem_no_idx.add(idx)
    
    if len(x_per[idx]) == 0:
        per_no_idx.add(idx)
    
    if len(x_d[idx]) == 0:
        d_no_idx.add(idx)

In [310]:
tot_no_idx = per_no_idx | d_no_idx
all_idx = set(range(len(x_dem)))
alive_idx = all_idx - tot_no_idx

alive_idx_list = list(alive_idx)
alive_idx_list.sort()


x_dem_f = [x_dem[i] for i in alive_idx_list]
x_per_f = [x_per[i] for i in alive_idx_list]
x_d_f = [x_d[i] for i in alive_idx_list]


In [311]:
print(len(x_dem_f))
print(len(x_per_f))
print(len(x_d_f))


11619
11619
11619


### Build Dataset

In [312]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, x_dem, x_per, x_d):
        
        self.x_dem = x_dem
        self.x_per = x_per
        self.x_d = x_d
    
    def __len__(self):
        
        return len(self.x_d)
    
    def __getitem__(self, index):
        
        x_dem = self.x_dem[index]
        x_per = self.x_per[index]
        x_d = self.x_d[index]
        return x_dem, x_per, x_d
        

dataset = CustomDataset(x_dem_f, x_per_f, x_d_f)

### Collate

In [383]:
num_patients = len(x_per_f) # 11619
num_visits = [len(patient) for patient in x_per_f] #
max_num_visits = max(num_visits)

num_codes = 400
num_per_codes = 4
print(max_num_visits)


76


In [154]:
len(x_d_f[0])

3

In [198]:
p = [0] * 4
ps = [p] * 7
print(ps)

[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]


In [210]:
a = [[1,2,3,4]]

In [211]:
a.extend(ps)

In [212]:
a

[[1, 2, 3, 4],
 [0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0],
 [0, 0, 0, 0]]

In [213]:
torch.tensor(a)

tensor([[1, 2, 3, 4],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]])

In [290]:
len(x_per[0])

2253

In [479]:
## 1. (x_d) For all patients, Number of Visit (len(x_d)) must be same -> Need to be padded
## 2. (x_d) For all patients, Number of Prescription codes must be same -> Already Done
## 3. (x_per) For all patients, Number of Visit (len(x_per)) must be same -> Need to be padded
## 4. Mask -> 0 For padded intputs -> No Need
def collate_fn(data):

    x_dem_now, x_per_now, x_d_now = zip(*data)

    x_dem_now = torch.tensor(x_dem_now[0], dtype=torch.float32)
    x_per_now = x_per_now[0]
    x_d_now = x_d_now[0]

    # print(len(x_per_now))
    # print(len(x_d_now))


    # x_d_new = torch.zeros((1, max_num_visits, num_codes), dtype=torch.long)
    # x_per_new = torch.zeros((1, max_num_visits, num_per_codes), dtype=torch.long)


    ## (x_d) 
    length_valid_visit = len(x_d_now)
    # print("Valid Visit : {0}".format(length_valid_visit))
    num_pad_visit = max_num_visits - length_valid_visit
    pad = [0] * num_codes  # [0, ... , 0] : For 400 zeros
    tmp = x_d_now.copy()
    tmp.extend([pad] * num_pad_visit)
    x_d_now = torch.tensor(tmp.copy(), dtype=torch.float32)

    # (x_per)
    length_valid = len(x_per_now)
    num_pad_visit = max_num_visits - length_valid
    pad = [0] * num_per_codes
    tmp2 = x_per_now.copy()
    tmp2.extend([pad] * num_pad_visit)
    x_per_now = torch.tensor(tmp2.copy(), dtype=torch.float32)
    
    return x_dem_now, x_per_now, x_d_now

In [480]:
loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
loader_iter = iter(loader)
x_dem_new, x_per_new, x_d_new = next(loader_iter)
# print(x[0])
# print(x_d_new[0])

In [482]:
# x_per, x_d : num patients, max num visits, num codes
# x_dem : num_patients, num demo codes

print(x_dem_new.size())
print(x_per_new.size())  
print(x_d_new.size())  


torch.Size([5])
torch.Size([76, 4])
torch.Size([76, 400])


In [483]:
split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 9295
Length of val dataset: 2324


In [484]:
def load_data(train_dataset, val_dataset, collate_fn):
    
    batch_size = 1
    train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [485]:
train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)


In [486]:
print(len(train_loader)) ## divided by 32, no remainder
print(len(val_loader)) ## remainder exists

9295
2324


In [487]:
for a, b, c in train_loader:
    print(a.size())
    print(b.size())
    print(c.size())
    break




torch.Size([5])
torch.Size([76, 4])
torch.Size([76, 400])


In [490]:
## cross attention
def CA(x_dem_per_comb_frac, x_d_new_frac, w_dim=128):
    # concatenate x_dem and x_per_new

        x_c = x_dem_per_comb_frac.clone()
        x_d = x_d_new_frac.clone()

        c_linear = nn.Linear(x_c.size()[0], w_dim)
        d_linear = nn.Linear(x_d.size()[0], w_dim)
        tanh = nn.Tanh()
        sigmoid = nn.Sigmoid()


        # print(x_c.size())  # 5 + 128 = 133
        # print(x_d.size())  # 128

        x_c1 = c_linear(x_c)
        x_d1 = d_linear(x_d)

        x_c2 = c_linear(x_c)
        x_d2 = d_linear(x_d)


        s_cd = sigmoid(c_linear(x_c) * tanh(x_c1 + x_d1))
        s_dc = sigmoid(d_linear(x_d) * tanh(x_c2 + x_d2))

        # weight_c = torch.sum(x_c, dim=0)
        # weight_d = torch.sum(x_d, dim=0)
        # s_dc = weight_d * x_c
        # s_cd = weight_c * x_d


        x = torch.concat([torch.mul(s_cd, x_d), torch.mul(s_dc, x_c)])
        s = torch.concat([s_cd, s_dc])

        # print(x.size())

        return x,s


In [491]:
class DSCARNN(nn.Module):

    def __init__(self, h_t_minus_1, ca_t_minus_1, x_t, s_t, num_hidden_dim, num_codes_d, tau):
        super().__init__()

        self.h_t_minus_1 = h_t_minus_1
        self.ca_t_minus_1 = ca_t_minus_1
        self.x_t = x_t
        self.s_t = s_t
        self.num_hidden_dim = num_hidden_dim
        self.num_codes_d = num_codes_d
        
        self.linear = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim * 2)
        self.linear_dec = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim)
        self.linear_out = nn.Linear(self.num_hidden_dim *2, self.num_codes_d)
    
        self.tau = tau
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
    
    def forward(self):
        d_t = self.sigmoid(self.linear(self.x_t))
        ca_t = self.s_t + (1 - d_t) * self.ca_t_minus_1

        ca_t_tilde = self.tanh(self.linear(ca_t))

        h_tilde = self.sigmoid(torch.concat([self.linear_dec(self.x_t), self.linear_dec(self.h_t_minus_1)], dim=0))

        h_t = self.tau * ca_t_tilde + (1-self.tau) * self.sigmoid(h_tilde)

        y_t_hat = self.sigmoid(self.linear_out(self.relu(self.linear(h_t))))

        return h_t, ca_t, y_t_hat

In [467]:
# a = torch.zeros((1, 32))
# b = torch.zeros((1, 32))
# c = torch.concat([a, b], dim=0)

In [493]:
class DSCANet(nn.Module):
    
    def __init__(self, num_codes, num_codes_d, tau):
        super().__init__()

        # self.embedding = nn.Embedding(num_embeddings = num_codes, embedding_dim = 128)
        # self.embedding_d = nn.Embedding(num_embeddings = num_codes_d, embedding_dim = 128)
        self.num_hidden_dim = 128
        self.embedding = nn.Linear(num_codes + 5, self.num_hidden_dim)
        self.embedding_d = nn.Linear(num_codes_d, self.num_hidden_dim)

        self.num_codes_d = num_codes_d

        # self.linear = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim * 2)
        # self.linear_dec = nn.Linear(self.num_hidden_dim * 2, self.num_hidden_dim)
        # self.linear_out = nn.Linear(self.num_hidden_dim *2, num_codes_d)
    

        # self.rnn = nn.GRU(input_size = 128, hidden_size = 128, batch_first=True)
        # self.fc = nn.Linear(in_features = 256, out_features=1)
        self.tau = tau
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    
    def forward(self, x_dem_new, x_per_new, x_d_new):
        
        # batch_size = x_d_new.shape[0] # batch_size = 1
        x_dem_new = x_dem_new.squeeze(0) # num_demos
        x_per_new = x_per_new.squeeze(0) # num_visit, num_examinations
        x_d_new = x_d_new.squeeze(0) #num_visit, num_codes

        # print(x_dem_new.size()) # 5
        # print(x_per_new.size()) # 76, 4
        # print(x_d_new.size()) # 76,400

        num_visits = x_per_new.size()[0]

        

        for v_idx in range(num_visits):
            if v_idx == 0:

                # if v_idx == 0, initialize h_t and ca_t
                h_t_minus_1 = nn.init.kaiming_uniform_(torch.empty(1, self.num_hidden_dim * 2)).squeeze(0)
                ca_t_minus_1 = nn.init.kaiming_uniform_(torch.empty(1, self.num_hidden_dim * 2)).squeeze(0)

                
                # getting embeddings and cross attentions
                embedded_per = self.embedding(torch.concat([x_dem_new, x_per_new[v_idx]], dim=0))
                embedded_d = self.embedding_d(x_d_new[v_idx])

                x_t, s_t = CA(embedded_per, embedded_d, w_dim=self.num_hidden_dim)

                h_t_minus_1, ca_t_minus_1, y_t_hat = DSCARNN(h_t_minus_1, ca_t_minus_1, x_t, s_t, self.num_hidden_dim, self.num_codes_d, self.tau).forward()
                y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)
                

            else:
                # getting embeddings and cross attentions
                embedded_per = self.embedding(torch.concat([x_dem_new, x_per_new[v_idx]], dim=0))
                embedded_d = self.embedding_d(x_d_new[v_idx])
                x_t, s_t = CA(embedded_per, embedded_d, w_dim=self.num_hidden_dim)
                
                h_t_minus_1, ca_t_minus_1, y_t_hat = DSCARNN(h_t_minus_1, ca_t_minus_1, x_t, s_t, self.num_hidden_dim, self.num_codes_d, self.tau).forward()
                y_hats = torch.concat([y_hats, y_t_hat.unsqueeze(0)], dim=0)
                


        return y_hats
    

# load the model here
dsca_net = DSCANet(num_codes = num_per_codes, num_codes_d=num_codes, tau=0.6)
y_hats = dsca_net.forward(x_dem_new, x_per_new, x_d_new)

  y_hats = torch.tensor(y_t_hat, dtype=torch.float32).unsqueeze(0)


In [498]:
dsca_net

DSCANet(
  (embedding): Linear(in_features=9, out_features=128, bias=True)
  (embedding_d): Linear(in_features=400, out_features=128, bias=True)
  (sigmoid): Sigmoid()
  (tanh): Tanh()
  (relu): ReLU()
)

In [494]:
y_hats.size()

torch.Size([76, 400])

In [497]:
DSCANet.parameters()

TypeError: parameters() missing 1 required positional argument: 'self'

In [500]:
criterion = nn.BCELoss(reduction='none')
optimizer = torch.optim.Adam(list(dsca_net.parameters(), CA.parameters()), lr=1e-2)

AttributeError: 'function' object has no attribute 'parameters'

In [None]:
def train(model, train_loader, val_loader, n_epochs):
    
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0
        for x_dem_new, x_per_new, x_d_new in train_loader:
            """
            TODO:
                1. zero grad
                2. model forward
                3. calculate loss
                4. loss backward
                5. optimizer step
            """
            loss = None
            # your code here
            optimizer.zero_grad()
            
            ## model forward
            y_hat = model(x, masks, rev_x, rev_masks)
            
            ##loss
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        p, r, f, roc_auc = eval_model(model, val_loader)
        print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'
              .format(epoch+1, p, r, f, roc_auc))

In [382]:
x_per_new.size()

torch.Size([76, 4])

In [360]:
print(x_ex.size())
print(s_ex.size())

torch.Size([409])
torch.Size([409])


In [100]:
comb = x_d_ex.view(10, -1, 128)
combsum = torch.sum(comb, dim=1)

In [102]:
combsum * x_d

torch.Size([10, 128])

In [103]:
x_c_ex.size()

torch.Size([10, 47, 128])

In [96]:
print(x_d_ex.size())
print(x_c_ex.size())

torch.Size([10, 42, 48, 128])
torch.Size([10, 47, 128])


In [84]:
x_d_ex[0][0][0].size()

torch.Size([128])

In [88]:
print(x_c_ex[0][0].size())
print(x_d_ex[0][0][0][0])

torch.Size([128])
tensor(0.9721, grad_fn=<SelectBackward0>)


In [91]:
torch.sum(x_d_ex[0][0][0]) * x_c_ex[0][0]

tensor([ -0.1053,  11.7656,  -0.9722,  -4.5396,   1.8796,  -4.2402,   1.3459,
          2.6512,  -2.2932,  -3.7640,   1.5582,  -3.1277,  -1.8824,  -5.7328,
          3.5307,   2.6421,  -3.5459,  14.6715,   0.1037,   3.1923,  -4.5923,
         -2.3548,   3.7755,  -1.0293,   2.9087,   0.4739,  -5.3497,  -6.9707,
         -0.5234,   6.3263,   1.4085,   0.4351,   6.2411,  -6.4672,  -2.6802,
         -3.0452,  -2.7206,   1.9852,   4.9574,  -1.2803,   6.8643,   3.8397,
         -7.8684,   5.9907,   9.8121,   1.5851,   0.9734,   2.2264,   3.3719,
          2.6987,   1.6941,  -3.0267,   5.8485,  -2.0138,  -4.8148,  -9.3796,
          1.6449,   7.1585,  -5.9781,   6.6120,   3.1153,   1.9886,   4.7200,
         11.3641,   2.7315,   0.1468,  -1.7767,  -8.4672,  -7.4049,  -0.9031,
         -6.7645,   1.0623,  -6.6412,   4.6745,   1.0223,  -3.9003,   1.1705,
          3.8033,   4.1859,  -7.3584,  -8.8363,   8.7551,   4.4871,   8.2050,
         -3.4502,  -5.3145,  -8.8944,   0.0268,  -0.7932,   0.73

In [90]:
torch.sum(x_c_ex[0][0]) * x_d_ex[0][0][0]

tensor([ 3.0352,  1.2070,  4.3272,  4.2509,  1.3044,  0.7701, -2.1011, -3.1854,
         0.1768, -3.2213, -4.4997,  2.2523,  0.6664, -1.6208, -4.0750, -4.5078,
        -0.6047,  4.6722,  0.8967, -1.4821,  7.2332, -5.6826, -0.7302, -5.0149,
         0.3665,  3.0565, -1.5576, -1.2695, -2.9509, -2.9607,  8.2866, -2.2148,
         4.4720,  4.6280, -3.8665, -3.1833, -3.4565, -0.9128, -1.2805, -0.5828,
         1.7183,  2.8481,  2.3334,  2.5972, -0.2080,  0.6289,  0.7385,  2.1967,
        -3.3458, -0.9485,  2.9139, -2.2417, -2.0779,  0.6835,  6.1068, -0.7912,
         1.9938, -2.1672, -1.1477,  0.9477, -1.8482,  1.4520, -1.0124,  3.1347,
         1.2956,  1.2019,  7.7229,  6.2986,  2.1932, -6.4995,  0.7766, -1.5119,
         1.7411, -1.5819, -1.3898, -1.6854, -0.5595, -0.2571, -3.3016, -4.5502,
        -0.4487,  0.5043, -3.6371, -3.7618, -3.4027, -1.9484, -0.4875, -3.1030,
         1.3803, -4.2961,  2.6533, -0.1317, -1.6301,  5.6923, -0.8562, -0.4400,
         0.0173, -3.8527,  0.6336, -0.15

In [74]:
## 47, 128 짜리 하나랑. 
x_d_ex[0][0].size()

torch.Size([48, 128])

In [104]:
print(x_d_ex.size())
print(x_c_ex.size())



torch.Size([10, 42, 48, 128])
torch.Size([10, 47, 128])
