# **GRAM: Graph-based Attention Model for Healthcare Representation Learning**
Edward Choi, Mohammad Taha Bahadori, Le Song, Walter F. Stewart, Jimeng Sun

[KDD 2017](https://dl.acm.org/doi/10.1145/3097983.3098126)

Model Parameters in Original Paper:

1.   GRU Hidden Layer Dimension: 128
2.   Attention Layer Dimension: 128
3.   Dropout Rate: 0.5
4.   Code Embedding Dimension: 128
5.   Batch Size: 100
6.   Number of Epochs: 100
7.   L2 Regularization Coefficient: 0.001
8.   $x_{max}$: 100
9.   $\alpha$: 0.75

PyTorch Implementation by [Leisheng Yu](https://github.com/ThunderbornSakana) (leisheng.yu@alumni.emory.edu)

# **Diagnosis Prediction -- Multi-label Binary Prediction Task**




## **Package Setup**

In [1]:
import os
import pickle as pickle
import numpy as np
from datetime import datetime
import pandas as pd
import scipy.sparse as sps
import torch
from copy import deepcopy
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.init as init
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from collections import OrderedDict
import torch.utils.data as data
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import ndcg_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import random
import warnings
warnings.filterwarnings("ignore")

## **Load Data**

In [2]:
# Binary Format Combo for GRAM
with open('../MIMIC3_data/Binary_Data_Format/binary_train_codes_x.pkl', 'rb') as f0:
  binary_train_codes_x = pickle.load(f0)

with open('../MIMIC3_data/Binary_Data_Format/binary_test_codes_x.pkl', 'rb') as f1:
  binary_test_codes_x = pickle.load(f1)

train_codes_y = np.load('../MIMIC3_data/Binary_Data_Format/train_codes_y.npy')
train_visit_lens = np.load('../MIMIC3_data/Binary_Data_Format/train_visit_lens.npy')
test_codes_y = np.load('../MIMIC3_data/Binary_Data_Format/test_codes_y.npy')
test_visit_lens = np.load('../MIMIC3_data/Binary_Data_Format/test_visit_lens.npy')

code_levels = np.load('../MIMIC3_data/code_related/code_levels.npy')
patient_code_adj = np.load('../MIMIC3_data/code_related/patient_code_adj.npy')
code_code_adj = np.load('../MIMIC3_data/code_related/code_code_adj.npy')

with open('../MIMIC3_data/code_related/code_map.pkl', 'rb') as f13:
  code_map = pickle.load(f13)

In [3]:
def transform_and_pad_input(x):
  tempX = []
  for ele in x:
    tempX.append(torch.tensor(ele).to(torch.float32))
  x_padded = pad_sequence(tempX, batch_first=True, padding_value=0)
  return x_padded

In [4]:
padded_X_train = transform_and_pad_input(binary_train_codes_x)
padded_X_test = transform_and_pad_input(binary_test_codes_x)
trans_y_train = torch.tensor(train_codes_y)
trans_y_test = torch.tensor(test_codes_y)

In [5]:
class MyData(data.Dataset):
    def __init__(self, data_seq, data_label, data_len):
        self.data_seq = data_seq
        self.data_label = data_label
        self.data_len = data_len
 
    def __len__(self):
        return len(self.data_seq)
 
    def __getitem__(self, idx):
        return self.data_seq[idx], self.data_label[idx], self.data_len[idx]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## **Model Starts**

In [None]:
# 2.4 Initializing Basic Embeddings -- GloVe
# Augment all visits
# code_num_in_levels = (np.max(code_levels, axis=0)).tolist()
# fourth_level_num = code_num_in_levels[3]
# third_level_num = code_num_in_levels[2]
# second_level_num = code_num_in_levels[1]
# augmented_list = []
# for i in range(len(train_codes_x)):
#   for ii in range(train_visit_lens[i]):
#     temp_visit = train_codes_x[i][ii]
#     temp_visit2 = temp_visit[temp_visit > 0]
#     cleaned_visit = temp_visit2.tolist()
#     prev_lens = len(cleaned_visit)
#     for iii in range(prev_lens):
#       first_idx = code_levels[cleaned_visit[iii]-1][2]
#       second_idx = code_levels[cleaned_visit[iii]-1][1]
#       third_idx = code_levels[cleaned_visit[iii]-1][0]
#       cleaned_visit = cleaned_visit + [fourth_level_num+first_idx, fourth_level_num+third_level_num+second_idx, fourth_level_num+third_level_num+second_level_num+third_idx]
#     augmented_list.append(cleaned_visit)

# Creating the Co-occurrence matrix M
# total_code_num = sum(code_num_in_levels)
# M = np.zeros((total_code_num, total_code_num))
# for length in range(1, 36):
#   for width in range(length, total_code_num):
#     for visit_count in range(len(augmented_list)):
#       one_visit = augmented_list[visit_count]
#       M[length][width] += one_visit.count(length+1) * one_visit.count(width+1)
#   print(length)

# Training the embedding vectors using M
f_M = np.load("../MIMIC3_data/code_related/f_M.npy")
M_log = np.load("../MIMIC3_data/code_related/M_log.npy")
for i in range(len(M_log)):
  for ii in range(len(M_log)):
    if M_log[i][ii] < 0:
      M_log[i][ii] = 0

class GlovePretrain(nn.Module):
  def __init__(self, f_M, M_log):
    super(GlovePretrain, self).__init__()
    self.f_M = f_M
    self.M_log = M_log
    self.Mlens = len(f_M)
    self.b = torch.nn.Parameter(torch.zeros(len(f_M),))
    self.e = torch.nn.Embedding(len(f_M), 128)

  def forward(self):
    ee = torch.matmul(self.e.weight * 1, torch.t(self.e.weight * 1))
    b_ij = self.b.expand(self.Mlens, self.Mlens)
    b_ij = b_ij + torch.t(b_ij)
    J = torch.sum(f_M * torch.square(ee + b_ij - M_log))
    return J

f_M = torch.from_numpy(f_M).to(device)
M_log = torch.from_numpy(M_log).to(device)

glove = GlovePretrain(f_M, M_log)
glove = glove.to(device)
glove.train()
optimizer = torch.optim.Adagrad(glove.parameters(), lr=0.05)
num_iterations = 25000
for iteration in range(num_iterations):
  loss = glove()
  # Backward and optimize
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  if iteration % 2000 == 0:
    print(loss.item())
pretrained_embeddings = glove.e.weight

In [7]:
class GRAM(nn.Module):
  def __init__(self, code_levels, code_num_in_levels, code_dims, hidden_dim, layer_dim, dropout_prob, output_dim, pretrained_embeddings):
    super(GRAM, self).__init__()
    self.level_num = len(code_num_in_levels)
    self.code_levels = code_levels
    self.level_embeddings = nn.ModuleList([nn.Embedding(code_num, code_dim) for level, (code_num, code_dim) in enumerate(zip(code_num_in_levels, code_dims))])
    # Equip with the pretrained embeddings
    self.level_embeddings[3].weight = torch.nn.Parameter(pretrained_embeddings[0:code_num_in_levels[3]])
    self.level_embeddings[2].weight = torch.nn.Parameter(pretrained_embeddings[code_num_in_levels[3]:code_num_in_levels[3] + code_num_in_levels[2]])
    self.level_embeddings[1].weight = torch.nn.Parameter(pretrained_embeddings[code_num_in_levels[3] + code_num_in_levels[2]:code_num_in_levels[3] + code_num_in_levels[2] + code_num_in_levels[1]])
    self.level_embeddings[0].weight = torch.nn.Parameter(pretrained_embeddings[code_num_in_levels[3] + code_num_in_levels[2] + code_num_in_levels[1]:code_num_in_levels[3] + code_num_in_levels[2] + code_num_in_levels[1] + code_num_in_levels[0]])
    # Attention layer
    self.attention = nn.Linear(code_dims[3]*2, code_dims[3])
    self.u_a = nn.Linear(code_dims[3], 1, bias=False)
    # GRU layers for processing sequences
    self.hidden_dim = hidden_dim
    self.layer_dim = layer_dim
    self.gru = nn.GRU(code_dims[3], hidden_dim, layer_dim, batch_first=True, dropout=dropout_prob)
    self.fc = nn.Linear(hidden_dim, output_dim)
    self.softmax0 = torch.nn.Softmax(dim=1)
    self.softmax = torch.nn.Softmax()
    self.tanh = torch.nn.Tanh()

  def forward(self, x, x_len):
    # Initializing hidden state for first input with zeros
    weight0 = next(self.parameters()).data
    h0 = weight0.new(self.layer_dim, x.size(0), self.hidden_dim).zero_().to(device)
    h0 = h0.data
    # Update the code embeddings
    embeddings = [self.level_embeddings[level](self.code_levels[:, level] - 1) for level in range(self.level_num)]
    score_matrix = self.softmax0(torch.concat((self.u_a(self.attention(torch.concat((embeddings[3], embeddings[0]), dim=1))), self.u_a(self.attention(torch.concat((embeddings[3], embeddings[1]), dim=1))), self.u_a(self.attention(torch.concat((embeddings[3], embeddings[2]), dim=1))), self.u_a(self.attention(torch.concat((embeddings[3], embeddings[3]), dim=1)))), dim=1))
    new_emb_matrix = embeddings[0] * torch.reshape(score_matrix[:, 0], (len(score_matrix[:, 0]), 1)) + embeddings[1] * torch.reshape(score_matrix[:, 1], (len(score_matrix[:, 1]), 1)) + embeddings[2] * torch.reshape(score_matrix[:, 2], (len(score_matrix[:, 2]), 1)) + embeddings[3] * torch.reshape(score_matrix[:, 3], (len(score_matrix[:, 3]), 1))
    # Get visit embeddings
    x = self.tanh(torch.matmul(x, new_emb_matrix))
    # Feed into GRU
    x_packed = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=False)
    # Forward propagation by passing in the input and hidden state into the model
    out, _ = self.gru(x_packed, h0)
    out, out_lengths = pad_packed_sequence(out, batch_first=True)
    # Reshaping the outputs in the shape of (batch_size, hidden_size)
    # so that it can fit into the fully connected layer
    out = out[list(torch.arange(len(out)).cpu()), list((out_lengths-1).cpu()), :]
    # Last layer with softmax
    out = self.softmax(self.fc(out))
    return out

## **Training Loop**

In [8]:
code_num_in_levels = (np.max(code_levels, axis=0)).tolist()
code_levels2 = torch.from_numpy(code_levels).to(device)
model = GRAM(code_levels2, code_num_in_levels, [128, 128, 128, 128], 128, 2, 0.5, 4880, pretrained_embeddings)
model = model.to(device)

In [9]:
# This is for diagnosis prediction
def evaluate_model(pred, label, k1, k2, k3, k4, k5, k6):
  pred2_k1 = torch.zeros_like(pred[0])
  pred3_k1 = []
  pred2_k2 = torch.zeros_like(pred[0])
  pred3_k2 = []
  pred2_k3 = torch.zeros_like(pred[0])
  pred3_k3 = []
  pred2_k4 = torch.zeros_like(pred[0])
  pred3_k4 = []
  pred2_k5 = torch.zeros_like(pred[0])
  pred3_k5 = []
  pred2_k6 = torch.zeros_like(pred[0])
  pred3_k6 = []
  # above is for recall and precision
  true3 = [] # this is for label
  pred4 = [] # this is for ndcg
  for i in range(len(pred)):
    pred2_k1[torch.topk(pred[i], k1).indices] = 1
    pred3_k1.append(pred2_k1.cpu().detach().tolist())
    pred2_k2[torch.topk(pred[i], k2).indices] = 1
    pred3_k2.append(pred2_k2.cpu().detach().tolist())
    pred2_k3[torch.topk(pred[i], k3).indices] = 1
    pred3_k3.append(pred2_k3.cpu().detach().tolist())
    pred2_k4[torch.topk(pred[i], k4).indices] = 1
    pred3_k4.append(pred2_k4.cpu().detach().tolist())
    pred2_k5[torch.topk(pred[i], k5).indices] = 1
    pred3_k5.append(pred2_k5.cpu().detach().tolist())
    pred2_k6[torch.topk(pred[i], k6).indices] = 1
    pred3_k6.append(pred2_k6.cpu().detach().tolist())
    pred4.append(pred[i].cpu().detach().tolist())
    true3.append(label[i].cpu().detach().tolist())
  
  metric_p_1 = precision_score(true3, pred3_k1, average='samples')
  metric_p_2 = precision_score(true3, pred3_k2, average='samples')
  metric_p_3 = precision_score(true3, pred3_k3, average='samples')
  metric_p_4 = precision_score(true3, pred3_k4, average='samples')
  metric_p_5 = precision_score(true3, pred3_k5, average='samples')
  metric_p_6 = precision_score(true3, pred3_k6, average='samples')
  
  metric_r_1 = recall_score(true3, pred3_k1, average='samples')
  metric_r_2 = recall_score(true3, pred3_k2, average='samples')
  metric_r_3 = recall_score(true3, pred3_k3, average='samples')
  metric_r_4 = recall_score(true3, pred3_k4, average='samples')
  metric_r_5 = recall_score(true3, pred3_k5, average='samples')
  metric_r_6 = recall_score(true3, pred3_k6, average='samples')
  
  metric_n_1 = ndcg_score(true3, pred4, k=k1)
  metric_n_2 = ndcg_score(true3, pred4, k=k2)
  metric_n_3 = ndcg_score(true3, pred4, k=k3)
  metric_n_4 = ndcg_score(true3, pred4, k=k4)
  metric_n_5 = ndcg_score(true3, pred4, k=k5)
  metric_n_6 = ndcg_score(true3, pred4, k=k6)
  return metric_p_1, metric_r_1, metric_n_1, metric_p_2, metric_r_2, metric_n_2, metric_p_3, metric_r_3, metric_n_3, metric_p_4, metric_r_4, metric_n_4, metric_p_5, metric_r_5, metric_n_5, metric_p_6, metric_r_6, metric_n_6

In [10]:
# Initialize evaluation record lists
metric_p1_list = []
metric_p2_list = []
metric_p3_list = []
metric_p4_list = []
metric_p5_list = []
metric_p6_list = []
metric_r1_list = []
metric_r2_list = []
metric_r3_list = []
metric_r4_list = []
metric_r5_list = []
metric_r6_list = []
metric_n1_list = []
metric_n2_list = []
metric_n3_list = []
metric_n4_list = []
metric_n5_list = []
metric_n6_list = []

In [None]:
# Training mode
model.train()
# Loss and optimizer (learning rate)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.001)
# Initialize data loader for training
training_data = MyData(padded_X_train, trans_y_train, train_visit_lens)
train_loader = DataLoader(training_data, batch_size=100, shuffle=True)
total_step = len(train_loader)
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
  for i, (patients, labels, seq_lengths) in enumerate(train_loader):
    patients = patients.to(device)
    labels = labels.to(device)
    # Forward pass
    outputs = model(patients, seq_lengths)
    loss = criterion(outputs, labels.to(torch.float32))
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Tracking
    if (i+1) % 20 == 0:
      print('Epoch: [{}/{}], Step: [{}/{}], Loss: {}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
  if (epoch+1) % 1 == 0:
    model.eval()
    # Initialize data loader for testing
    test_data = MyData(padded_X_test, trans_y_test, test_visit_lens)
    test_loader = DataLoader(test_data, batch_size=len(padded_X_test), shuffle=True)
    # Testing
    for (patients, labels, seq_lengths) in test_loader:
        patients = patients.to(device)
        labels = labels.to(device)
        pred = model(patients, seq_lengths)
        # Subject to Change! @k for evaluation
        metric_p1, metric_r1, metric_n1, metric_p2, metric_r2, metric_n2, metric_p3, metric_r3, metric_n3, metric_p4, metric_r4, metric_n4, metric_p5, metric_r5, metric_n5, metric_p6, metric_r6, metric_n6, = evaluate_model(pred, labels, 5, 10, 15, 20, 25, 30)
        ###############################
        metric_p1_list.append(metric_p1)
        metric_p2_list.append(metric_p2)
        metric_p3_list.append(metric_p3)
        metric_p4_list.append(metric_p4)
        metric_p5_list.append(metric_p5)
        metric_p6_list.append(metric_p6)
        ###############################
        metric_r1_list.append(metric_r1)
        metric_r2_list.append(metric_r2)
        metric_r3_list.append(metric_r3)
        metric_r4_list.append(metric_r4)
        metric_r5_list.append(metric_r5)
        metric_r6_list.append(metric_r6)
        ###############################
        metric_n1_list.append(metric_n1)
        metric_n2_list.append(metric_n2)
        metric_n3_list.append(metric_n3)
        metric_n4_list.append(metric_n4)
        metric_n5_list.append(metric_n5)
        metric_n6_list.append(metric_n6)
    model.train()

# **Mortality Prediction -- Binary Prediction Task**

## **Package Setup**

In [1]:
import os
import pickle as pickle
import numpy as np
from datetime import datetime
import pandas as pd
import scipy.sparse as sps
import torch
from copy import deepcopy
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.init as init
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from collections import OrderedDict
import torch.utils.data as data
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import ndcg_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import random
import warnings
warnings.filterwarnings("ignore")

## **Load Data**

In [2]:
# Binary Format Combo for GRAM
with open('../MIMIC3_data/Binary_Data_Format/binary_train_codes_x.pkl', 'rb') as f0:
  binary_train_codes_x = pickle.load(f0)

with open('../MIMIC3_data/Binary_Data_Format/binary_test_codes_x.pkl', 'rb') as f1:
  binary_test_codes_x = pickle.load(f1)

train_visit_lens = np.load('../MIMIC3_data/Binary_Data_Format/train_visit_lens.npy')
train_mort = np.load('../MIMIC3_data/Binary_Data_Format/train_mort.npy')
test_visit_lens = np.load('../MIMIC3_data/Binary_Data_Format/test_visit_lens.npy')
test_mort = np.load('../MIMIC3_data/Binary_Data_Format/test_mort.npy')

code_levels = np.load('../MIMIC3_data/code_related/code_levels.npy')
patient_code_adj = np.load('../MIMIC3_data/code_related/patient_code_adj.npy')
code_code_adj = np.load('../MIMIC3_data/code_related/code_code_adj.npy')

with open('../MIMIC3_data/code_related/code_map.pkl', 'rb') as f13:
  code_map = pickle.load(f13)

In [3]:
def transform_and_pad_input(x):
  tempX = []
  for ele in x:
    tempX.append(torch.tensor(ele).to(torch.float32))
  x_padded = pad_sequence(tempX, batch_first=True, padding_value=0)
  return x_padded

In [4]:
padded_X_train = transform_and_pad_input(binary_train_codes_x)
padded_X_test = transform_and_pad_input(binary_test_codes_x)
trans_y_train = torch.tensor(train_mort)
trans_y_test = torch.tensor(test_mort)

In [5]:
class MyData(data.Dataset):
    def __init__(self, data_seq, data_label, data_len):
        self.data_seq = data_seq
        self.data_label = data_label
        self.data_len = data_len
 
    def __len__(self):
        return len(self.data_seq)
 
    def __getitem__(self, idx):
        return self.data_seq[idx], self.data_label[idx], self.data_len[idx]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## **Model Starts**

In [None]:
# 2.4 Initializing Basic Embeddings -- GloVe
# Augment all visits
# code_num_in_levels = (np.max(code_levels, axis=0)).tolist()
# fourth_level_num = code_num_in_levels[3]
# third_level_num = code_num_in_levels[2]
# second_level_num = code_num_in_levels[1]
# augmented_list = []
# for i in range(len(train_codes_x)):
#   for ii in range(train_visit_lens[i]):
#     temp_visit = train_codes_x[i][ii]
#     temp_visit2 = temp_visit[temp_visit > 0]
#     cleaned_visit = temp_visit2.tolist()
#     prev_lens = len(cleaned_visit)
#     for iii in range(prev_lens):
#       first_idx = code_levels[cleaned_visit[iii]-1][2]
#       second_idx = code_levels[cleaned_visit[iii]-1][1]
#       third_idx = code_levels[cleaned_visit[iii]-1][0]
#       cleaned_visit = cleaned_visit + [fourth_level_num+first_idx, fourth_level_num+third_level_num+second_idx, fourth_level_num+third_level_num+second_level_num+third_idx]
#     augmented_list.append(cleaned_visit)

# Creating the Co-occurrence matrix M
# total_code_num = sum(code_num_in_levels)
# M = np.zeros((total_code_num, total_code_num))
# for length in range(1, 36):
#   for width in range(length, total_code_num):
#     for visit_count in range(len(augmented_list)):
#       one_visit = augmented_list[visit_count]
#       M[length][width] += one_visit.count(length+1) * one_visit.count(width+1)
#   print(length)

# Training the embedding vectors using M
f_M = np.load("../MIMIC3_data/code_related/f_M.npy")
M_log = np.load("../MIMIC3_data/code_related/M_log.npy")
for i in range(len(M_log)):
  for ii in range(len(M_log)):
    if M_log[i][ii] < 0:
      M_log[i][ii] = 0

class GlovePretrain(nn.Module):
  def __init__(self, f_M, M_log):
    super(GlovePretrain, self).__init__()
    self.f_M = f_M
    self.M_log = M_log
    self.Mlens = len(f_M)
    self.b = torch.nn.Parameter(torch.zeros(len(f_M),))
    self.e = torch.nn.Embedding(len(f_M), 128)

  def forward(self):
    ee = torch.matmul(self.e.weight * 1, torch.t(self.e.weight * 1))
    b_ij = self.b.expand(self.Mlens, self.Mlens)
    b_ij = b_ij + torch.t(b_ij)
    J = torch.sum(f_M * torch.square(ee + b_ij - M_log))
    return J

f_M = torch.from_numpy(f_M).to(device)
M_log = torch.from_numpy(M_log).to(device)

glove = GlovePretrain(f_M, M_log)
glove = glove.to(device)
glove.train()
optimizer = torch.optim.Adagrad(glove.parameters(), lr=0.05)
num_iterations = 25000
for iteration in range(num_iterations):
  loss = glove()
  # Backward and optimize
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  if iteration % 2000 == 0:
    print(loss.item())
pretrained_embeddings = glove.e.weight

In [7]:
class GRAM(nn.Module):
  def __init__(self, code_levels, code_num_in_levels, code_dims, hidden_dim, layer_dim, dropout_prob, output_dim):
    super(GRAM, self).__init__()
    self.level_num = len(code_num_in_levels)
    self.code_levels = code_levels
    self.level_embeddings = nn.ModuleList([nn.Embedding(code_num, code_dim) for level, (code_num, code_dim) in enumerate(zip(code_num_in_levels, code_dims))])
    # Equip with the pretrained embeddings
    self.level_embeddings[3].weight = torch.nn.Parameter(pretrained_embeddings[0:code_num_in_levels[3]])
    self.level_embeddings[2].weight = torch.nn.Parameter(pretrained_embeddings[code_num_in_levels[3]:code_num_in_levels[3] + code_num_in_levels[2]])
    self.level_embeddings[1].weight = torch.nn.Parameter(pretrained_embeddings[code_num_in_levels[3] + code_num_in_levels[2]:code_num_in_levels[3] + code_num_in_levels[2] + code_num_in_levels[1]])
    self.level_embeddings[0].weight = torch.nn.Parameter(pretrained_embeddings[code_num_in_levels[3] + code_num_in_levels[2] + code_num_in_levels[1]:code_num_in_levels[3] + code_num_in_levels[2] + code_num_in_levels[1] + code_num_in_levels[0]])
    # Attention layer
    self.attention = nn.Linear(code_dims[3]*2, code_dims[3])
    self.u_a = nn.Linear(code_dims[3], 1, bias=False)
    # GRU layers for processing sequences
    self.hidden_dim = hidden_dim
    self.layer_dim = layer_dim
    self.gru = nn.GRU(code_dims[3], hidden_dim, layer_dim, batch_first=True, dropout=dropout_prob)
    self.fc = nn.Linear(hidden_dim, output_dim)
    self.softmax0 = torch.nn.Softmax(dim=1)
    self.sigmoid = torch.nn.Sigmoid()
    self.tanh = torch.nn.Tanh()

  def forward(self, x, x_len):
    # Initializing hidden state for first input with zeros
    weight0 = next(self.parameters()).data
    h0 = weight0.new(self.layer_dim, x.size(0), self.hidden_dim).zero_().to(device)
    h0 = h0.data
    # Update the code embeddings
    embeddings = [self.level_embeddings[level](self.code_levels[:, level] - 1) for level in range(self.level_num)]
    score_matrix = self.softmax0(torch.concat((self.u_a(self.attention(torch.concat((embeddings[3], embeddings[0]), dim=1))), self.u_a(self.attention(torch.concat((embeddings[3], embeddings[1]), dim=1))), self.u_a(self.attention(torch.concat((embeddings[3], embeddings[2]), dim=1))), self.u_a(self.attention(torch.concat((embeddings[3], embeddings[3]), dim=1)))), dim=1))
    new_emb_matrix = embeddings[0] * torch.reshape(score_matrix[:, 0], (len(score_matrix[:, 0]), 1)) + embeddings[1] * torch.reshape(score_matrix[:, 1], (len(score_matrix[:, 1]), 1)) + embeddings[2] * torch.reshape(score_matrix[:, 2], (len(score_matrix[:, 2]), 1)) + embeddings[3] * torch.reshape(score_matrix[:, 3], (len(score_matrix[:, 3]), 1))
    # Get visit embeddings
    x = self.tanh(torch.matmul(x, new_emb_matrix))
    # Feed into GRU
    x_packed = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=False)
    # Forward propagation by passing in the input and hidden state into the model
    out, _ = self.gru(x_packed, h0)
    out, out_lengths = pad_packed_sequence(out, batch_first=True)
    # Reshaping the outputs in the shape of (batch_size, hidden_size)
    # so that it can fit into the fully connected layer
    out = out[list(torch.arange(len(out)).cpu()), list((out_lengths-1).cpu()), :]
    # Last layer with sigmoid
    out = self.sigmoid(self.fc(out))
    return out

## **Training Loop**

In [8]:
code_num_in_levels = (np.max(code_levels, axis=0)).tolist()
code_levels2 = torch.from_numpy(code_levels).to(device)
model = GRAM(code_levels2, code_num_in_levels, [128, 128, 128, 128], 128, 2, 0.5, 1)
model = model.to(device)

In [9]:
# Initialize evaluation record lists
auc_list = []
acc_list = []

In [None]:
# Training mode
model.train()
# Loss and optimizer (learning rate)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.001)
# Initialize data loader for training
training_data = MyData(padded_X_train, trans_y_train, train_visit_lens)
train_loader = DataLoader(training_data, batch_size=len(padded_X_train), shuffle=True)
total_step = len(train_loader)
# Train the model
num_epochs = 100
for epoch in range(num_epochs):
  for i, (patients, labels, seq_lengths) in enumerate(train_loader):
    patients = patients.to(device)
    labels = labels.to(device)
    # Forward pass
    outputs = model(patients, seq_lengths)
    outputs = torch.reshape(outputs, (len(outputs),))
    loss = criterion(outputs, labels.to(torch.float32))
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Tracking
    if (i+1) % 1 == 0:
      print('Epoch: [{}/{}], Step: [{}/{}], Loss: {}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
  if (epoch+1) % 1 == 0:
    model.eval()
    # Initialize data loader for testing
    test_data = MyData(padded_X_test, trans_y_test, test_visit_lens)
    test_loader = DataLoader(test_data, batch_size=len(padded_X_test), shuffle=True)
    # Testing
    for (patients, labels, seq_lengths) in test_loader:
        patients = patients.to(device)
        labels = labels.to(device)
        pred = model(patients, seq_lengths)
        pred = torch.reshape(pred, (len(pred),))
        pred_auc = pred.detach().cpu().numpy()
        pred_acc = np.round(pred_auc)
        auc_list.append(roc_auc_score(labels.detach().cpu().numpy(), pred_auc))
        acc_list.append(accuracy_score(labels.detach().cpu().numpy(), pred_acc))
    model.train()