### Queries, Keys and Values

I will ommit the batch dimension, I will assume that the input embeddings are of shape (seq_length, d_e) and the output embeddings are of shape (seq_length, d_emb).

I will do some math to understand it better (also looking at https://arxiv.org/abs/1706.03762):

Notation:

* Values matrix: $V \in \mathbb{R}^{seq\_length \times d_{e}}$
* Single value vector: $v_i \in \mathbb{R}^{d_{e}}$
* Keys matrix: $K \in \mathbb{R}^{seq\_length \times d_{k}}$+
* Single key vector: $k_i \in \mathbb{R}^{d_{k}}$
* Queries matrix: $Q \in \mathbb{R}^{seq\_length \times d_{q}}$
* Single query vector: $q_i \in \mathbb{R}^{d_{q}}$


$$V = \begin{bmatrix} v_1 \\ v_2 \\ \vdots \\ v_{seq\_length} \end{bmatrix} K = \begin{bmatrix} k_1 \\ k_2 \\ \vdots \\ k_{seq\_length} \end{bmatrix} Q = \begin{bmatrix} q_1 \\ q_2 \\ \vdots \\ q_{seq\_length} \end{bmatrix}$$

1. **Dot product between queries and keys**:

$$ \frac{Q \cdot K^T}{\sqrt{d_{k}}} = \frac{1}{\sqrt{d_{k}}} \begin{bmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 & \cdots & q_1 \cdot k_{seq\_length} \\ q_2 \cdot k_1 & q_2 \cdot k_2 & \cdots & q_2 \cdot k_{seq\_length} \\ \vdots & \vdots & \ddots & \vdots \\ q_{seq\_length} \cdot k_1 & q_{seq\_length} \cdot k_2 & \cdots & q_{seq\_length} \cdot k_{seq\_length} \end{bmatrix}$$

2. **Softmax**:

$$ W = \text{softmax} \left( \frac{Q \cdot K^T}{\sqrt{d_{k}}} \right) = \begin{bmatrix} w_{1,1} & w_{1,2} & \cdots & w_{1,seq\_length} \\ w_{2,1} & w_{2,2} & \cdots & w_{2,seq\_length} \\ \vdots & \vdots & \ddots & \vdots \\ w_{seq\_length,1} & w_{seq\_length,2} & \cdots & w_{seq\_length,seq\_length} \end{bmatrix} $$

$$where \sum_{j = 1}^{seq\_length} w_{ij} = 1 \quad \forall i  $$

3. **Weighted sum of values**:

$$ X^{'} =  \begin{bmatrix} x^{'}_1 \\ x^{'}_2 \\ \vdots \\ x^{'}_{seq\_length} \end{bmatrix} = W \cdot V = \begin{bmatrix} w_{1,1} & w_{1,2} & \cdots & w_{1,seq\_length} \\ w_{2,1} & w_{2,2} & \cdots & w_{2,seq\_length} \\ \vdots & \vdots & \ddots & \vdots \\ w_{seq\_length,1} & w_{seq\_length,2} & \cdots & w_{seq\_length,seq\_length} \end{bmatrix} \begin{bmatrix} v_1 \\ v_2 \\ \vdots \\ v_{seq\_length} \end{bmatrix} $$

$$so, \quad x^{'}_i = \sum_{j = 1}^{seq\_length} w_{ij} \cdot v_j $$

In [1]:
from transformers import AutoTokenizer, AutoConfig
import torch
import torch.nn.functional as F
import math

In [2]:
model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
sentence = "time flies like an arrow"

inputs = tokenizer(sentence, return_tensors="pt", add_special_tokens=False)
inputs.input_ids

tensor([[ 2051, 10029,  2066,  2019,  8612]])

In [4]:
#loading the config parameters fromt he model
config = AutoConfig.from_pretrained(model_ckpt)
#learnable embeddings that acts as a lookup table
token_emb = torch.nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [8]:
inputs_embeds = token_emb(inputs.input_ids)
print(inputs_embeds.size())
print(inputs_embeds)

torch.Size([1, 5, 768])
tensor([[[ 0.8201,  0.6045, -0.0150,  ..., -0.5656,  0.3174,  0.8564],
         [-0.0386, -1.0203, -0.0138,  ...,  0.4973,  0.9126, -0.1514],
         [ 1.1335,  2.5236,  0.5811,  ..., -0.2125,  2.2726,  0.8796],
         [ 1.0999,  0.9886,  0.5276,  ...,  0.0387,  0.1758, -1.4119],
         [-0.2951,  1.1370, -0.4077,  ..., -0.6754,  0.4314, -0.9594]]],
       grad_fn=<EmbeddingBackward0>)


In [14]:
#for simplicity we will compute a single attention head and assume Q = K = V = inputs_embeds
query = key = value = inputs_embeds
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(dim_k)
print(f"Size of scores: {scores.size()}\n")
weights = F.softmax(scores, dim=-1)
print(f"Size of weights: {weights.size()}")
print(f"Check if the sum of weights is 1: {weights.sum(dim=-1)}\n")
att_outputs = torch.bmm(weights, value)
print(f"Size of att_outputs: {att_outputs.size()}")
print(f"att_outputs: {att_outputs}")

Size of scores: torch.Size([1, 5, 5])

Size of weights: torch.Size([1, 5, 5])
Check if the sum of weights is 1: tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

Size of att_outputs: torch.Size([1, 5, 768])
att_outputs: tensor([[[ 0.8201,  0.6045, -0.0150,  ..., -0.5656,  0.3174,  0.8564],
         [-0.0386, -1.0203, -0.0138,  ...,  0.4973,  0.9126, -0.1514],
         [ 1.1335,  2.5236,  0.5811,  ..., -0.2125,  2.2726,  0.8796],
         [ 1.0999,  0.9886,  0.5276,  ...,  0.0387,  0.1758, -1.4119],
         [-0.2951,  1.1370, -0.4077,  ..., -0.6754,  0.4314, -0.9594]]],
       grad_fn=<BmmBackward0>)
