In [22]:
from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show

In [23]:
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
text = "time flies like an arrow"
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [3]:
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
inputs.input_ids

tensor([[ 2051, 10029,  2066,  2019,  8612]])

In [4]:
from torch import nn
from transformers import AutoConfig

In [5]:
config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [6]:
inputs_embeds = token_emb(inputs.input_ids)
inputs_embeds.size()

torch.Size([1, 5, 768])

In [8]:
import torch
from math import sqrt


tensor([[[-0.0897,  0.6301,  1.2020,  ..., -1.1901, -0.1277, -0.4695],
         [ 1.6142, -0.3258, -0.3409,  ..., -0.3297, -0.0606, -1.0925],
         [ 0.0075,  0.2705,  0.3665,  ..., -0.6179, -0.5777, -1.3844],
         [-1.6967, -1.0106, -0.1175,  ..., -0.6843,  0.2953, -1.3628],
         [-0.0459, -0.8149, -0.6277,  ...,  0.9132,  0.5041,  0.6363]]],
       grad_fn=<EmbeddingBackward0>)

In [9]:
query = key = value = inputs_embeds
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1,2)) / sqrt(dim_k)
scores.size()

torch.Size([1, 5, 5])

In [18]:
import torch.nn.functional as F
weights = F.softmax(scores, dim=-1)
weights.sum(dim=-1)

tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

In [19]:
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape

torch.Size([1, 5, 768])

In [20]:
def scaled_dot_product_attention(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In [21]:
query = key = value = inputs_embeds
attention = scaled_dot_product_attention(query,key,value)
attention

tensor([[[-0.0897,  0.6301,  1.2020,  ..., -1.1901, -0.1277, -0.4695],
         [ 1.6142, -0.3258, -0.3409,  ..., -0.3297, -0.0606, -1.0925],
         [ 0.0075,  0.2705,  0.3665,  ..., -0.6179, -0.5777, -1.3844],
         [-1.6967, -1.0106, -0.1175,  ..., -0.6843,  0.2953, -1.3628],
         [-0.0459, -0.8149, -0.6277,  ...,  0.9132,  0.5041,  0.6363]]],
       grad_fn=<BmmBackward0>)