In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import math
from ipynb.fs.full.OntologyEmbedding import FuseEmbeddings

## The Transformer

**Layer Normalization**: a technique commonly used in neural networks for normalizing the inputs to a layer.
<br/>
<br/>
$$ Layer Normalization_i = weight * \hat{x}_i + bias $$
where
$$ \hat{x}_{i,k} = {{x_{i,k} - mean_i} \over {\sqrt{varience_i + \epsilon}}} $$
and
$$ varience_i = {1 \over K}{\sum_{k=1}^{K} (x_{i,k}-\mu_i)^2}$$

<br/> 
reference: 
<br/> 
1. https://arxiv.org/pdf/1607.06450.pdf

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        mean = torch.mean(x, dim=-1, keepdim=True)
        varience = torch.pow((x - mean), 2)
        mean_varience = torch.mean(varience, dim=-1, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(mean_varience + self.variance_epsilon)
        return self.weight * x_hat + self.bias

**Sublayer Connection**: the function implemented by the sub-layer. The input x is normalized using LayerNorm and then passed into the sublayer. A dropout is applied to the output of the sublayer before it is added to the sub-layer input.
<br/>
<br/>
Each layer has two sublayers: 
1. the Multi-Head Attention Layer 
2. the Position-wise Fully Connected Feed-Forward Network Layer

<br/>
reference:
<br/>
1. http://nlp.seas.harvard.edu/annotated-transformer/  

In [4]:
class SublayerConnection(nn.Module):
    def __init__(self, hidden_size, dropout_prob):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

**Scaled-Dot Product Attention**: an attention mechanism where the dot products are scaled down by $\sqrt{d_k}$ where $d_k$ is the dimension of Q and K.

$$ Attention(Q,K,V) = softmax({{QK^T} \over {\sqrt{d_k}}})V $$
reference: 
1. https://arxiv.org/pdf/1706.03762.pdf

In [5]:
class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None):
        d_k = query.shape[0]
        numerator = torch.matmul(query, torch.transpose(key, -2, -1))
        scaled = numerator / (d_k**0.5)

        if mask is not None:
            scaled = scaled.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(scaled, dim=-1)
        
        if dropout is not None:
            attn = dropout(attn)
        
        output = torch.matmul(attn, value)
        
        return output, attn
    

**Multi-Head Attention**: GBERT specifies 4 attention heads
<br/> 
$$ MultiHead(Q,K,V) = Concat(head_1,..., head_h)W^O $$
where
$$ head_i = Attention(Q{W_i}^Q, K{W_i}^K, V{W_i}^V) $$

reference: 
1. https://arxiv.org/pdf/1706.03762.pdf

In [14]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, hidden_size, dropout_prob):
        super().__init__()
        assert hidden_size % 4 == 0

        self.d_k = hidden_size // 4
        self.h = 4

        linear_layers = []
        for i in range(3):
            linear_layers.append(nn.Linear(hidden_size, hidden_size, bias=False))
            
        self.linear_layers = nn.ModuleList(linear_layers)
        self.output_linear = nn.Linear(hidden_size, hidden_size)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]
        
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)

**Position-wise Feed Foward**: consists of two linear transformations with an activation step inbetween
<br/>
<br/>
In the paper Attention is All You Need, ReLU is used as the activation but in the G-BERT github page the guasssian error linear unit (GeLU) is used as the activation step. This is approximated as 
$$ GeLU(x) = 0.5x(1+tanh[\sqrt{2/\pi}(x+0.044715x^3)]) $$
references: 
1. https://arxiv.org/pdf/1606.08415.pdf  
2. https://arxiv.org/pdf/1706.03762.pdf
3. http://nlp.seas.harvard.edu/annotated-transformer/

In [72]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout_prob):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(hidden_size, intermediate_size)
        self.w_2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        x = self.w_1(x)
        gelu_result = 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        return self.w_2(self.dropout(gelu_result))

**Transformer Block**: this acts as the encoder and uses the two layers, multi-head attention and FNN, to encode the input. 

reference:
1. http://nlp.seas.harvard.edu/annotated-transformer/

In [73]:
class TransformerBlock(nn.Module):

    def __init__(self, hidden_size, intermediate_size, dropout_prob):
        super().__init__()
        self.attention = MultiHeadedAttention(hidden_size, dropout_prob)
        self.feed_forward = PositionwiseFeedForward(hidden_size, intermediate_size, dropout_prob)
        self.input_sublayer = SublayerConnection(hidden_size, dropout_prob)
        self.output_sublayer = SublayerConnection(hidden_size, dropout_prob)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x, mask):
        attention_layer = lambda y: self.attention.forward(y, y, y, mask=mask)
        fnn_layer = self.feed_forward
        x = self.input_sublayer(x, attention_layer)
        x = self.output_sublayer(x, fnn_layer)
        return self.dropout(x)

**Bert Embeddings**: BERT embeddings usually include word embeddings, token embeddings, and position emmbeddings, G-BERT does not include position embeddings because we are using medical codes as our input which do not have a specific order within a given visit.
<br/>
<br/>
references:
1. https://www.ijcai.org/proceedings/2019/0825.pdf
2. https://arxiv.org/pdf/1810.04805.pdf

In [74]:
class BertEmbeddings(nn.Module):

    def __init__(self, vocab_size, hidden_size, dropout_prob):
        super(BertEmbeddings, self).__init__()
        self.segment_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.token_embeddings = nn.Embedding(2, hidden_size)

        self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, input_ids, token_ids=None):
        if token_ids is None:
            token_ids = torch.zeros_like(input_ids)

        segment_embeddings = self.segment_embeddings(input_ids)

        embeddings = segment_embeddings + self.token_embeddings(token_ids)
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

**Bert Model**: this class brings it all together.  In G-BERT, ontology embeddings are used so this class can be initialize with either the regular Bert Embeddings or the ontology embeddings based on the useGraph input.

In [15]:
class BertModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout_prob, useGraph, all_conditions=None, all_drugs=None):
        super(BertModel, self).__init__()
        if useGraph:
            self.embedding = FuseEmbeddings(all_conditions, all_drugs, vocab_size)
        else:
            self.embedding = BertEmbeddings(vocab_size, hidden_size, dropout_prob)

        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(hidden_size, 300, dropout_prob), 
            TransformerBlock(hidden_size, 300, dropout_prob)
        ])
        
        self.apply(self.init_bert_weights)
        
    def init_bert_weights(self, module):
        '''
        Taken from https://github.com/huggingface/transformers/blob/78b7debf56efb907c6af767882162050d4fbb294/src/transformers/modeling_utils.py#L1596
        '''
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
            
    def forward(self, input_ids, token_ids):
        mask = (input_ids > 1).unsqueeze(1).repeat(1, input_ids.size(1), 1).unsqueeze(1)
        
        x = self.embedding(input_ids, token_ids)

        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
     
        return x, x[:, 0]