# Understanding The Attention Mechanism

<div style='width:75%'>
    <p>
    In this notebook, we aim to understand the key mechanism that makes Transformer models so powerful.
    The heart and soul of a Transformer, attention. The goal will be to understand the inner workings
    of the attention mechanism and how tensors flow through the operation. The idea is to gain a grasp
    of why and how things work starting from embedding an input sequence to the computation of the
    attention weights.
    </p>
    <img src="../images/transformer.png" width="250"/>
</div>

In [7]:
import numpy as np

import torch
import torch.nn.functional as F
from torch import nn

# seeding our environment
torch.manual_seed(123)

<torch._C.Generator at 0x108e9ac30>

### Embedding an input sequence

In [8]:
sentence = 'wow that is crazy'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'crazy': 0, 'is': 1, 'that': 2, 'wow': 3}


In [9]:
encoded_sentence = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(encoded_sentence)

tensor([3, 2, 1, 0])


In [10]:
embedding_size = 16
n_words = encoded_sentence.shape[0]
embed = torch.nn.Embedding(n_words, embedding_size)
embedded_sentence = embed(encoded_sentence).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692]])
torch.Size([4, 16])


## Scaled Dot-Product Attention

<img src="../images/scaled-dot-prod-attention.png" width="400"/>

We being by initializing our weight matricies for $W_q, W_k, W_v$ where the matrices $W_q$ and $W_k$ are both of size $d_{model} \times d_k$ and $W_v$ is of size $d_{model} \times d_v$. 

The variables $d_{model}, d_k, d_v$ are defined as:
1. $d_{model}$ - the length of the embedding vector
2. $d_k$ - the dimension of the query and key matrices
3. $d_v$ - the dimension of the value matrix

In [11]:
d_model = embedding_size # length of the input sequence (embedding vector)
d_k = 64 # dimension for queries and keys
d_v = 64 # dimension for values

W_q = nn.Parameter(torch.randn(d_model, d_k))
W_k = nn.Parameter(torch.randn(d_model, d_k))
W_v = nn.Parameter(torch.randn(d_model, d_v))

print(f"Weight Matrices\n==========\nQueries: {W_q.shape}\nKeys: {W_k.shape}\nValues: {W_v.shape}\n")
print(f"Embeddings: {embedded_sentence.shape}")

Weight Matrices
Queries: torch.Size([16, 64])
Keys: torch.Size([16, 64])
Values: torch.Size([16, 64])

Embeddings: torch.Size([4, 16])


In [12]:
Q = torch.matmul(embedded_sentence, W_q)
K = torch.matmul(embedded_sentence, W_k)
V = torch.matmul(embedded_sentence, W_v)

print(f"Resulting Matrices\n==========\nqueries: {Q.shape}\nkeys: {K.shape}\nvalues: {V.shape}")

Resulting Matrices
queries: torch.Size([4, 64])
keys: torch.Size([4, 64])
values: torch.Size([4, 64])


$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

In [13]:
scores = torch.matmul(Q, K.T) / np.sqrt(d_k) # unnormalized attention weights
attn_weights = F.softmax(scores, dim=-1) # normalized attention weights
print(attn_weights.shape)
print(attn_weights)

torch.Size([4, 4])
tensor([[5.6082e-22, 1.0000e+00, 3.3372e-23, 3.4169e-08],
        [1.2046e-12, 3.2347e-16, 1.0000e+00, 1.9651e-10],
        [7.2092e-20, 7.5277e-13, 5.3729e-13, 1.0000e+00],
        [6.3403e-10, 2.0965e-06, 2.5648e-09, 1.0000e+00]],
       grad_fn=<SoftmaxBackward0>)


## BertViz

In [14]:
from IPython.display import clear_output
!pip install bertviz
!pip install jupyterlab
!pip install ipywidgets
clear_output()

Load the pre-trained model. We're using Google's BERT but you can use any pre-trained or fine-tuned model.

In [None]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import head_view
utils.logging.set_verbosity_error()  # Suppress standard warnings
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased", output_attentions=True)

Now we can prepare the inputs and perform the attention operation

In [23]:
inputs = tokenizer.encode("the apple is red", return_tensors='pt')
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0]) 
head_view(attention, tokens)

<IPython.core.display.Javascript object>

In [24]:
inputs = tokenizer.encode(sentence, return_tensors='pt')
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0]) 
head_view(attention, tokens)

<IPython.core.display.Javascript object>

# Acknowledgements

- [Visualizing Attention, a Transformer's Heart](https://www.youtube.com/watch?v=eMlx5fFNoYc) by 3Blue1Brown
- [Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html) by Sebastian Raschka
- [Self Attention in Transformer Neural Networks (with Code!)](https://www.youtube.com/watch?v=QCJQG4DuHT0&list=PLTl9hO2Oobd97qfWC40gOSU8C0iu0m2l4) by CodeEmporium
- Visualizing attention matrix using [BertViz](https://github.com/jessevig/bertviz)