In [None]:
!pip install transformers scikit-learn bertviz

In [2]:
import torch
from torch import nn
import math
from bertviz.transformers_neuron_view import BertModel, BertConfig
from transformers import BertTokenizer

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

# model config and load

In [None]:
max_length = 256
model_name = 'bert-base-uncased'
config = BertConfig.from_pretrained(model_name, output_attentions = True,
                                    output_hidden_states=True,
                                    return_dict=True)
tokenizer = BertTokenizer.from_pretrained(model_name)
config.max_position_embeddings = max_length

model = BertModel.from_pretrained(model_name).to(device)
model.eval()

In [5]:
model.config

{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": true,
  "output_hidden_states": false,
  "pad_token_id": 0,
  "torchscript": false,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [6]:
model.encoder.layer[0].attention.self.query

Linear(in_features=768, out_features=768, bias=True)

# data

In [7]:
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset='train')

inputs_tests = tokenizer(newsgroups_train['data'][:1], truncation=True,padding=True, max_length=max_length,
                         return_tensors='pt').to(device)

In [8]:
inputs_tests.keys()


dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

# model output

In [9]:
model_output = model(**inputs_tests)

In [10]:
len(model_output)


3

In [11]:
model_output[0]

tensor([[[ 0.0085, -0.2443,  0.0346,  ..., -0.2983,  0.2613,  0.4714],
         [-0.0798, -0.1472,  0.1433,  ...,  0.3298,  0.7872,  0.5355],
         [-0.7091,  0.1773,  0.4971,  ..., -0.5104,  0.2784,  0.5167],
         ...,
         [ 0.7488,  0.0424,  1.0689,  ...,  0.0818, -0.2160,  0.2466],
         [ 0.3391, -0.0718,  0.4097,  ...,  0.5932, -0.5949, -0.1425],
         [ 0.2651,  0.3685,  0.2441,  ...,  0.1301, -0.5060, -0.2754]]],
       grad_fn=<AddBackward0>)

In [12]:
model_output[1]

tensor([[-0.6022, -0.4841, -0.9692,  0.4348,  0.8144, -0.1743, -0.0952,  0.2833,
         -0.8750, -0.9999, -0.7018,  0.9676,  0.9616,  0.6243,  0.7338, -0.3455,
          0.3768, -0.5210,  0.1832,  0.9012,  0.5294,  1.0000, -0.1947,  0.3778,
          0.4275,  0.9923, -0.6664,  0.8672,  0.8796,  0.5362, -0.1150,  0.2813,
         -0.9835, -0.1881, -0.9788, -0.9836,  0.4402, -0.4884, -0.0377,  0.1147,
         -0.7458,  0.3006,  1.0000, -0.1123,  0.5225, -0.1872, -1.0000,  0.3246,
         -0.5525,  0.9579,  0.9226,  0.9629,  0.1030,  0.3061,  0.4857, -0.3409,
         -0.1420, -0.0370, -0.2785, -0.4145, -0.6839,  0.4322, -0.9142, -0.7302,
          0.9266,  0.9311, -0.2464, -0.3378, -0.0248, -0.0636,  0.4120,  0.2519,
         -0.5571, -0.8798,  0.7321,  0.3363, -0.7292,  1.0000,  0.0283, -0.9513,
          0.9648,  0.8054,  0.6048, -0.4400,  0.5410, -1.0000,  0.6302, -0.0897,
         -0.9770,  0.3008,  0.6549, -0.1141,  0.9319,  0.6333, -0.4074, -0.5968,
         -0.3185, -0.9316, -

In [14]:
model_output[2][-1]['attn'][0,0,:,:]

torch.Size([201, 201])

# 4 from scratch

In [15]:
emb_output = model.embeddings(inputs_tests['input_ids'], inputs_tests['token_type_ids'])

In [17]:
emb_output.shape

torch.Size([1, 201, 768])

In [18]:
model.encoder.layer[0]

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [19]:
emb_output[0].shape
att_head_size = 64

In [20]:
Q_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.query.weight.T[:,:att_head_size] + model.encoder.layer[0].attention.self.query.bias[:att_head_size]

In [21]:
k_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.key.weight.T[:,:att_head_size] + model.encoder.layer[0].attention.self.key.bias[:att_head_size]

In [22]:
attn_scores = torch.nn.Softmax(dim=-1)(torch.matmul(Q_first_head_first_layer, k_first_head_first_layer.transpose(-1, -2)) / math.sqrt(att_head_size))

In [23]:
attn_scores

tensor([[0.0053, 0.0109, 0.0052,  ..., 0.0039, 0.0036, 0.0144],
        [0.0086, 0.0041, 0.0125,  ..., 0.0045, 0.0041, 0.0071],
        [0.0051, 0.0043, 0.0046,  ..., 0.0043, 0.0045, 0.0031],
        ...,
        [0.0010, 0.0023, 0.0055,  ..., 0.0012, 0.0018, 0.0011],
        [0.0010, 0.0023, 0.0057,  ..., 0.0012, 0.0017, 0.0007],
        [0.0022, 0.0056, 0.0063,  ..., 0.0045, 0.0048, 0.0015]],
       grad_fn=<SoftmaxBackward0>)

In [24]:
V_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.value.weight.T[:,:att_head_size] + model.encoder.layer[0].attention.self.value.bias[:att_head_size]

In [25]:
attn_embed = torch.matmul(attn_scores, V_first_head_first_layer)

In [26]:
attn_embed

tensor([[-4.5640e-01,  4.6211e-02,  4.3913e-02,  ..., -2.0099e-02,
         -1.2756e-02,  6.4255e-03],
        [-4.5674e-01,  3.4322e-02,  3.2707e-02,  ..., -4.9206e-02,
          1.4976e-02, -3.0628e-02],
        [-4.9474e-01, -2.9540e-04, -7.5376e-04,  ..., -2.0035e-02,
          1.7146e-02, -3.0126e-02],
        ...,
        [-3.7991e-01,  5.2831e-02,  2.2534e-02,  ..., -1.8338e-02,
         -6.9509e-02,  2.1317e-02],
        [-3.8071e-01,  4.0900e-02,  2.8770e-02,  ..., -2.1192e-02,
         -5.2893e-02,  1.9734e-02],
        [-4.7131e-01,  1.0947e-01,  1.1631e-02,  ..., -3.4542e-02,
         -2.3753e-02, -5.0505e-03]], grad_fn=<MmBackward0>)