### Queries, Keys and Values

I will ommit the batch dimension, I will assume that the input embeddings are of shape (seq_length, d_e) and the output embeddings are of shape (seq_length, d_emb).

I will do some math to understand it better (also looking at https://arxiv.org/abs/1706.03762):

Notation:

* Values matrix: $V \in \mathbb{R}^{seq\_length \times d_{e}}$
* Single value vector: $v_i \in \mathbb{R}^{d_{e}}$
* Keys matrix: $K \in \mathbb{R}^{seq\_length \times d_{k}}$+
* Single key vector: $k_i \in \mathbb{R}^{d_{k}}$
* Queries matrix: $Q \in \mathbb{R}^{seq\_length \times d_{q}}$
* Single query vector: $q_i \in \mathbb{R}^{d_{q}}$


$$V = \begin{bmatrix} v_1 \\ v_2 \\ \vdots \\ v_{seq\_length} \end{bmatrix} K = \begin{bmatrix} k_1 \\ k_2 \\ \vdots \\ k_{seq\_length} \end{bmatrix} Q = \begin{bmatrix} q_1 \\ q_2 \\ \vdots \\ q_{seq\_length} \end{bmatrix}$$

1. **Dot product between queries and keys**:

$$ \frac{Q \cdot K^T}{\sqrt{d_{k}}} = \frac{1}{\sqrt{d_{k}}} \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & \cdots & q_1 \cdot k_{seq\_length} \\ q_2 \cdot k_1 & q_2 \cdot k_2 & \cdots & q_2 \cdot k_{seq\_length} \\ \vdots & \vdots & \ddots & \vdots \\ q_{seq\_length} \cdot k_1 & q_{seq\_length} \cdot k_2 & \cdots & q_{seq\_length} \cdot k_{seq\_length} \end{bmatrix}$$

2. **Softmax**:

$$ W = \text{softmax} \left( \frac{Q \cdot K^T}{\sqrt{d_{k}}} \right) = \begin{bmatrix} w_{1,1} & w_{1,2} & \cdots & w_{1,seq\_length} \\ w_{2,1} & w_{2,2} & \cdots & w_{2,seq\_length} \\ \vdots & \vdots & \ddots & \vdots \\ w_{seq\_length,1} & w_{seq\_length,2} & \cdots & w_{seq\_length,seq\_length} \end{bmatrix} $$

$$where \sum_{j = 1}^{seq\_length} w_{ij} = 1 \quad \forall i  $$

3. **Weighted sum of values**:

$$ X^{'} =  \begin{bmatrix} x^{'}_1 \\ x^{'}_2 \\ \vdots \\ x^{'}_{seq\_length} \end{bmatrix} = W \cdot V = \begin{bmatrix} w_{1,1} & w_{1,2} & \cdots & w_{1,seq\_length} \\ w_{2,1} & w_{2,2} & \cdots & w_{2,seq\_length} \\ \vdots & \vdots & \ddots & \vdots \\ w_{seq\_length,1} & w_{seq\_length,2} & \cdots & w_{seq\_length,seq\_length} \end{bmatrix} \begin{bmatrix} v_1 \\ v_2 \\ \vdots \\ v_{seq\_length} \end{bmatrix} $$

$$so, \quad x^{'}_i = \sum_{j = 1}^{seq\_length} w_{ij} \cdot v_j $$

In [3]:
from transformers import AutoTokenizer, AutoConfig
import torch
import torch.nn.functional as F
import math

In [9]:
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
sentence = "time flies like an arrow"

inputs = tokenizer(sentence, return_tensors="pt", add_special_tokens=False)
print(inputs.input_ids.shape)
print(inputs.input_ids)


torch.Size([1, 5])
tensor([[ 2051, 10029,  2066,  2019,  8612]])


In [6]:
#loading the config parameters fromt he model
config = AutoConfig.from_pretrained(model_ckpt)
#learnable embeddings that acts as a lookup table
token_emb = torch.nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [7]:
inputs_embeds = token_emb(inputs.input_ids)
print(inputs_embeds.size())
print(inputs_embeds)

torch.Size([1, 5, 768])
tensor([[[ 6.6525e-01, -3.2743e-01,  4.4065e-01,  ..., -7.4415e-03,
          -1.8935e-02, -5.1604e-01],
         [-2.8991e-01,  1.1964e+00, -1.2789e+00,  ..., -1.2697e-01,
          -4.4247e-01, -1.3083e+00],
         [-3.2184e-01, -1.1409e+00,  1.2393e+00,  ..., -1.5396e-03,
          -2.1357e-03,  7.1527e-01],
         [-8.4413e-01, -9.1694e-01,  1.1774e+00,  ..., -5.1015e-01,
          -9.6122e-01,  5.2684e-01],
         [ 1.2174e+00,  1.9181e+00,  1.5767e-01,  ..., -2.4287e+00,
           2.8732e-01, -1.0570e+00]]], grad_fn=<EmbeddingBackward0>)


In [10]:
#for simplicity we will compute a single attention head and assume Q = K = V = inputs_embeds
# and will ommit the positional embeddings
query = key = value = inputs_embeds
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(dim_k)
print(f"Size of scores: {scores.size()}\n")
weights = F.softmax(scores, dim=-1)
print(f"Size of weights: {weights.size()}")
print(f"Check if the sum of weights is 1: {weights.sum(dim=-1)}\n")
att_outputs = torch.bmm(weights, value)
print(f"Size of att_outputs: {att_outputs.size()}")
print(f"att_outputs: {att_outputs}")

Size of scores: torch.Size([1, 5, 5])

Size of weights: torch.Size([1, 5, 5])
Check if the sum of weights is 1: tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

Size of att_outputs: torch.Size([1, 5, 768])
att_outputs: tensor([[[ 6.6525e-01, -3.2743e-01,  4.4065e-01,  ..., -7.4415e-03,
          -1.8935e-02, -5.1604e-01],
         [-2.8991e-01,  1.1964e+00, -1.2789e+00,  ..., -1.2697e-01,
          -4.4247e-01, -1.3083e+00],
         [-3.2184e-01, -1.1409e+00,  1.2393e+00,  ..., -1.5396e-03,
          -2.1357e-03,  7.1527e-01],
         [-8.4413e-01, -9.1694e-01,  1.1774e+00,  ..., -5.1015e-01,
          -9.6122e-01,  5.2684e-01],
         [ 1.2174e+00,  1.9181e+00,  1.5767e-01,  ..., -2.4287e+00,
           2.8732e-01, -1.0570e+00]]], grad_fn=<BmmBackward0>)


In [13]:
config.hidden_size

768