In [1]:
import torch
import torch.nn as nn
import math
from transformers import BertTokenizer
from bertviz.transformers_neuron_view import BertModel, BertConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model config and load
max_length = 256
model_name = 'bert-base-uncased'
config = BertConfig.from_pretrained(model_name, output_attentions=True,
                                    output_hidden_states=True,
                                    return_dict=True)
tokenizer = BertTokenizer.from_pretrained(model_name)
config.max_position_embeddings = max_length
model = BertModel(config).from_pretrained(model_name)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (den

In [4]:
model.config

{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": true,
  "output_hidden_states": false,
  "pad_token_id": 0,
  "torchscript": false,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [5]:
attn_head_size = int(model.config.hidden_size / model.config.num_attention_heads)
attn_head_size

64

In [6]:
model.encoder

BertEncoder(
  (layer): ModuleList(
    (0-11): 12 x BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): BertSelfOutput(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (LayerNorm): BertLayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (intermediate): BertIntermediate(
        (dense): Linear(in_features=768, out_features=3072, bias=True)
      )
      (output): BertOutput(
        (dense): Linear(in_features=3072, out_features=768, bias=True)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

In [3]:
# data
from sklearn.datasets import fetch_20newsgroups
news_data = fetch_20newsgroups(subset='train')
input_data = tokenizer(news_data['data'][:1], truncation=True, padding=True, max_length=max_length, return_tensors='pt')

In [8]:
print(input_data.keys())
print(input_data['input_ids'].shape)

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
torch.Size([1, 201])


In [4]:
# model output
outputs = model(**input_data)

In [None]:
# outputs[0] : [1, 201, 768]
# outputs[1]: [1, 768]
# outputs[2]: 12

In [12]:
for x in range(len(outputs[2])):
    for i, j in outputs[2][x].items():
        print(i, j.shape)

attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries torch.Size([1, 12, 201, 64])
keys torch.Size([1, 12, 201, 64])
attn torch.Size([1, 12, 201, 201])
queries tor

In [10]:
outputs[-1][0].keys()

dict_keys(['attn', 'queries', 'keys'])

In [11]:
outputs[-1][0]['attn'].shape

torch.Size([1, 12, 201, 201])

In [12]:
emd_output = model.embeddings(input_data['input_ids'], input_data['token_type_ids'])
emd_output.shape

torch.Size([1, 201, 768])

In [14]:
Q_first_head_first_layer = emd_output[0] @ model.encoder.layer[0].attention.self.query.weight.T[:, :attn_head_size] \
                           + model.encoder.layer[0].attention.self.query.bias[:attn_head_size]
Q_first_head_first_layer

tensor([[ 0.7090, -0.1532, -0.0324,  ..., -0.1861, -1.1897, -0.3917],
        [ 0.9676,  0.0748,  0.2025,  ...,  0.7521,  0.4138,  0.1224],
        [ 0.8280, -0.0809,  0.5322,  ...,  0.1444,  0.3582, -0.0987],
        ...,
        [ 0.9873, -0.7626,  0.3701,  ..., -0.7065,  0.7049, -0.3984],
        [ 1.0845, -0.9601,  0.4284,  ..., -0.8086,  0.8632, -0.7072],
        [ 0.3943, -0.4631, -0.3239,  ..., -0.1489, -0.8662, -0.1141]],
       grad_fn=<AddBackward0>)

In [15]:
K_first_head_first_layer = emd_output[0] @ model.encoder.layer[0].attention.self.key.weight.T[:, :attn_head_size] \
                           + model.encoder.layer[0].attention.self.key.bias[:attn_head_size]
K_first_head_first_layer

tensor([[ 1.2797,  0.2204,  0.2408,  ...,  0.5843, -0.4191,  0.9194],
        [-0.5851, -0.5868, -0.1607,  ...,  1.0473,  0.0725, -0.3103],
        [-0.5147, -0.5780,  0.0156,  ...,  0.4617,  0.3585, -0.3786],
        ...,
        [-0.6989,  1.2598,  0.1314,  ...,  0.0398,  0.6206, -1.6981],
        [-1.0770,  1.2705, -0.0528,  ..., -0.2308,  0.5101, -1.7183],
        [-1.3563,  1.7390, -0.0489,  ..., -0.6310,  0.2154,  0.9688]],
       grad_fn=<AddBackward0>)

In [18]:
attn_scores = torch.nn.Softmax(dim=-1)(Q_first_head_first_layer @ K_first_head_first_layer.T / math.sqrt(attn_head_size))
attn_scores

tensor([[0.0053, 0.0109, 0.0052,  ..., 0.0039, 0.0036, 0.0144],
        [0.0086, 0.0041, 0.0125,  ..., 0.0045, 0.0041, 0.0071],
        [0.0051, 0.0043, 0.0046,  ..., 0.0043, 0.0045, 0.0031],
        ...,
        [0.0010, 0.0023, 0.0055,  ..., 0.0012, 0.0018, 0.0011],
        [0.0010, 0.0023, 0.0057,  ..., 0.0012, 0.0017, 0.0007],
        [0.0022, 0.0056, 0.0063,  ..., 0.0045, 0.0048, 0.0015]],
       grad_fn=<SoftmaxBackward0>)

In [19]:
outputs[-1][0]['attn'][0, 0, :, :]

tensor([[0.0053, 0.0109, 0.0052,  ..., 0.0039, 0.0036, 0.0144],
        [0.0086, 0.0041, 0.0125,  ..., 0.0045, 0.0041, 0.0071],
        [0.0051, 0.0043, 0.0046,  ..., 0.0043, 0.0045, 0.0031],
        ...,
        [0.0010, 0.0023, 0.0055,  ..., 0.0012, 0.0018, 0.0011],
        [0.0010, 0.0023, 0.0057,  ..., 0.0012, 0.0017, 0.0007],
        [0.0022, 0.0056, 0.0063,  ..., 0.0045, 0.0048, 0.0015]],
       grad_fn=<SliceBackward0>)

In [21]:
V_first_head_first_layer = emd_output[0] @ model.encoder.layer[0].attention.self.value.weight.T[:, :attn_head_size] \
                           + model.encoder.layer[0].attention.self.value.bias[:attn_head_size]
attn_emb = attn_scores @ V_first_head_first_layer
attn_emb.shape

torch.Size([201, 64])