# Foundations of Self-Attention and Multi-Head Attention (Chapter 16 - Initial Focus)

---

This notebook introduces the revolutionary **Self-Attention** mechanism, the core component of the Transformer architecture, marking the beginning of **Chapter 16: Transformers - Attention Beyond Recurrence**. Unlike RNNs (Chapter 15), which process sequences sequentially, Self-Attention processes all elements in parallel, determining the context and dependencies between them simultaneously.

### 1. Preparing the Input for Attention ðŸ§ 

* **Input Embedding:** The notebook starts with a simple input sequence (a sentence represented by token indices) and uses an **`nn.Embedding`** layer to convert these indices into dense, continuous vectors. This embedding vector is the foundation of the attention calculation.
* **Query, Key, and Value Projections (Q, K, V):** The central concept of attention is that the input embedding ($X$) is transformed into three different representations using three distinct weight matrices ($W_Q$, $W_K$, $W_V$):
    * **Query ($Q$):** Used to score against all keys.
    * **Key ($K$):** Used to score against the query.
    * **Value ($V$):** The content that is summed up according to the scores (attention weights).
* **Implementation:** The notebook manually defines and applies three separate `nn.Linear` layers to the input embeddings to obtain the $Q$, $K$, and $V$ matrices.

### 2. The Scaled Dot-Product Attention Mechanism

The core calculation follows the formula: $\text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V$

* **Scoring ($QK^T$):** This step calculates the **raw attention scores** by taking the dot product of the Query matrix with the transpose of the Key matrix. The score indicates how relevant every word in the sequence is to every other word.
* **Scaling:** The scores are divided by the square root of the key vector dimension ($\sqrt{d_k}$). This is the **scaling factor** introduced to prevent the dot products from becoming too large (especially with deep networks), which can push the Softmax function into regions with tiny gradients.
* **Softmax:** The scaled scores are passed through the **`nn.Softmax`** function. This converts the raw scores into a probability distribution (the **Attention Weights**), where all weights sum to 1.
* **Final Output ($Z$):** The final output is calculated by multiplying the Softmax weights by the **Value** matrix. This produces the **Context Vector** ($Z$), where the representation of each word is a weighted sum of the values of all words in the sequence. 

### 3. Introduction to Multi-Head Attention

The notebook introduces the concept of **Multi-Head Attention**â€”the idea of running several attention mechanisms in parallel:

* **Heads:** Instead of one large $Q$, $K$, and $V$, the input is split into multiple smaller "heads." Each head learns to focus on different aspects of the dependencies (e.g., one head for syntax, one for semantics).
* **Implementation:** This is demonstrated by manually defining multiple projection matrices ($W_{Q_i}$, $W_{K_i}$, $W_{V_i}$) for different heads.
* **Concatenation:** The outputs of all heads are concatenated and passed through a final linear layer to produce the final context vector.

This notebook provides a detailed, step-by-step mathematical breakdown of the attention mechanism, which entirely replaces the complex sequential processing of RNNs.

In [40]:
import torch
from torch import nn

In [41]:
sentence = torch.tensor([
    0,
    7,
    1,
    2,
    5,
    6,
    4,
    3
])

In [42]:
torch.manual_seed(123)
embed = nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape

torch.Size([8, 16])

In [43]:
omega = torch.empty(8, 8)
for i, x_i in enumerate(embedded_sentence):
    for j, x_j in enumerate(embedded_sentence):
        omega[i, j] = torch.dot(x_i, x_j)

In [44]:
omega_mat = embedded_sentence.matmul(embedded_sentence.T)

In [45]:
torch.allclose(omega_mat, omega)

True

In [46]:
attention_weights = nn.functional.softmax(omega, dim= 1)
attention_weights.shape

torch.Size([8, 8])

In [47]:
attention_weights.sum(dim= 1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [48]:
x_2 = embedded_sentence[1, :]
context_vec2 = torch.zeros(x_2.shape)
for j in range(8):
    x_j = embedded_sentence[j, :]
    context_vec2 += attention_weights[1, j] * x_j
context_vec2

tensor([-9.3975e-01, -4.6856e-01,  1.0311e+00, -2.8192e-01,  4.9373e-01,
        -1.2896e-02, -2.7327e-01, -7.6358e-01,  1.3958e+00, -9.9543e-01,
        -7.1287e-04,  1.2449e+00, -7.8077e-02,  1.2765e+00, -1.4589e+00,
        -2.1601e+00])

In [49]:
context_vectors = torch.matmul(attention_weights, embedded_sentence)

In [51]:
torch.allclose(context_vec2, context_vectors[1])

True

In [52]:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

In [53]:
x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)
key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)

In [54]:
keys = U_key.matmul(embedded_sentence.T).T
values = U_value.matmul(embedded_sentence.T).T

In [55]:
print(torch.allclose(key_2, keys[1]))
print(torch.allclose(value_2, values[1]))

True
True


In [56]:
omega_23 = query_2.dot(keys[2])
omega_23

tensor(14.3667)

In [57]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([-25.1623,   9.3602,  14.3667,  32.1482,  53.8976,  46.6626,  -1.2131,
        -32.9392])

In [60]:
attention_weights_2 = nn.functional.softmax(omega_2 / d ** 0.5, dim= 0)
attention_weights_2

tensor([2.2317e-09, 1.2499e-05, 4.3696e-05, 3.7242e-03, 8.5596e-01, 1.4026e-01,
        8.8897e-07, 3.1935e-10])

In [61]:
context_vec2 = attention_weights_2.matmul(values)
context_vec2

tensor([-1.2226, -3.4387, -4.3928, -5.2125, -1.1249, -3.3041, -1.4316, -3.2765,
        -2.5114, -2.6105, -1.5793, -2.8433, -2.4142, -0.3998, -1.9917, -3.3499])

In [62]:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
one_U_query = torch.rand(d, d)

In [75]:
h = 8
multihead_U_query = torch.rand(h, d, d)
multihead_U_key = torch.rand(h, d, d)
multihead_U_value = torch.rand(h, d, d)

In [76]:
multihead_query_2 = multihead_U_query.matmul(x_2)
multihead_key_2 = multihead_U_key.matmul(x_2)
multihead_value_2 = multihead_U_value.matmul(x_2)
print(multihead_query_2.shape)
print(multihead_query_2[3])

torch.Size([8, 16])
tensor([-1.1907, -1.8279, -1.5237,  0.2820, -1.6944, -1.6440, -0.7748, -1.2830,
        -1.6806, -0.1230, -0.4647, -0.0889, -2.0839, -2.7138, -0.7688, -1.9872])


In [77]:
stacked_input = embedded_sentence.T.repeat(8, 1, 1)
stacked_input.shape

torch.Size([8, 16, 8])

In [81]:
multihead_keys = torch.bmm(multihead_U_key, stacked_input)
multihead_keys.shape

torch.Size([8, 16, 8])

In [82]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_keys.shape

torch.Size([8, 8, 16])

In [84]:
multihead_keys[2, 1]

tensor([ 0.0533, -0.2590, -0.5376, -0.8360,  0.1815, -1.0017, -0.9257, -1.4889,
        -1.6172, -0.2682,  2.2755, -0.0882, -0.1427,  0.3652, -0.4133, -1.3387])

In [87]:
multihead_values = multihead_U_value.bmm(stacked_input).permute(0, 2, 1)
multihead_values.shape

torch.Size([8, 8, 16])

In [88]:
multihead_z_2 = torch.rand(8, 16)

In [89]:
linear = nn.Linear(8*16, 16)
context_vector_2 = linear(multihead_z_2.flatten())
context_vector_2.shape

torch.Size([16])