# BertEncoder Self Attention

$$
Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
$$

In [102]:
import torch
from torch import nn
import math
from transformers import BertTokenizer, BertModel, BertConfig

## model config and load

In [103]:
model_name = 'bert-base-uncased'
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True, return_dict=True)
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, config=config)
model = model.eval()

In [104]:
model.config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_attentions": true,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.33.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [105]:
att_head_size = model.config.hidden_size / model.config.num_attention_heads
print(att_head_size)
att_head_size = int(att_head_size)

64.0


In [106]:
model.encoder.layer[0].attention.self.query.weight.T[:, 64:128]

tensor([[-0.0112, -0.0324, -0.0615,  ..., -0.0383,  0.0031,  0.0059],
        [ 0.0260, -0.0067, -0.0616,  ...,  0.1097,  0.0029, -0.0540],
        [-0.0169,  0.0232,  0.0068,  ...,  0.0124, -0.0168,  0.0301],
        ...,
        [ 0.1083,  0.0056,  0.0968,  ...,  0.0188, -0.0171,  0.0141],
        [-0.0436, -0.1032, -0.1035,  ...,  0.0138, -0.0488, -0.0453],
        [-0.0611,  0.0224, -0.0320,  ...,  0.0376,  0.0186, -0.0482]],
       grad_fn=<SliceBackward0>)

## data

In [107]:
from sklearn.datasets import fetch_20newsgroups
newsgroups_train = fetch_20newsgroups(subset='train')
inputs_tests = tokenizer(newsgroups_train['data'][:1], 
                         truncation=True, padding=True, return_tensors='pt')
inputs_tests

{'input_ids': tensor([[  101,  2013,  1024,  3393,  2099,  2595,  3367,  1030, 11333,  2213,
          1012,  8529,  2094,  1012,  3968,  2226,  1006,  2073,  1005,  1055,
          2026,  2518,  1007,  3395,  1024,  2054,  2482,  2003,  2023,   999,
          1029,  1050,  3372,  2361,  1011, 14739,  1011,  3677,  1024, 10958,
          2278,  2509,  1012, 11333,  2213,  1012,  8529,  2094,  1012,  3968,
          2226,  3029,  1024,  2118,  1997,  5374,  1010,  2267,  2380,  3210,
          1024,  2321,  1045,  2001,  6603,  2065,  3087,  2041,  2045,  2071,
          4372,  7138,  2368,  2033,  2006,  2023,  2482,  1045,  2387,  1996,
          2060,  2154,  1012,  2009,  2001,  1037,  1016,  1011,  2341,  2998,
          2482,  1010,  2246,  2000,  2022,  2013,  1996,  2397, 20341,  1013,
          2220, 17549,  1012,  2009,  2001,  2170,  1037,  5318,  4115,  1012,
          1996,  4303,  2020,  2428,  2235,  1012,  1999,  2804,  1010,  1996,
          2392, 21519,  2001,  3584,  

In [108]:
inputs_tests['input_ids'].shape

torch.Size([1, 201])

## model output

In [109]:
model_output = model(**inputs_tests)
model_output.keys()

odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states', 'attentions'])

In [110]:
model_output['attentions'][0].shape
# 12 is the number of attention heads

torch.Size([1, 12, 201, 201])

In [111]:
model_output['attentions'][0][0, 0, :, :].sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 

## from scratch

In [112]:
emb_output = model.embeddings(inputs_tests['input_ids'], inputs_tests['token_type_ids'])
emb_output.shape

torch.Size([1, 201, 768])

In [113]:
model.encoder.layer[0]

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [114]:
# emb_output[0].shape: 201*768
# query.weight.T.shape: 768*768, query.weight.T[:, :att_head_size]: 768*64
Q_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.query.weight.T[:, :att_head_size] \
                            + model.encoder.layer[0].attention.self.query.bias[:att_head_size]
# 201*64
Q_first_head_first_layer.shape

torch.Size([201, 64])

In [115]:
# 201*64
K_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.key.weight.T[:, :att_head_size] \
                            + model.encoder.layer[0].attention.self.key.bias[:att_head_size]
K_first_head_first_layer.shape

torch.Size([201, 64])

In [117]:
# (201*64)*(64*201) ==> 201*201
attn_scores = torch.nn.Softmax(dim=-1)(Q_first_head_first_layer @ K_first_head_first_layer.T / math.sqrt(att_head_size))
attn_scores.shape

torch.Size([201, 201])

In [118]:
attn_scores

tensor([[0.0053, 0.0109, 0.0052,  ..., 0.0039, 0.0036, 0.0144],
        [0.0086, 0.0041, 0.0125,  ..., 0.0045, 0.0041, 0.0071],
        [0.0051, 0.0043, 0.0046,  ..., 0.0043, 0.0045, 0.0031],
        ...,
        [0.0010, 0.0023, 0.0055,  ..., 0.0012, 0.0018, 0.0011],
        [0.0010, 0.0023, 0.0057,  ..., 0.0012, 0.0017, 0.0007],
        [0.0022, 0.0056, 0.0063,  ..., 0.0045, 0.0048, 0.0015]],
       grad_fn=<SoftmaxBackward0>)

In [119]:
attn_scores.sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 

In [121]:
V_first_head_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.value.weight.T[:, :att_head_size] \
                            + model.encoder.layer[0].attention.self.value.bias[:att_head_size]
V_first_head_first_layer.shape

torch.Size([201, 64])

In [122]:
attn_emb = attn_scores @ V_first_head_first_layer
attn_emb.shape

torch.Size([201, 64])

## all

In [123]:
Q_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.query.weight.T \
                            + model.encoder.layer[0].attention.self.query.bias
K_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.key.weight.T \
                            + model.encoder.layer[0].attention.self.key.bias
V_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.value.weight.T \
                            + model.encoder.layer[0].attention.self.value.bias
Q_first_layer.shape, K_first_layer.shape, V_first_layer.shape

(torch.Size([201, 768]), torch.Size([201, 768]), torch.Size([201, 768]))

In [124]:
scores = torch.nn.Softmax(dim=-1)(Q_first_layer @ K_first_layer.T / math.sqrt(64))

In [125]:
scores.sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 

In [126]:
scores @ V_first_layer[:, :64]

tensor([[ 2.3786e+00, -9.7945e-02, -2.8436e-01,  ...,  9.6968e-02,
         -1.8537e-01,  2.1132e-01],
        [-6.2984e-01,  1.1037e+00, -2.0763e-04,  ..., -3.1767e-01,
          7.2197e-02, -1.0126e-01],
        [ 3.7932e-02,  2.9540e-01, -3.6121e-01,  ...,  1.6694e-01,
         -9.9637e-02, -2.0356e-01],
        ...,
        [-7.0623e-01, -3.0113e-01,  9.4959e-02,  ..., -1.7217e-01,
          1.8647e-01, -4.4743e-01],
        [ 6.2045e-02, -4.0626e-02,  1.6757e-01,  ...,  2.2690e-01,
          1.4684e-01,  1.3029e-01],
        [ 2.3748e+00, -9.7801e-02, -2.8357e-01,  ...,  9.7254e-02,
         -1.8481e-01,  2.1127e-01]], grad_fn=<MmBackward0>)