<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Implement-Encoder" data-toc-modified-id="Implement-Encoder-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Implement Encoder</a></span></li></ul></div>

In [1]:
from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show

Since BertViz needs to tap into the attention layers of the model, we’ll instantiate our BERT checkpoint with the model class from BertViz and then use the show() function to generate the interactive visualization

In [2]:
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)


Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

100%|██████████| 433/433 [00:00<00:00, 803598.95B/s]
100%|██████████| 440473133/440473133 [06:28<00:00, 1132782.83B/s]


In [4]:
# text = "time flies like an arrow"
text = 'I had an enjoyable two weeks in London'
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

the values of the query and key vectors are represented as vertical bands, where the intensity of each band corresponds to the magnitude. The connecting lines are weighted according to the attention between the tokens

# Implement Encoder

In [5]:
text = "time flies like an arrow"


In [6]:
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
# exclude [cls] and [sep] with add_special_tokens=False
inputs.input_ids

tensor([[ 2051, 10029,  2066,  2019,  8612]])

Load embedding matrix from bert-base-uncased

In [7]:
from torch import nn
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

Run the input text to this embedding matrix

In [9]:
inputs_embeds = token_emb(inputs.input_ids)
inputs_embeds.size() # [batch_size, seq_len, hidden_dim]

torch.Size([1, 5, 768])

Q@K.T / sqrt(dim_k)

In [11]:
import torch
from math import sqrt

query = key = value = inputs_embeds
dim_k = key.size(-1)

In [12]:
query.shape, key.transpose(1,2).shape

(torch.Size([1, 5, 768]), torch.Size([1, 768, 5]))

In [13]:
scores = torch.bmm(query, key.transpose(1,2)) / sqrt(dim_k)
scores.size()

torch.Size([1, 5, 5])

Add softmax

In [14]:
import torch.nn.functional as F

weights = F.softmax(scores, dim=-1)
weights.sum(dim=-1)

tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

Multiply Value matrix

In [16]:
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape

torch.Size([1, 5, 768])

Put into a function

In [17]:
def scaled_dot_product_attention(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)