# **Collaborative Graph Learning with Auxiliary Text for Temporal Event Prediction in Healthcare**

Chang Lu, Chandan K. Reddy, Prithwish Chakraborty, Samantha Kleinberg, Yue Ning

[IJCAI 2021](https://www.ijcai.org/proceedings/2021/0486.pdf)

Model Parameters in Original Paper:

1.   Learning Rate: 0.001
2.   Number of Epochs: 200
3.   Batch Size: 32
4.   $d_c$: 32
5.   $d_p$: 16
6.   GRU Hidden Layer Dimension: 200
7.   Graph Layer Number $L$: 2
8.   $d_{c}^{(1)}$: 64
9.   $d_{p}^{(1)}$: 32
10.  $d_{c}^{(2)}$: 128
11.  Attention Dimension: 32

PyTorch Implementation by [Leisheng Yu](https://github.com/ThunderbornSakana) (leisheng.yu@alumni.emory.edu)

Code adapted from https://github.com/LuChang-CS/CGL

Text part not included (for fair comparison)

# **Diagnosis Prediction -- Multi-label Binary Prediction Task**

## **Package Setup**

In [1]:
import os
import pickle as pickle
import numpy as np
from datetime import datetime
import pandas as pd
import scipy.sparse as sps
import torch
from copy import deepcopy
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.init as init
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from collections import OrderedDict
import torch.utils.data as data
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import ndcg_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import random
import warnings
warnings.filterwarnings("ignore")

## **Load Data**

In [2]:
# Non-Binary Format Combo for CGL
train_codes_x = np.load("../MIMIC3_data/Nonbinary_Data_Format/train_codes_x.npy")
test_codes_x = np.load("../MIMIC3_data/Nonbinary_Data_Format/test_codes_x.npy")
train_codes_y = np.load('../MIMIC3_data/Nonbinary_Data_Format/train_codes_y.npy')
train_visit_lens = np.load('../MIMIC3_data/Nonbinary_Data_Format/train_visit_lens.npy')
test_codes_y = np.load('../MIMIC3_data/Nonbinary_Data_Format/test_codes_y.npy')
test_visit_lens = np.load('../MIMIC3_data/Nonbinary_Data_Format/test_visit_lens.npy')

code_levels = np.load('../MIMIC3_data/code_related/code_levels.npy')
patient_code_adj = np.load('../MIMIC3_data/code_related/patient_code_adj.npy')
code_code_adj = np.load('../MIMIC3_data/code_related/code_code_adj.npy')
with open('../MIMIC3_data/code_related/code_map.pkl', 'rb') as f13:
    code_map = pickle.load(f13)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
train_lens = torch.from_numpy(train_visit_lens).to(device)
train_x = torch.from_numpy(train_codes_x).to(device)
train_y = torch.from_numpy(train_codes_y).to(device)
test_lens = torch.from_numpy(test_visit_lens).to(device)
test_x = torch.from_numpy(test_codes_x).to(device)
test_y = torch.from_numpy(test_codes_y).to(device)

In [5]:
class MyData(data.Dataset):
    def __init__(self, data_seq, data_label, data_len):
        self.data_seq = data_seq
        self.data_label = data_label
        self.data_len = data_len
 
    def __len__(self):
        return len(self.data_seq)
 
    def __getitem__(self, idx):
        return self.data_seq[idx], self.data_label[idx], self.data_len[idx]

## **Model Starts**

In [6]:
def sequence_mask(lengths, maxlen):
    mask = ~(torch.ones((len(lengths), maxlen)).to(device).cumsum(dim=1).t() > lengths).t()
    return mask

def masked_softmax(inputs, mask):
    inputs = inputs - torch.max(inputs, dim=-1, keepdim=True)[0]
    exp = torch.exp(inputs) * mask
    result = exp / (torch.sum(exp, dim=-1, keepdim=True) + 1e-12)
    return result

In [7]:
# 3.2 The Proposed Model - CGL
# Hierarchical Embedding for Medical Codes
class HierarchicalEmbedding(nn.Module):
    def __init__(self, code_levels, code_num_in_levels, code_dims):
        super(HierarchicalEmbedding, self).__init__()
        self.level_num = len(code_num_in_levels)
        self.code_levels = code_levels
        self.level_embeddings = nn.ModuleList([nn.Embedding(code_num, code_dim) for level, (code_num, code_dim) in enumerate(zip(code_num_in_levels, code_dims))])

    def forward(self, input=None):
        embeddings = [self.level_embeddings[level](self.code_levels[:, level] - 1) for level in range(self.level_num)]
        embeddings = torch.cat(embeddings, dim=1)
        return embeddings#return: (code_num, embedding_size*4)

In [8]:
# Graph Representation + Collaborative Graph Learning
class GraphConvBlock(torch.nn.Module):
  def __init__(self, node_type, input_dim, output_dim, adj):
    super(GraphConvBlock, self).__init__()
    self.node_type = node_type
    self.adj = adj
    self.dense = torch.nn.Linear(input_dim, output_dim)
    self.bn = torch.nn.BatchNorm1d(output_dim)
    self.activation = torch.nn.ReLU()

  def forward(self, embedding, embedding_neighbor, weight_decay=None):
    output = embedding + torch.matmul(self.adj, embedding_neighbor)
    if self.node_type == 'code':
        assert weight_decay is not None
        output += torch.matmul(weight_decay, embedding)
    output = self.dense(output)
    output = self.bn(output)
    output = self.activation(output)
    return output

In [9]:
class GraphLearning(torch.nn.Module):
  def __init__(self, patient_dim, code_dim, patient_code_adj, code_code_adj, patient_hidden_dims, code_hidden_dims):
    super(GraphLearning, self).__init__()
    # Setup
    self.patient_code_adj = patient_code_adj
    self.code_patient_adj = patient_code_adj.t()
    self.code_code_adj = code_code_adj
    self.code_num = code_code_adj.shape[0]
    self.sigma = torch.nn.Sigmoid()
    # Parameters
    # For Ontology Weights
    self.miu = torch.nn.Parameter(torch.randn(self.code_num), requires_grad=True)
    self.theta = torch.nn.Parameter(torch.randn(self.code_num), requires_grad=True)
    # For L=1 Graph Layer
    self.c2p_dense_1 = torch.nn.Linear(code_dim, patient_dim)
    self.p2c_dense_1 = torch.nn.Linear(patient_dim, code_dim)
    self.patient_block_1 = GraphConvBlock('patient', patient_dim, patient_hidden_dims[0], self.patient_code_adj)
    self.code_block_1 = GraphConvBlock('code', code_dim, code_hidden_dims[0], self.code_patient_adj)
    # For L=2 Graph Layer
    self.p2c_dense_2 = torch.nn.Linear(patient_hidden_dims[0], code_hidden_dims[0])
    self.code_block_2 = GraphConvBlock('code', code_hidden_dims[0], code_hidden_dims[1], self.code_patient_adj)

  def forward(self, patient_embeddings, code_embeddings):
    ontology_weight_matrix = torch.sigmoid(self.miu * self.code_code_adj + self.theta)
    # L = 1
    code_embeddings_p = self.c2p_dense_1(code_embeddings)
    patient_embeddings_new = self.patient_block_1(patient_embeddings, code_embeddings_p)
    patient_embeddings_c = self.p2c_dense_1(patient_embeddings)
    code_embeddings = self.code_block_1(code_embeddings, patient_embeddings_c, ontology_weight_matrix)
    patient_embeddings = patient_embeddings_new
    # L = 2
    patient_embeddings_c = self.p2c_dense_2(patient_embeddings)
    code_embeddings = self.code_block_2(code_embeddings, patient_embeddings_c, ontology_weight_matrix)
    return code_embeddings

In [10]:
# Temporal Learning for Visits
class VisitEmbedding(torch.nn.Module):
  def __init__(self, max_seq_len):
    super(VisitEmbedding, self).__init__()
    self.max_seq_len = max_seq_len

  def forward(self, code_embeddings, visit_codes, visit_lens):
    # visit_codes: (batch_size, max_seq_len, max_code_num_in_a_visit)
    visit_codes = visit_codes - 1
    visit_codes_mask = (visit_codes == -1)
    visit_codes[visit_codes_mask] = 0
    visit_codes_mask = (1 - visit_codes_mask.float())
    visit_codes_num = visit_codes_mask.float().sum(dim=-1).unsqueeze(-1)
    visit_codes_embedding = code_embeddings[visit_codes] # (batch_size, max_seq_len, max_code_num_in_a_visit, code_dim)
    visit_codes_mask = visit_codes_mask.unsqueeze(-1)
    visit_codes_embedding *= visit_codes_mask  # (batch_size, max_seq_len, max_code_num_in_a_visit, code_dim)
    assert not torch.isnan(visit_codes_embedding).any()
    visit_codes_num[visit_codes_num == 0] = 1
    visits_embeddings = torch.sum(visit_codes_embedding, dim=-2) / visit_codes_num # (batch_size, max_seq_len, code_dim)
    assert not torch.isnan(visits_embeddings).any()
    visit_mask = sequence_mask(visit_lens, self.max_seq_len)  # (batch_size, max_seq_len, 1)
    visits_embeddings *= visit_mask.unsqueeze(-1)  # (batch_size, max_seq_len, code_dim)
    return visits_embeddings

In [11]:
class Attention(torch.nn.Module):
  def __init__(self, input_dim, attention_dim):
    super(Attention, self).__init__()
    self.attention_dim = attention_dim
    self.u_omega = torch.nn.Parameter(torch.randn(attention_dim), requires_grad=True)
    self.w_omega = torch.nn.Parameter(torch.randn(input_dim, attention_dim), requires_grad=True)

  def forward(self, x, mask):
    # x: (batch_size, max_seq_len, rnn_dim[-1] / hidden_size)
    t = torch.matmul(x, self.w_omega)
    vu = torch.matmul(t, self.u_omega).view(x.shape[:-1])  # (batch_size, max_seq_len)
    vu *= mask
    alphas = masked_softmax(vu, mask)
    output = torch.sum(x * alphas.unsqueeze(-1), dim=-2)  # (batch_size, rnn_dim[-1] / hidden_size)
    return output, alphas

In [12]:
class TemporalEmbedding(torch.nn.Module):
  def __init__(self, input_dim, rnn_dims, attention_dim, max_seq_len):
    super(TemporalEmbedding, self).__init__()
    self.rnn_layers = torch.nn.GRU(input_dim, rnn_dims[-1])
    self.attention = Attention(rnn_dims[-1], attention_dim)
    self.max_seq_len = max_seq_len

  def forward(self, embeddings, lens):
    seq_mask = sequence_mask(lens, self.max_seq_len)
    outputs, hn = self.rnn_layers(embeddings)
    outputs = outputs * seq_mask.unsqueeze(-1)  # (batch_size, max_seq_len, rnn_dim[-1])
    outputs, alphas = self.attention(outputs, seq_mask)  # (batch_size, rnn_dim[-1])
    return outputs, alphas

In [13]:
class CGL(torch.nn.Module):
  def __init__(self, code_map, code_levels, patient_code_adj, code_code_adj, num_train_sample, max_admission_num):
    super(CGL, self).__init__()
    # Hierarchical Embedding for Medical Codes
    code_num_in_levels = (np.max(code_levels, axis=0)).tolist()
    code_levels = torch.from_numpy(code_levels).to(device)
    code_dims = [32] * code_levels.shape[1]
    self.hier_embed_layer = HierarchicalEmbedding(code_levels, code_num_in_levels, code_dims)
    # Initialize Patient Embeddings
    self.user_emb = torch.nn.Embedding(num_train_sample, 16)
    # Collaborative Graph Learning
    patient_code_adj = torch.from_numpy(patient_code_adj).to(device).to(torch.float32)
    code_code_adj = torch.from_numpy(code_code_adj).to(device)
    code_code_adj = (code_code_adj > 0).float()
    self.graph_convolution_layer = GraphLearning(
        patient_dim=16,
        code_dim=sum(code_dims),
        patient_code_adj=patient_code_adj,
        code_code_adj=code_code_adj,
        patient_hidden_dims=[32],
        code_hidden_dims=[64, 128])
    # Temporal Learning for Visits
    self.visit_embedding_layer = VisitEmbedding(max_admission_num)
    self.visit_temporal_embedding_layer = TemporalEmbedding(128, [200], 32, max_admission_num)
    # Output FC Layer
    self.output_layer = torch.nn.Linear(200, len(code_map))
    self.softmax = torch.nn.Softmax()

  def forward(self, visit_codes, visit_lens):
    # Get Hierarchical Embedding for Medical Codes
    code_embeddings = self.hier_embed_layer(None)
    # Get Patient Embeddings
    patient_embeddings = self.user_emb.weight
    # Collaborative Graph Learning
    assert not torch.isnan(code_embeddings).any()
    code_embeddings = self.graph_convolution_layer(patient_embeddings, code_embeddings)
    # Temporal Learning for Visits
    visits_embeddings = self.visit_embedding_layer(code_embeddings, visit_codes, visit_lens)
    assert not torch.isnan(visits_embeddings).any()
    visit_output, alpha_visit = self.visit_temporal_embedding_layer(visits_embeddings, visit_lens)
    assert not torch.isnan(visit_output).any()
    # Output
    output = self.softmax(self.output_layer(visit_output))
    return output, code_embeddings

## **Training Loop**

In [14]:
model = CGL(code_map, code_levels, patient_code_adj, code_code_adj, len(train_x), torch.Tensor.size(train_x)[1])
model = model.to(device)

In [15]:
# This is for diagnosis prediction
def evaluate_model(pred, label, k1, k2, k3, k4, k5, k6):
  pred2_k1 = torch.zeros_like(pred[0])
  pred3_k1 = []
  pred2_k2 = torch.zeros_like(pred[0])
  pred3_k2 = []
  pred2_k3 = torch.zeros_like(pred[0])
  pred3_k3 = []
  pred2_k4 = torch.zeros_like(pred[0])
  pred3_k4 = []
  pred2_k5 = torch.zeros_like(pred[0])
  pred3_k5 = []
  pred2_k6 = torch.zeros_like(pred[0])
  pred3_k6 = []
  # above is for recall and precision
  true3 = [] # this is for label
  pred4 = [] # this is for ndcg
  for i in range(len(pred)):
    pred2_k1[torch.topk(pred[i], k1).indices] = 1
    pred3_k1.append(pred2_k1.cpu().detach().tolist())
    pred2_k2[torch.topk(pred[i], k2).indices] = 1
    pred3_k2.append(pred2_k2.cpu().detach().tolist())
    pred2_k3[torch.topk(pred[i], k3).indices] = 1
    pred3_k3.append(pred2_k3.cpu().detach().tolist())
    pred2_k4[torch.topk(pred[i], k4).indices] = 1
    pred3_k4.append(pred2_k4.cpu().detach().tolist())
    pred2_k5[torch.topk(pred[i], k5).indices] = 1
    pred3_k5.append(pred2_k5.cpu().detach().tolist())
    pred2_k6[torch.topk(pred[i], k6).indices] = 1
    pred3_k6.append(pred2_k6.cpu().detach().tolist())
    pred4.append(pred[i].cpu().detach().tolist())
    true3.append(label[i].cpu().detach().tolist())
  
  metric_p_1 = precision_score(true3, pred3_k1, average='samples')
  metric_p_2 = precision_score(true3, pred3_k2, average='samples')
  metric_p_3 = precision_score(true3, pred3_k3, average='samples')
  metric_p_4 = precision_score(true3, pred3_k4, average='samples')
  metric_p_5 = precision_score(true3, pred3_k5, average='samples')
  metric_p_6 = precision_score(true3, pred3_k6, average='samples')
  
  metric_r_1 = recall_score(true3, pred3_k1, average='samples')
  metric_r_2 = recall_score(true3, pred3_k2, average='samples')
  metric_r_3 = recall_score(true3, pred3_k3, average='samples')
  metric_r_4 = recall_score(true3, pred3_k4, average='samples')
  metric_r_5 = recall_score(true3, pred3_k5, average='samples')
  metric_r_6 = recall_score(true3, pred3_k6, average='samples')
  
  metric_n_1 = ndcg_score(true3, pred4, k=k1)
  metric_n_2 = ndcg_score(true3, pred4, k=k2)
  metric_n_3 = ndcg_score(true3, pred4, k=k3)
  metric_n_4 = ndcg_score(true3, pred4, k=k4)
  metric_n_5 = ndcg_score(true3, pred4, k=k5)
  metric_n_6 = ndcg_score(true3, pred4, k=k6)
  return metric_p_1, metric_r_1, metric_n_1, metric_p_2, metric_r_2, metric_n_2, metric_p_3, metric_r_3, metric_n_3, metric_p_4, metric_r_4, metric_n_4, metric_p_5, metric_r_5, metric_n_5, metric_p_6, metric_r_6, metric_n_6

In [16]:
# Initialize evaluation record lists
metric_p1_list = []
metric_p2_list = []
metric_p3_list = []
metric_p4_list = []
metric_p5_list = []
metric_p6_list = []
metric_r1_list = []
metric_r2_list = []
metric_r3_list = []
metric_r4_list = []
metric_r5_list = []
metric_r6_list = []
metric_n1_list = []
metric_n2_list = []
metric_n3_list = []
metric_n4_list = []
metric_n5_list = []
metric_n6_list = []

In [None]:
# Training mode
model.train()
# Loss and optimizer (learning rate)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Initialize data loader for training
training_data = MyData(train_x, train_y, train_lens)
train_loader = DataLoader(training_data, batch_size=32, shuffle=True)
total_step = len(train_loader)
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
  for i, (patients, labels, seq_lengths) in enumerate(train_loader):
    patients = patients.to(device)
    labels = labels.to(device)
    # Forward pass
    outputs, learned_code = model(patients, seq_lengths)
    loss = criterion(outputs, labels.to(torch.float32))
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Tracking
    if (i+1) % 100 == 0:
      print('Epoch: [{}/{}], Step: [{}/{}], Loss: {}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
  if (epoch+1) % 1 == 0:
    model.eval()
    # Initialize data loader for testing
    test_data = MyData(test_x, test_y, test_lens)
    test_loader = DataLoader(test_data, batch_size=len(test_x), shuffle=True)
    # Testing
    for (patients, labels, seq_lengths) in test_loader:
        visits_embeddings = model.visit_embedding_layer(learned_code, patients, seq_lengths)
        assert not torch.isnan(visits_embeddings).any()
        visit_output, alpha_visit = model.visit_temporal_embedding_layer(visits_embeddings, seq_lengths)
        assert not torch.isnan(visit_output).any()
        pred = model.softmax(model.output_layer(visit_output))
        # Subject to Change! @k for evaluation
        metric_p1, metric_r1, metric_n1, metric_p2, metric_r2, metric_n2, metric_p3, metric_r3, metric_n3, metric_p4, metric_r4, metric_n4, metric_p5, metric_r5, metric_n5, metric_p6, metric_r6, metric_n6, = evaluate_model(pred, labels, 5, 10, 15, 20, 25, 30)
        ###############################
        metric_p1_list.append(metric_p1)
        metric_p2_list.append(metric_p2)
        metric_p3_list.append(metric_p3)
        metric_p4_list.append(metric_p4)
        metric_p5_list.append(metric_p5)
        metric_p6_list.append(metric_p6)
        ###############################
        metric_r1_list.append(metric_r1)
        metric_r2_list.append(metric_r2)
        metric_r3_list.append(metric_r3)
        metric_r4_list.append(metric_r4)
        metric_r5_list.append(metric_r5)
        metric_r6_list.append(metric_r6)
        ###############################
        metric_n1_list.append(metric_n1)
        metric_n2_list.append(metric_n2)
        metric_n3_list.append(metric_n3)
        metric_n4_list.append(metric_n4)
        metric_n5_list.append(metric_n5)
        metric_n6_list.append(metric_n6)
    model.train()

# **Mortality Prediction -- Binary Prediction Task**

## **Package Setup**

In [1]:
import os
import pickle as pickle
import numpy as np
from datetime import datetime
import pandas as pd
import scipy.sparse as sps
import torch
from copy import deepcopy
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.init as init
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from collections import OrderedDict
import torch.utils.data as data
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import ndcg_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import random
import warnings
warnings.filterwarnings("ignore")

## **Load Data**

In [2]:
# Non-Binary Format Combo for CGL
train_codes_x = np.load("../MIMIC3_data/Nonbinary_Data_Format/train_codes_x.npy")
test_codes_x = np.load("../MIMIC3_data/Nonbinary_Data_Format/test_codes_x.npy")
train_mort = np.load('../MIMIC3_data/Nonbinary_Data_Format/train_mort.npy')
train_visit_lens = np.load('../MIMIC3_data/Nonbinary_Data_Format/train_visit_lens.npy')
test_mort = np.load('../MIMIC3_data/Nonbinary_Data_Format/test_mort.npy')
test_visit_lens = np.load('../MIMIC3_data/Nonbinary_Data_Format/test_visit_lens.npy')

code_levels = np.load('../MIMIC3_data/code_related/code_levels.npy')
patient_code_adj = np.load('../MIMIC3_data/code_related/patient_code_adj.npy')
code_code_adj = np.load('../MIMIC3_data/code_related/code_code_adj.npy')
with open('../MIMIC3_data/code_related/code_map.pkl', 'rb') as f13:
  code_map = pickle.load(f13)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
train_lens = torch.from_numpy(train_visit_lens).to(device)
train_x = torch.from_numpy(train_codes_x).to(device)
train_y = torch.from_numpy(train_mort).to(device)
test_lens = torch.from_numpy(test_visit_lens).to(device)
test_x = torch.from_numpy(test_codes_x).to(device)
test_y = torch.from_numpy(test_mort).to(device)

In [5]:
class MyData(data.Dataset):
    def __init__(self, data_seq, data_label, data_len):
        self.data_seq = data_seq
        self.data_label = data_label
        self.data_len = data_len
 
    def __len__(self):
        return len(self.data_seq)
 
    def __getitem__(self, idx):
        return self.data_seq[idx], self.data_label[idx], self.data_len[idx]

## **Model Starts**

In [6]:
def sequence_mask(lengths, maxlen):
    mask = ~(torch.ones((len(lengths), maxlen)).to(device).cumsum(dim=1).t() > lengths).t()
    return mask

def masked_softmax(inputs, mask):
    inputs = inputs - torch.max(inputs, dim=-1, keepdim=True)[0]
    exp = torch.exp(inputs) * mask
    result = exp / (torch.sum(exp, dim=-1, keepdim=True) + 1e-12)
    return result

In [7]:
# 3.2 The Proposed Model - CGL
# Hierarchical Embedding for Medical Codes
class HierarchicalEmbedding(nn.Module):
    def __init__(self, code_levels, code_num_in_levels, code_dims):
        super(HierarchicalEmbedding, self).__init__()
        self.level_num = len(code_num_in_levels)
        self.code_levels = code_levels
        self.level_embeddings = nn.ModuleList([nn.Embedding(code_num, code_dim) for level, (code_num, code_dim) in enumerate(zip(code_num_in_levels, code_dims))])

    def forward(self, input=None):
        embeddings = [self.level_embeddings[level](self.code_levels[:, level] - 1) for level in range(self.level_num)]
        embeddings = torch.cat(embeddings, dim=1)
        return embeddings#return: (code_num, embedding_size*4)

In [8]:
# Graph Representation + Collaborative Graph Learning
class GraphConvBlock(torch.nn.Module):
  def __init__(self, node_type, input_dim, output_dim, adj):
    super(GraphConvBlock, self).__init__()
    self.node_type = node_type
    self.adj = adj
    self.dense = torch.nn.Linear(input_dim, output_dim)
    self.bn = torch.nn.BatchNorm1d(output_dim)
    self.activation = torch.nn.ReLU()

  def forward(self, embedding, embedding_neighbor, weight_decay=None):
    output = embedding + torch.matmul(self.adj, embedding_neighbor)
    if self.node_type == 'code':
        assert weight_decay is not None
        output += torch.matmul(weight_decay, embedding)
    output = self.dense(output)
    output = self.bn(output)
    output = self.activation(output)
    return output

In [9]:
class GraphLearning(torch.nn.Module):
  def __init__(self, patient_dim, code_dim, patient_code_adj, code_code_adj, patient_hidden_dims, code_hidden_dims):
    super(GraphLearning, self).__init__()
    # Setup
    self.patient_code_adj = patient_code_adj
    self.code_patient_adj = patient_code_adj.t()
    self.code_code_adj = code_code_adj
    self.code_num = code_code_adj.shape[0]
    self.sigma = torch.nn.Sigmoid()
    # Parameters
    # For Ontology Weights
    self.miu = torch.nn.Parameter(torch.randn(self.code_num), requires_grad=True)
    self.theta = torch.nn.Parameter(torch.randn(self.code_num), requires_grad=True)
    # For L=1 Graph Layer
    self.c2p_dense_1 = torch.nn.Linear(code_dim, patient_dim)
    self.p2c_dense_1 = torch.nn.Linear(patient_dim, code_dim)
    self.patient_block_1 = GraphConvBlock('patient', patient_dim, patient_hidden_dims[0], self.patient_code_adj)
    self.code_block_1 = GraphConvBlock('code', code_dim, code_hidden_dims[0], self.code_patient_adj)
    # For L=2 Graph Layer
    self.p2c_dense_2 = torch.nn.Linear(patient_hidden_dims[0], code_hidden_dims[0])
    self.code_block_2 = GraphConvBlock('code', code_hidden_dims[0], code_hidden_dims[1], self.code_patient_adj)

  def forward(self, patient_embeddings, code_embeddings):
    ontology_weight_matrix = torch.sigmoid(self.miu * self.code_code_adj + self.theta)
    # L = 1
    code_embeddings_p = self.c2p_dense_1(code_embeddings)
    patient_embeddings_new = self.patient_block_1(patient_embeddings, code_embeddings_p)
    patient_embeddings_c = self.p2c_dense_1(patient_embeddings)
    code_embeddings = self.code_block_1(code_embeddings, patient_embeddings_c, ontology_weight_matrix)
    patient_embeddings = patient_embeddings_new
    # L = 2
    patient_embeddings_c = self.p2c_dense_2(patient_embeddings)
    code_embeddings = self.code_block_2(code_embeddings, patient_embeddings_c, ontology_weight_matrix)
    return code_embeddings

In [10]:
# Temporal Learning for Visits
class VisitEmbedding(torch.nn.Module):
  def __init__(self, max_seq_len):
    super(VisitEmbedding, self).__init__()
    self.max_seq_len = max_seq_len

  def forward(self, code_embeddings, visit_codes, visit_lens):
    # visit_codes: (batch_size, max_seq_len, max_code_num_in_a_visit)
    visit_codes = visit_codes - 1
    visit_codes_mask = (visit_codes == -1)
    visit_codes[visit_codes_mask] = 0
    visit_codes_mask = (1 - visit_codes_mask.float())
    visit_codes_num = visit_codes_mask.float().sum(dim=-1).unsqueeze(-1)
    visit_codes_embedding = code_embeddings[visit_codes] # (batch_size, max_seq_len, max_code_num_in_a_visit, code_dim)
    visit_codes_mask = visit_codes_mask.unsqueeze(-1)
    visit_codes_embedding *= visit_codes_mask  # (batch_size, max_seq_len, max_code_num_in_a_visit, code_dim)
    assert not torch.isnan(visit_codes_embedding).any()
    visit_codes_num[visit_codes_num == 0] = 1
    visits_embeddings = torch.sum(visit_codes_embedding, dim=-2) / visit_codes_num # (batch_size, max_seq_len, code_dim)
    assert not torch.isnan(visits_embeddings).any()
    visit_mask = sequence_mask(visit_lens, self.max_seq_len)  # (batch_size, max_seq_len, 1)
    visits_embeddings *= visit_mask.unsqueeze(-1)  # (batch_size, max_seq_len, code_dim)
    return visits_embeddings

In [11]:
class Attention(torch.nn.Module):
  def __init__(self, input_dim, attention_dim):
    super(Attention, self).__init__()
    self.attention_dim = attention_dim
    self.u_omega = torch.nn.Parameter(torch.randn(attention_dim), requires_grad=True)
    self.w_omega = torch.nn.Parameter(torch.randn(input_dim, attention_dim), requires_grad=True)

  def forward(self, x, mask):
    # x: (batch_size, max_seq_len, rnn_dim[-1] / hidden_size)
    t = torch.matmul(x, self.w_omega)
    vu = torch.matmul(t, self.u_omega).view(x.shape[:-1])  # (batch_size, max_seq_len)
    vu *= mask
    alphas = masked_softmax(vu, mask)
    output = torch.sum(x * alphas.unsqueeze(-1), dim=-2)  # (batch_size, rnn_dim[-1] / hidden_size)
    return output, alphas

In [12]:
class TemporalEmbedding(torch.nn.Module):
  def __init__(self, input_dim, rnn_dims, attention_dim, max_seq_len):
    super(TemporalEmbedding, self).__init__()
    self.rnn_layers = torch.nn.GRU(input_dim, rnn_dims[-1])
    self.attention = Attention(rnn_dims[-1], attention_dim)
    self.max_seq_len = max_seq_len

  def forward(self, embeddings, lens):
    seq_mask = sequence_mask(lens, self.max_seq_len)
    outputs, hn = self.rnn_layers(embeddings)
    outputs = outputs * seq_mask.unsqueeze(-1)  # (batch_size, max_seq_len, rnn_dim[-1])
    outputs, alphas = self.attention(outputs, seq_mask)  # (batch_size, rnn_dim[-1])
    return outputs, alphas

In [13]:
class CGL(torch.nn.Module):
  def __init__(self, code_map, code_levels, patient_code_adj, code_code_adj, num_train_sample, max_admission_num):
    super(CGL, self).__init__()
    # Hierarchical Embedding for Medical Codes
    code_num_in_levels = (np.max(code_levels, axis=0)).tolist()
    code_levels = torch.from_numpy(code_levels).to(device)
    code_dims = [32] * code_levels.shape[1]
    self.hier_embed_layer = HierarchicalEmbedding(code_levels, code_num_in_levels, code_dims)
    # Initialize Patient Embeddings
    self.user_emb = torch.nn.Embedding(num_train_sample, 16)
    # Collaborative Graph Learning
    patient_code_adj = torch.from_numpy(patient_code_adj).to(device).to(torch.float32)
    code_code_adj = torch.from_numpy(code_code_adj).to(device)
    code_code_adj = (code_code_adj > 0).float()
    self.graph_convolution_layer = GraphLearning(
        patient_dim=16,
        code_dim=sum(code_dims),
        patient_code_adj=patient_code_adj,
        code_code_adj=code_code_adj,
        patient_hidden_dims=[32],
        code_hidden_dims=[64, 128])
    # Temporal Learning for Visits
    self.visit_embedding_layer = VisitEmbedding(max_admission_num)
    self.visit_temporal_embedding_layer = TemporalEmbedding(128, [200], 32, max_admission_num)
    # Output FC Layer
    self.output_layer = torch.nn.Linear(200, 1)
    self.sigmoid = torch.nn.Sigmoid()

  def forward(self, visit_codes, visit_lens):
    # Get Hierarchical Embedding for Medical Codes
    code_embeddings = self.hier_embed_layer(None)
    # Get Patient Embeddings
    patient_embeddings = self.user_emb.weight
    # Collaborative Graph Learning
    assert not torch.isnan(code_embeddings).any()
    code_embeddings = self.graph_convolution_layer(patient_embeddings, code_embeddings)
    # Temporal Learning for Visits
    visits_embeddings = self.visit_embedding_layer(code_embeddings, visit_codes, visit_lens)
    assert not torch.isnan(visits_embeddings).any()
    visit_output, alpha_visit = self.visit_temporal_embedding_layer(visits_embeddings, visit_lens)
    assert not torch.isnan(visit_output).any()
    # Output
    output = self.sigmoid(self.output_layer(visit_output))
    return output, code_embeddings

## **Training Loop**

In [14]:
model = CGL(code_map, code_levels, patient_code_adj, code_code_adj, len(train_x), torch.Tensor.size(train_x)[1])
model = model.to(device)

In [15]:
# Initialize evaluation record lists
auc_list = []
acc_list = []

In [None]:
# Training mode
model.train()
# Loss and optimizer (learning rate)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Initialize data loader for training
training_data = MyData(train_x, train_y, train_lens)
train_loader = DataLoader(training_data, batch_size=32, shuffle=True)
total_step = len(train_loader)
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
  for i, (patients, labels, seq_lengths) in enumerate(train_loader):
    patients = patients.to(device)
    labels = labels.to(device)
    # Forward pass
    outputs, learned_code = model(patients, seq_lengths)
    outputs = torch.reshape(outputs, (len(outputs),))
    loss = criterion(outputs, labels.to(torch.float32))
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Tracking
    if (i+1) % 100 == 0:
      print('Epoch: [{}/{}], Step: [{}/{}], Loss: {}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
  if (epoch+1) % 1 == 0:
    model.eval()
    # Initialize data loader for testing
    test_data = MyData(test_x, test_y, test_lens)
    test_loader = DataLoader(test_data, batch_size=len(test_x), shuffle=True)
    # Testing
    for (patients, labels, seq_lengths) in test_loader:
        visits_embeddings = model.visit_embedding_layer(learned_code, patients, seq_lengths)
        assert not torch.isnan(visits_embeddings).any()
        visit_output, alpha_visit = model.visit_temporal_embedding_layer(visits_embeddings, seq_lengths)
        assert not torch.isnan(visit_output).any()
        pred = model.sigmoid(model.output_layer(visit_output))
        pred = torch.reshape(pred, (len(pred),))
        pred_auc = pred.detach().cpu().numpy()
        pred_acc = np.round(pred_auc)
        auc_list.append(roc_auc_score(labels.detach().cpu().numpy(), pred_auc))
        acc_list.append(accuracy_score(labels.detach().cpu().numpy(), pred_acc))
    model.train()