In [4]:
from pyhealth.datasets import MIMIC3Dataset


In [5]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch_geometric.utils as utils

seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [6]:

from edge_index import *

In [7]:
from pyhealth.medcode import InnerMap
ICD9 = InnerMap.load("ICD9CM")
ATC = InnerMap.load("ATC")

In [8]:
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import drug_recommendation_mimic3_fn

mimic3_ds = MIMIC3Dataset(
        root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
        code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
)

# we show the statistics below.
mimic3_ds.stat()
dataset = mimic3_ds.set_task(task_fn=drug_recommendation_mimic3_fn)

Generating samples for drug_recommendation_mimic3_fn:  11%|██▎                 | 5711/49993 [00:00<00:00, 57092.90it/s]


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 49993
	- Number of visits: 52769
	- Number of visits per patient: 1.0555
	- Number of events per visit in DIAGNOSES_ICD: 9.1038
	- Number of events per visit in PROCEDURES_ICD: 3.2186
	- Number of events per visit in PRESCRIPTIONS: 32.9969



Generating samples for drug_recommendation_mimic3_fn: 100%|███████████████████| 49993/49993 [00:00<00:00, 52836.05it/s]


In [137]:
m = 0
for i in range(len(dataset.samples)):
    if len(dataset.samples[i]["drugs"])>m:
        m = len(dataset.samples[i]["drugs"])
        
m

96

In [192]:
ATC_map,ATC_child_edge_index,ATC_parent_edge_index = get_edge_index(ATC)

In [193]:
ICD9_map,ICD9_child_edge_index,ICD9_parent_edge_index = get_edge_index(ICD9)

In [217]:
ICD9.standardize("1363")

'136.3'

In [195]:
ICD9_child_edge_index = utils.add_self_loops(torch.Tensor(ICD9_child_edge_index))
ICD9_parent_edge_index = utils.add_self_loops(torch.Tensor(ICD9_parent_edge_index))
ATC_child_edge_index = utils.add_self_loops(torch.Tensor(ATC_child_edge_index))
ATC_parent_edge_index = utils.add_self_loops(torch.Tensor(ATC_parent_edge_index))


In [196]:
ICD9_size = len(ICD9_map)
ATC_size = len(ATC_map)

v = torch.ones(ICD9_child_edge_index[0].size()[1])
ICD9_child_adj = torch.sparse_coo_tensor(ICD9_child_edge_index[0],
                               v,(ICD9_size,ICD9_size))
v = torch.ones(ICD9_parent_edge_index[0].size()[1])
ICD9_parent_adj = torch.sparse_coo_tensor(ICD9_parent_edge_index[0],
                               v,(ICD9_size,ICD9_size))
v = torch.ones(ATC_child_edge_index[0].size()[1])
ATC_child_adj = torch.sparse_coo_tensor(ATC_child_edge_index[0],
                               v,(ICD9_size,ICD9_size))
v = torch.ones(ATC_parent_edge_index[0].size()[1])
ATC_parent_adj = torch.sparse_coo_tensor(ATC_parent_edge_index[0],
                               v,(ICD9_size,ICD9_size))


In [338]:
ATC_parent_edge_index[0]

tensor([[  14.,   15.,   16.,  ..., 6437., 6438., 6439.],
        [   0.,    0.,    0.,  ..., 6437., 6438., 6439.]])

In [359]:
def x2sparce_step1(x,parent_adj,child_adj,idx_map,code_bank):
    indices = [[0],[0]]
    values = []
    size = len(idx_map)
    first = True
    for code in x:
        code = code_bank.standardize(code)
        l = parent_adj[idx_map[code]]._indices()[0]
        l = l.tolist()
        print(l)
        row = []
        col = []
        v = []
        for i in l:
            tmp=child_adj[i]._indices().tolist()
            for j in tmp[0]:
                row.append(i)
                col.append(j)
                v.append(1)
        if first:
            indices[0] = row
            indices[1] = col
            first = False
        else:
            indices[0] = indices[0]+row
            indices[1] = indices[1]+col
        values = values + v
    x_sparse = torch.sparse_coo_tensor(indices,
                                       values,
                                       size=(size,size))
    return x_sparse


def x2sparce_step2(x,parent_adj,child_adj,idx_map,code_bank):
    indices = [[0],[0]]
    values = []
    size = len(idx_map)
    first = True
    for code in x:
        code = code_bank.standardize(code)
        l = parent_adj[idx_map[code]]._indices()[0]
        l = l.tolist()
        row = []
        col = []
        v = []
        for i in l:
            row.append(idx_map[code])
            col.append(i)
            v.append(1)
        if first:
            indices[0] = row
            indices[1] = col
            first = False
        else:
            indices[0] = indices[0]+row
            indices[1] = indices[1]+col
        values = values + v
    print(len(indices[0]),len(values))
    x_sparse = torch.sparse_coo_tensor(indices,
                                       values,
                                       size=(size,size))
    return x_sparse
    

def parse_samples(input_data, type_of_data,parent_adj,child_adj,idx_map,code_bank):
    indices = []
    values = []
    first = True
    for visit in input_data:
        x = visit[type_of_data]
        c = x2sparce_step2(x,parent_adj,child_adj,idx_map,code_bank)
        c = c.unsqueeze(0)
        
        if first:
            sparse = c
            first = False
        else:
            sparse = torch.cat((sparse,c),dim=0)
    return sparse

def parse_samples_step2(input_data, type_of_data,parent_adj,child_adj,idx_map,code_bank):
    indices = []
    values = []
    first = True
    for visit in input_data:
        x = visit[type_of_data]
        c = x2sparce_step2(x,parent_adj,child_adj,idx_map,code_bank)
        c = c.unsqueeze(0)
        
        if first:
            sparse = c
            first = False
        else:
            sparse = torch.cat((sparse,c),dim=0)
    return sparse

parse_samples_step2(test,"ICD9_CODE",ICD9_parent_adj,ICD9_child_adj,ICD9_map,ICD9)


43 43
40 40
27 27
29 29
47 47
27 27
66 66
45 45
27 27
24 24
23 23
48 48
18 18
55 55
34 34
16 16
59 59
47 47
68 68
93 93
38 38
11 11
52 52
41 41
56 56
48 48
23 23
13 13
53 53
42 42
63 63
31 31


tensor(indices=tensor([[    0,     0,     0,  ...,    31,    31,    31],
                       [ 5873,  5873,  5873,  ...,  9952,  9952,  9952],
                       [  306, 16383,   928,  ...,  4416,   928,  9952]]),
       values=tensor([1, 1, 1,  ..., 1, 1, 1]),
       size=(32, 17736, 17736), nnz=1307, layout=torch.sparse_coo)

In [395]:
class GAT(torch.nn.Module):
    def __init__(self,child_adj,parent_adj,code2idx,code_bank,input_name):
        super(GAT, self).__init__()
        #input_dim to be the number of codes in the tree (ATC or ICD9).
        #output dim to be 100 for tractibility
        self.input_dim = len(code2idx)
        self.output_dim = 100
        self.weights_step1 = nn.Parameter(torch.Tensor(self.input_dim, self.output_dim))
        nn.init.xavier_uniform_(self.weights_step1)
        self.weights_step1.type(torch.DoubleTensor)
        #Adjacency matrix generated outside of the class
        self.child_adj = child_adj
        self.parent_adj = parent_adj
        #codes generated outside of the class
        self.code2idx = code2idx
        self.code_bank = code_bank
        self.relu = nn.ReLU()
        self.sm = nn.Softmax(dim=1)
        
        #name of table from x data that is being used. "drugs"/"diagnosis"
        self.input_name = input_name
        
      
    def forward(self,x):
        # x to be passed in as [batch_size (visits), codes]
        # first step is to process x into graph adjacency matrix [batch_size,#codes,#codes]
        # has to be sparse matrices due to memory load
        x_sparse = self.parse_samples_step1(x)
        h = torch.zeros((len(x),self.weights_step1.shape[0],self.weights_step1.shape[1]))
        for i in range(x_sparse.shape[0]):
            h[i] = torch.matmul(x_sparse[i].float(),self.weights_step1)
        
        x_sparse = self.parse_samples_step2(x)
        for i in range(x_sparse.shape[0]):
            h[i] = torch.matmul(x_sparse[i].float(),h[i])
        h[h==0] = -1e9
        att = self.sm(h)
        att = self.relu(h)
        
        
        return att
        
    def x2sparce_step1(self, x):
        indices = [[0],[0]]
        values = []
        size = len(self.code2idx)
        first = True
        for code in x:
            code = self.code_bank.standardize(code)
            
            l = self.parent_adj[self.code2idx[code]]._indices()[0]
            l = l.tolist()
            row = []
            col = []
            v = []
            for i in l:
                tmp=self.child_adj[i]._indices().tolist()
                for j in tmp[0]:
                    row.append(i)
                    col.append(j)
                    v.append(1)
            if first:
                indices[0] = row
                indices[1] = col
                first = False
            else:
                indices[0] = indices[0]+row
                indices[1] = indices[1]+col
            values = values + v
        x_sparse = torch.sparse_coo_tensor(indices,
                                           values,
                                           size=(size,size))
        return x_sparse

    
    def x2sparce_step2(self,x):
        indices = [[0],[0]]
        values = []
        size = len(self.code2idx)
        first = True
        for code in x:
            code = self.code_bank.standardize(code)
            l = self.parent_adj[self.code2idx[code]]._indices()[0]
            l = l.tolist()
            row = []
            col = []
            v = []
            for i in l:
                row.append(self.code2idx[code])
                col.append(i)
                v.append(1)
            if first:
                indices[0] = row
                indices[1] = col
                first = False
            else:
                indices[0] = indices[0]+row
                indices[1] = indices[1]+col
            values = values + v
        x_sparse = torch.sparse_coo_tensor(indices,
                                           values,
                                           size=(size,size))
        return x_sparse

    def parse_samples_step1(self, input_data):
        indices = []
        values = []
        first = True
        for visit in input_data:
            x = visit[self.input_name]
            c = self.x2sparce_step1(x)
            c = c.unsqueeze(0)

            if first:
                sparse = c
                first = False
            else:
                sparse = torch.cat((sparse,c),dim=0)
        return sparse
    
    
    def parse_samples_step2(self,input_data):
        indices = []
        values = []
        first = True
        for visit in input_data:
            x = visit[self.input_name]
            c = self.x2sparce_step2(x)
            c = c.unsqueeze(0)

            if first:
                sparse = c
                first = False
            else:
                sparse = torch.cat((sparse,c),dim=0)
        return sparse

In [400]:
diag_ont = GAT(ICD9_child_adj,ICD9_parent_adj,ICD9_map,ICD9,"ICD9_CODE")
diag_ont.forward(test).max()

tensor(1.9343, grad_fn=<MaxBackward1>)

In [74]:

import pickle


with open('data-single-visit.pkl', 'rb') as f:
    data = pickle.load(f)

In [183]:
ICD9.lookup("1363")

'Pneumocystosis'

In [180]:
data

Unnamed: 0,SUBJECT_ID,HADM_ID,ICD9_CODE,ATC4
0,4,185777,"[042, 1363, 7994, 2763, 7907, 5715, 04111, V09...","[A10AE, R05DB, R05DA, P01BA, A12BA, N03AE, A07..."
1,6,107064,"[40391, 4440, 9972, 2766, 2767, 2859, 2753, V1...","[A06AA, L04AA, A12CA, C07AB, C08CA, B01AB, D11..."
2,8,159514,"[V3001, 7706, 7746, V290, V502, V053]","[J01CA, D06AX]"
3,9,150750,"[431, 5070, 4280, 5849, 2765, 4019]","[B05CX, A12CA, J01MA, A01AB, C02DD, N01AX, A02..."
4,12,112213,"[1570, 57410, 9971, 4275, 99811, 4019, 5680, 5...","[B01AB, A12CA, B05CX, N01AX, A02BC, A12AA, A01..."
...,...,...,...,...
30740,99985,176670,"[0389, 51881, 48241, 4870, 78552, V4281, 99592...","[B01AB, A12CA, N02BE, A04AA, J01MA, A06AA, C10..."
30741,99991,151118,"[56211, 0389, 5570, 5849, 99592, 56081, 78959,...","[B01AB, A12CA, N02BE, B05CX, A01AB, C02DB, A12..."
30742,99992,197084,"[9999, 56881, 5772, 2851, 5849, 72992, 53081, ...","[N02BE, B05CX, N05BA, B03BB, A02BC, N06AX, C01..."
30743,99995,137810,"[4414, 42833, 99812, 2851, 4241, 25000, 99811,...","[B01AB, N02BE, C02DB, C07AB, B05CX, A07AA, C01..."
