In [23]:
import ipynb.fs.full.ExtractData as extractData
import torch
from torch import nn
import torch.nn.functional as F
import math
from torch_geometric.utils import add_self_loops, softmax, scatter_
import inspect

## Ontology Embedding
This file contains the methods for creating the ontology embeddings for G-BERT.  Most of the code is taken from the G-BERT Github Repository. 

In [17]:
def glorot(tensor):
    stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
    if tensor is not None:
        tensor.data.uniform_(-stdv, stdv)

In [18]:
def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

In [19]:
class GATConv(nn.Module):
    r"""The graph attentional operator from the `"Graph Attention Networks"
    <https://arxiv.org/abs/1710.10903>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{j} +
        \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},

    where the attention coefficients :math:`\alpha_{i,j}` are computed as

    .. math::
        \alpha_{i,j} =
        \frac{
        \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
        [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
        \right)\right)}
        {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
        \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
        [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
        \right)\right)}.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        heads (int, optional): Number of multi-head-attentions. (default:
            :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
        attentions are averaged instead of concatenated. (default: :obj:`True`)
        negative_slope (float, optional): LeakyReLU angle of the negative
            slope. (default: :obj:`0.2`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 heads=1,
                 negative_slope=0.2,
                 dropout=0,
                 bias=True):
        super(GATConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.weight = nn.Parameter(torch.Tensor(in_channels, heads * out_channels))
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * out_channels))
        self.bias = nn.Parameter(torch.Tensor(heads * out_channels))
        
        self.message_args = inspect.getargspec(self.message)[0][1:]
        self.update_args = inspect.getargspec(self.update)[0][2:]

        self.reset_parameters()
    
    def propagate(self, aggr, edge_index, **kwargs):
        r"""The initial call to start propagating messages.
        Takes in an aggregation scheme (:obj:`"add"`, :obj:`"mean"` or
        :obj:`"max"`), the edge indices, and all additional data which is
        needed to construct messages and to update node embeddings."""

        assert aggr in ['add', 'mean', 'max']
        kwargs['edge_index'] = edge_index

        size = None
        message_args = []
        for arg in self.message_args:
            if arg[-2:] == '_i':
                tmp = kwargs[arg[:-2]]
                size = tmp.size(0)
                message_args.append(tmp[edge_index[0]])
            elif arg[-2:] == '_j':
                tmp = kwargs[arg[:-2]]
                size = tmp.size(0)
                message_args.append(tmp[edge_index[1]])
            else:
                message_args.append(kwargs[arg])

        update_args = [kwargs[arg] for arg in self.update_args]

        out = self.message(*message_args)
        out = scatter_(aggr, out, edge_index[0], dim_size=size)
        out = self.update(out, *update_args)

        return out

    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.att)
        zeros(self.bias)

    def forward(self, x, edge_index):
        edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
        x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
        return self.propagate('add', edge_index, x=x, num_nodes=x.size(0))

    def message(self, x_i, x_j, edge_index, num_nodes):
        # Compute attention coefficients.
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index[0], num_nodes)

        alpha = F.dropout(alpha, p=self.dropout)

        return x_j * alpha.view(-1, self.heads, 1)

    def update(self, aggr_out):
        aggr_out = aggr_out.view(-1, self.heads * self.out_channels)

        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

In [5]:
def expand_level2():
    level2 = ['001-009', '010-018', '020-027', '030-041', '042', '045-049', '050-059', '060-066', '070-079', '080-088', '090-099', '100-104', '110-118', '120-129', '130-136', '137-139', '140-149', '150-159', '160-165', '170-176', '176', '179-189', '190-199', '200-208', '209', '210-229', '230-234', '235-238', '239', '240-246', '249-259', '260-269', '270-279', '280-289', '290-294', '295-299', '300-316', '317-319', '320-327', '330-337', '338', '339', '340-349', '350-359', '360-379', '380-389', '390-392', '393-398', '401-405', '410-414', '415-417', '420-429', '430-438', '440-449', '451-459', '460-466', '470-478', '480-488', '490-496', '500-508', '510-519', '520-529', '530-539', '540-543', '550-553', '555-558', '560-569', '570-579', '580-589', '590-599', '600-608', '610-611', '614-616', '617-629', '630-639', '640-649', '650-659', '660-669', '670-677', '678-679', '680-686', '690-698', '700-709', '710-719', '720-724', '725-729', '730-739', '740-759', '760-763', '764-779', '780-789', '790-796', '797-799', '800-804', '805-809', '810-819', '820-829', '830-839', '840-848', '850-854', '860-869', '870-879', '880-887', '890-897', '900-904', '905-909', '910-919', '920-924', '925-929', '930-939', '940-949', '950-957', '958-959', '960-979', '980-989', '990-995', '996-999', 'V01-V91', 'V01-V09', 'V10-V19', 'V20-V29', 'V30-V39', 'V40-V49', 'V50-V59', 'V60-V69', 'V70-V82', 'V83-V84', 'V85', 'V86', 'V87', 'V88', 'V89', 'V90', 'V91', 'E000-E899', 'E000', 'E001-E030', 'E800-E807', 'E810-E819', 'E820-E825', 'E826-E829', 'E830-E838', 'E840-E845', 'E846-E849', 'E850-E858', 'E860-E869', 'E870-E876', 'E878-E879', 'E880-E888', 'E890-E899', 'E900-E909', 'E910-E915', 'E916-E928', 'E929', 'E930-E949', 'E950-E959', 'E960-E969', 'E970-E978', 'E980-E989', 'E990-E999']

    level2_expand = {}
    for i in level2:
        tokens = i.split('-')
        if i[0] == 'V':
            if len(tokens) == 1:
                level2_expand[i] = i
            else:
                for j in range(int(tokens[0][1:]), int(tokens[1][1:]) + 1):
                    level2_expand["V%02d" % j] = i
        elif i[0] == 'E':
            if len(tokens) == 1:
                level2_expand[i] = i
            else:
                for j in range(int(tokens[0][1:]), int(tokens[1][1:]) + 1):
                    level2_expand["E%03d" % j] = i
        else:
            if len(tokens) == 1:
                level2_expand[i] = i
            else:
                for j in range(int(tokens[0]), int(tokens[1]) + 1):
                    level2_expand["%03d" % j] = i
    
    return level2_expand


In [6]:
def build_icd9_tree(codes):
    
    icd9_tree = []
    vocabulary = []

    root_node = 'icd9_root'
    level3_dict = expand_level2()
    for code in codes:
        level1 = code
        level2 = level1[:4] if level1[0] == 'E' else level1[:3]
        level3 = level3_dict[level2]
        level4 = root_node

        sample = [level1, level2, level3, level4]

        for i in sample:
            vocabulary.append(i)
        icd9_tree.append(sample)

    return icd9_tree, vocabulary

In [7]:
def build_atc_tree(codes):
    atc_tree = []
    vocabulary = set()
    
    root_node = 'atc_root'
    for code in codes:
        sample = [code]
        sample.append(code[:4])
        sample.append(code[:3])
        sample.append(code[:1])
        sample.append(root_node)

        atc_tree.append(sample)
        for i in sample:
            vocabulary.add(i)
        

    return atc_tree, list(vocabulary)


In [8]:
def build_stage_one_edges(tree, vocabulary):
    edge_idx = set()
    for sample in tree:
        sample_idx = list(map(lambda word: vocabulary.index(word), sample))

        for i in range(len(sample_idx) - 1):
            # only direct children -> ancestor
            edge_idx.add((sample_idx[i+1], sample_idx[i]))

    edge_idx = list(edge_idx)
    row = list(map(lambda x: x[0], edge_idx))
    col = list(map(lambda x: x[1], edge_idx))
    return [row, col]

In [9]:
def build_stage_two_edges(tree, vocabulary):
    edge_idx = []
    for sample in tree:
        sample_idx = list(map(lambda word: vocabulary.index(word), sample))
        # only ancestors -> leaf node
        edge_idx.extend([(sample_idx[0], sample_idx[i])
                         for i in range(1, len(sample_idx))])

    edge_idx = list(set(edge_idx))
    row = list(map(lambda x: x[0], edge_idx))
    col = list(map(lambda x: x[1], edge_idx))
    return [row, col]

In [10]:
class OntologyEmbedding(nn.Module):
    def __init__(self, all_codes, build_tree_func):
        super(OntologyEmbedding, self).__init__()

        # initial tree edges
        tree, vocabulary = build_tree_func(all_codes)
        stage_one_edges = build_stage_one_edges(tree, vocabulary)
        stage_two_edges = build_stage_two_edges(tree, vocabulary)

        self.edges1 = torch.tensor(stage_one_edges)
        self.edges2 = torch.tensor(stage_two_edges)
        self.graph_vocab = vocabulary

        # construct model
        self.g = GATConv(in_channels=300, out_channels=75, heads=4)

        # tree embedding
        num_nodes = len(vocabulary)
        self.embedding = nn.Parameter(torch.Tensor(num_nodes, 300))

        # idx mapping: FROM leaf node in graphvoc TO voc
        self.idx_mapping = [self.graph_vocab.index(code) for code in all_codes]

        self.init_params()

    def forward(self):
        emb = self.embedding

        emb = self.g(self.g(emb, self.edges1.to(emb.device)), self.edges2.to(emb.device))

        return emb[self.idx_mapping]

    def init_params(self):
        glorot(self.embedding)

In [11]:
class FuseEmbeddings(nn.Module):
    """Construct the embeddings from ontology, patient info and type embeddings.
    """

    def __init__(self, all_conditions, all_drugs, vocab_size):
        super(FuseEmbeddings, self).__init__()
        self.special_embedding = nn.Parameter(torch.Tensor(3, 300))
        self.drug_embedding = OntologyEmbedding(all_drugs, build_atc_tree)
        self.conditions_embedding = OntologyEmbedding(all_conditions, build_icd9_tree)
        
        self.init_params()
        self.type_embedding = nn.Embedding(2, 300)

    def forward(self, input_ids, input_types=None):
        # return self.ontology_embedding(input_ids)
        concat_embeddings = torch.cat([self.special_embedding, self.drug_embedding(), self.conditions_embedding()], dim=0)
        ontology_embedding = concat_embeddings[input_ids] + self.type_embedding(input_types)
        return ontology_embedding
    
    def init_params(self):
        glorot(self.special_embedding)