In [1]:
import torch

print(torch.backends.mps.is_built())
print(torch.backends.mps.is_available())

True
True


In [2]:
from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show

model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
text = "time files like an arrow"
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [5]:
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
print(inputs.input_ids)

tensor([[2051, 6764, 2066, 2019, 8612]])


In [7]:
from torch import nn
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
print(config)
print(token_emb)

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.27.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

Embedding(30522, 768)


In [10]:
inputs_embeds = token_emb(inputs.input_ids)
print(inputs_embeds)
print(inputs_embeds.size())

tensor([[[ 1.2226, -0.9145,  0.7991,  ...,  1.0869,  1.6828,  0.8517],
         [ 0.1741,  0.1626,  0.2582,  ..., -0.7076, -0.1016,  0.3266],
         [ 0.7566,  0.1454, -1.0454,  ...,  1.0995, -0.0073,  0.4968],
         [ 0.0120,  0.5308, -0.0789,  ..., -0.1254,  0.2728, -1.0377],
         [ 0.5253,  0.3071, -0.7226,  ..., -0.5924,  0.3080, -0.7447]]],
       grad_fn=<EmbeddingBackward0>)
torch.Size([1, 5, 768])


In [13]:
import torch
from math import sqrt

query = key = value = inputs_embeds
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1,2)) / sqrt(dim_k)
print(scores)
print(scores.size())

tensor([[[29.4782, -0.1990, -1.0939, -1.1332,  1.0315],
         [-0.1990, 26.0226, -0.2681, -0.3369,  1.0151],
         [-1.0939, -0.2681, 29.1612,  0.0959, -0.8135],
         [-1.1332, -0.3369,  0.0959, 27.1834,  0.8669],
         [ 1.0315,  1.0151, -0.8135,  0.8669, 26.3061]]],
       grad_fn=<DivBackward0>)
torch.Size([1, 5, 5])


In [16]:
import torch.nn.functional as F

weights = F.softmax(scores, dim = -1)
print(weights)
print(weights.sum(dim=-1))

tensor([[[1.0000e+00, 1.2922e-13, 5.2806e-14, 5.0774e-14, 4.4234e-13],
         [4.0935e-12, 1.0000e+00, 3.8202e-12, 3.5661e-12, 1.3784e-11],
         [7.2501e-14, 1.6557e-13, 1.0000e+00, 2.3828e-13, 9.5973e-14],
         [5.0383e-13, 1.1171e-12, 1.7221e-12, 1.0000e+00, 3.7231e-12],
         [1.0553e-11, 1.0381e-11, 1.6677e-12, 8.9513e-12, 1.0000e+00]]],
       grad_fn=<SoftmaxBackward0>)
tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)


In [17]:
attn_outputs = torch.bmm(weights, value)
print(attn_outputs)
print(attn_outputs.shape)

tensor([[[ 1.2226, -0.9145,  0.7991,  ...,  1.0869,  1.6828,  0.8517],
         [ 0.1741,  0.1626,  0.2582,  ..., -0.7076, -0.1016,  0.3266],
         [ 0.7566,  0.1454, -1.0454,  ...,  1.0995, -0.0073,  0.4968],
         [ 0.0120,  0.5308, -0.0789,  ..., -0.1254,  0.2728, -1.0377],
         [ 0.5253,  0.3071, -0.7226,  ..., -0.5924,  0.3080, -0.7447]]],
       grad_fn=<BmmBackward0>)
torch.Size([1, 5, 768])
