In [1]:
import pickle
import torch.nn as nn
from src.GRAM.gram_helpers import calculate_dimSize, get_rootCode, build_tree
from src.KAME.kame_helpers import codes2ancestors, leaf2ancestors, load_data, pad_matrix
from src.KAME.kame_module import KAME as model
import torch
import numpy as np

In [2]:
tree_file = 'outputs/mimic'
seq_file = 'outputs/mimic.seqs'
label_file = 'outputs/mimic.3digitICD9.seqs'

embd_dim_size = 100
attn_dim_size = 100
rnn_dim_size = 100
g_dim_size = 100

# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

In [3]:
device

device(type='cuda', index=0)

In [4]:
inputDimSize = calculate_dimSize(seq_file)
numClass = calculate_dimSize(label_file)
numAncestors = get_rootCode(tree_file+'.level2.pk') - inputDimSize + 1

In [5]:
inputDimSize,numAncestors, numClass

(4894, 668, 942)

In [6]:
leaves_list = []
ancestors_list = []
for i in range(5, 0, -1):
    leaves, ancestors = build_tree(tree_file + '.level' + str(i) + '.pk')
    leaves_list.extend(leaves)
    ancestors_list.extend(ancestors)
            
seqs = pickle.load(open(seq_file, 'rb'))
labels = pickle.load(open(label_file, 'rb'))

In [7]:
internal_list = []
internal_ancestors_list = []

for i in range(4, 0, -1):
    leaves, ancestors = build_tree(tree_file + '.a_level' + str(i) + '.pk')
    internal_list.extend(leaves)
    internal_ancestors_list.extend(ancestors)

In [10]:
model = model(leaves_list, ancestors_list, internal_list, internal_ancestors_list, inputDimSize, numAncestors, 
         embd_dim_size, attn_dim_size, rnn_dim_size, g_dim_size, numClass, device)

In [11]:
model.to(device)

KAME(
  (linear): Linear(in_features=100, out_features=100, bias=True)
  (fc): Linear(in_features=200, out_features=942, bias=True)
  (embed_init): Embedding(5562, 100)
  (dag_attention): DAGAttention(
    (linear1): Linear(in_features=200, out_features=100, bias=True)
    (linear2): Linear(in_features=100, out_features=1, bias=True)
  )
  (gru): GRU(100, 100, num_layers=2, batch_first=True, dropout=0.2)
  (embed_a): Embedding(669, 100)
)

In [12]:
leaf2ans = leaf2ancestors(tree_file)

In [13]:
print('Loading data ... ')
train_set, valid_set, test_set = load_data(seqs, labels, leaf2ans, inputDimSize)

data_dict = dict()
data_dict['train'] = train_set
data_dict['val'] = valid_set
data_dict['test'] = test_set
print('done!!')

Loading data ... 
done!!


In [14]:
import random
batch_size = 100
data_set = data_dict['train']
n_batches = int(np.ceil(float(len(data_set[0])) / float(batch_size)))

# Iterate over data.
for index in random.sample(range(n_batches), n_batches):
    batchX = data_set[0][index * batch_size:(index + 1) * batch_size]
    batchF = data_set[1][index * batch_size:(index + 1) * batch_size]
    batchY = data_set[2][index * batch_size:(index + 1) * batch_size]
    x, f, y, mask, lengths = pad_matrix(batchX, batchF, batchY, inputDimSize, numAncestors, numClass)
    print(x.shape, f.shape, y.shape)

(100, 1, 4894) (100, 1, 73) (100, 1, 942)
(100, 1, 4894) (100, 1, 72) (100, 1, 942)
(100, 1, 4894) (100, 1, 73) (100, 1, 942)
(100, 3, 4894) (100, 3, 69) (100, 3, 942)
(100, 1, 4894) (100, 1, 80) (100, 1, 942)
(100, 1, 4894) (100, 1, 76) (100, 1, 942)
(100, 1, 4894) (100, 1, 73) (100, 1, 942)
(100, 1, 4894) (100, 1, 73) (100, 1, 942)
(100, 2, 4894) (100, 2, 81) (100, 2, 942)
(100, 1, 4894) (100, 1, 68) (100, 1, 942)
(100, 5, 4894) (100, 5, 80) (100, 5, 942)
(100, 1, 4894) (100, 1, 77) (100, 1, 942)
(100, 3, 4894) (100, 3, 80) (100, 3, 942)
(100, 1, 4894) (100, 1, 80) (100, 1, 942)
(77, 33, 4894) (77, 33, 71) (77, 33, 942)
(100, 1, 4894) (100, 1, 74) (100, 1, 942)
(100, 1, 4894) (100, 1, 65) (100, 1, 942)
(100, 3, 4894) (100, 3, 65) (100, 3, 942)
(100, 1, 4894) (100, 1, 78) (100, 1, 942)
(100, 1, 4894) (100, 1, 77) (100, 1, 942)
(100, 1, 4894) (100, 1, 68) (100, 1, 942)
(100, 1, 4894) (100, 1, 80) (100, 1, 942)
(100, 1, 4894) (100, 1, 69) (100, 1, 942)
(100, 7, 4894) (100, 7, 76) (100, 

In [15]:
f

array([[[227, 231, 236, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]],

       [[236, 240, 241, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]],

       [[105, 106, 241, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]],

       ...,

       [[103, 105, 107, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]],

       [[227, 358, 231, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]],

       [[548, 549, 553, ...,   0,   0,   0],
        [  0,   0,   0, ...,   0,   0,   0]]])

In [16]:
batch_x = torch.from_numpy(x).to(device)
batch_f = torch.from_numpy(f).to(device)
batch_y = torch.from_numpy(y).to(device)
# lengths = torch.from_numpy(lengths).to(device)
output = model(batch_x, batch_f, mask)

In [19]:
output.shape, f.shape, batch_f.shape

(torch.Size([100, 2, 942]), (100, 2, 70), torch.Size([100, 2, 70]))

In [None]:
model.embed_a.weight

In [None]:
size = out.size()
out_re = out.reshape(size[0], size[1],1,size[2])
hl = (out_re * l).sum(dim=-1)


In [None]:
mask_ans = (batch_f > 0)

VERY_NEGATIVE_NUMBER = -1e30
mask_rank = (1-mask_ans.double()) * VERY_NEGATIVE_NUMBER

In [None]:
hl += mask_rank

In [None]:
weights = torch.softmax(hl, dim=-1)
x2 = weights.unsqueeze(3)

In [None]:
x2.shape

In [None]:
k = (x2*l).sum(dim=2)

In [None]:
s = torch.cat([out,k], dim=-1)

In [None]:
s.shape