In [1]:
import transformers as t
import numpy as np
import torch

In [2]:
from bertviz.transformers_neuron_view import BertModel
model_ckpt = "bert-base-uncased"
tokenizer = t.AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
text = "Infinity war is a great movie"

In [3]:
from bertviz.neuron_view import show
show(model,"bert",tokenizer,text,display_mode="light",layer=0,head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [4]:
inputs = tokenizer(text,return_tensors="pt",add_special_tokens=False)
inputs

{'input_ids': tensor([[15579,  2162,  2003,  1037,  2307,  3185]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [5]:
config = t.AutoConfig.from_pretrained(model_ckpt)
token_emb = torch.nn.Embedding(config.vocab_size,config.hidden_size)
token_emb

Embedding(30522, 768)

In [6]:
inputs_embeds = token_emb(inputs.input_ids)
inputs_embeds.size()

torch.Size([1, 6, 768])

In [7]:
import math

query = key = value = inputs_embeds
scores = torch.bmm(query,key.transpose(1,2))/math.sqrt(key.size(-1))
scores.size()

torch.Size([1, 6, 6])

In [8]:
scores

tensor([[[25.9591,  0.4989,  0.8043, -0.4755,  0.9081, -2.8453],
         [ 0.4989, 26.3869, -1.1003,  0.6698, -1.1015, -1.5000],
         [ 0.8043, -1.1003, 26.6376, -1.3577, -1.2122,  1.3938],
         [-0.4755,  0.6698, -1.3577, 27.2753,  0.1375,  0.1525],
         [ 0.9081, -1.1015, -1.2122,  0.1375, 28.0936, -0.5468],
         [-2.8453, -1.5000,  1.3938,  0.1525, -0.5468, 26.1624]]],
       grad_fn=<DivBackward0>)

In [9]:
from torch.nn import functional as F

weights = F.softmax(scores,dim=-1)
weights

tensor([[[1.0000e+00, 8.7652e-12, 1.1896e-11, 3.3082e-12, 1.3198e-11,
          3.0932e-13],
         [5.7147e-12, 1.0000e+00, 1.1547e-12, 6.7798e-12, 1.1534e-12,
          7.7423e-13],
         [6.0359e-12, 8.9862e-13, 1.0000e+00, 6.9471e-13, 8.0352e-13,
          1.0883e-11],
         [8.8716e-13, 2.7886e-12, 3.6717e-13, 1.0000e+00, 1.6376e-12,
          1.6624e-12],
         [1.5614e-12, 2.0930e-13, 1.8736e-13, 7.2249e-13, 1.0000e+00,
          3.6446e-13],
         [2.5243e-13, 9.6909e-13, 1.7504e-11, 5.0590e-12, 2.5139e-12,
          1.0000e+00]]], grad_fn=<SoftmaxBackward0>)

In [10]:
attn_outputs = torch.bmm(weights,value)
attn_outputs.size()

torch.Size([1, 6, 768])

In [12]:
class AttentionHead(torch.nn.Module):

    def __init__(self,embed_dim,head_dim):

        super().__init__()
        self.q = torch.nn.Linear(embed_dim,head_dim)
        self.k = torch.nn.Linear(embed_dim,head_dim)
        self.v = torch.nn.Linear(embed_dim,head_dim)
    
    def _scaled_dot_product_attention(self,query,key,value):

        scores = torch.bmm(query,key.transpose(1,2))/math.sqrt(key.size(-1))
        weights = F.softmax(scores,dim=-1)
        attn_outputs = torch.bmm(weights,value)
        return attn_outputs
    
    def forward(self,hidden_state):
        attn_outputs = self._scaled_dot_product_attention(
            self.q(hidden_state),self.k(hidden_state),self.v(hidden_state)
        ) 
        return attn_outputs

In [13]:
class MultiHeadAttention(torch.nn.Module):

    def __init__(self,config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim//num_heads
        self.heads = torch.nn.ModuleList(
            [AttentionHead(embed_dim,head_dim) for _ in range(num_heads)]
        )
        self.output_linear = torch.nn.Linear(embed_dim,embed_dim)
    
    def forward(self,hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads],dim=-1)
        x = self.output_linear(x)
        return x

In [14]:
mutliheadattn = MultiHeadAttention(config=config)
attn_output = mutliheadattn(inputs_embeds)
attn_output.size()

torch.Size([1, 6, 768])

In [15]:
class FeedForward(torch.nn.Module):

    def __init__(self,config):
        super().__init__()
        self.linear_1 = torch.nn.Linear(config.hidden_size,config.intermediate_size)
        self.linear_2 = torch.nn.Linear(config.intermediate_size,config.hidden_size)
        self.g = torch.nn.GELU()
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
    
    def forward(self,x):
        x = self.linear_1(x)
        x = self.g(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

In [16]:
feed_forward = FeedForward(config=config)
ff_outputs = feed_forward(attn_output)
ff_outputs.size()

torch.Size([1, 6, 768])

In [17]:
class TransformerEncoderLayer(torch.nn.Module):

    def __init__(self,config):
        super().__init__()
        self.layer_norm_1 = torch.nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = torch.nn.LayerNorm(config.hidden_size)
        self.mutliheadattn = MultiHeadAttention(config=config)
        self.feed_forward = FeedForward(config=config)

    def forward(self,x):
        hidden_state = self.layer_norm_1(x)
        x = x + self.mutliheadattn(hidden_state)
        x = x + self.feed_forward(self.layer_norm_2(x))
        return x

In [18]:
encoder_layer = TransformerEncoderLayer(config=config)
encoder_layer(inputs_embeds).size()

torch.Size([1, 6, 768])

In [19]:
class Embeddings(torch.nn.Module):

    def __init__(self,config):
        super().__init__()
        self.token_embeddings = torch.nn.Embedding(config.vocab_size,config.hidden_size)
        self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings,
                                                      config.hidden_size)
        self.layer_norm = torch.nn.LayerNorm(config.hidden_size,eps=1e-12)
        self.dropout = torch.nn.Dropout()

    
    def forward(self,input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length,dtype=torch.long).unsqueeze(0)
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [20]:
embedding_layer = Embeddings(config=config)
embedding_layer(inputs.input_ids).size()

torch.Size([1, 6, 768])

In [21]:
class TransformerEncoder(torch.nn.Module):

    def __init__(self,config):
        super().__init__()
        self.embeddings = Embeddings(config=config)
        self.layers = torch.nn.ModuleList([TransformerEncoderLayer(config=config)
            for _ in range(config.num_hidden_layers)])

    def forward(self,x):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x)
        return x

In [22]:
encoder = TransformerEncoder(config=config)
encoder(inputs.input_ids).size()

torch.Size([1, 6, 768])