# Attention and Transformers

We will demonstrate how attention mechanisms work, how they can be implemented and their use within model based on transformer architectures. We will develop an understanding from first principles using PyTorch for creating and manipulating tensors.

Ultimately, we're aiming to demystify what's' happening within PyTorch's high-level transformer modules: [torch.nn.TransformerEncoderLayer](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html#torch.nn.TransformerEncoderLayer) and [torch.nn.TransformerDecoderLayer](https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoderLayer.html#torch.nn.TransformerDecoderLayer).

Attention mechanisms aim to map [word embeddings](https://en.wikipedia.org/wiki/Word_embedding) from one vector space into another, based on the other word embeddings in the sequence. This produces context-aware embeddings.

We could express this mapping mathematically as, $\textbf{x} \to \textbf{z} = f(\textbf{x})$, where $\textbf{x} = (\vec{x_{1}}, ..., \vec{x_{N}})$, $\textbf{z} = (\vec{z_{1}}, ..., \vec{z_{N}})$, $\vec{x}$ and $\vec{z}$ are $d$-dimensional embedding vectors and $N$ is the number of tokens in the sequence. The goal of attention is to learn $f$ from data to solve machine learning tasks such as sequence-to-sequence translation.

## Imports

We will mostly be using PyTorch like NumPy (to create and manipulate tensors), but we will also use one or two modules from its neural networks module, `torch.nn`.

In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

## Example Text

Let's start with a sentence

In [2]:
sentence = "Sorry Dave, I'm afraid I can't do that."

We then [tokenize](https://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) our sentence into a sequence of integer values (one for each word), using an imaginary tokenization function.

In [3]:
def tokenize(text: str, vocab_size: int) -> torch.Tensor:
    """Dummy text tokenizer."""
    words = text.split(" ")
    return torch.randint(0, vocab_size, [len(words)])


VOCAB_SIZE = 20000

tokenized_sentence = tokenize(sentence, VOCAB_SIZE)
n_tokens = len(tokenized_sentence)
tokenized_sentence

tensor([10277, 18871, 14910, 13181,  2829, 19980,  9604, 10053])

And embed each token into a vector space using PyTorch's [torch.nn.Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding) module.

In [4]:
EMBEDDING_DIM = 32

embedding_layer = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
embedded_tokens = embedding_layer(tokenized_sentence)
embedded_tokens.shape

torch.Size([8, 32])

These embeddings will need to be learnt when training any model that uses an embedding layer. We can easily compute the number of parameters that need to be learnt.

In [5]:
n_embedding_params = sum(len(p.flatten()) for p in embedding_layer.parameters())
print(f"number of embedding parameters = {n_embedding_params:,}")

number of embedding parameters = 640,000


## Basic Self-Attention

An approach to computing attention is to express the new context-aware embeddings as a weighted linear combination or the input embeddings - e.g., $\vec{x_{i}} \to \vec{z_{i}} = \sum_{j=1}^{N}{a_{ij} \times \vec{x_{j}}}$. 

One sensible approach to computing the weights is to use the vector [dot product](https://en.wikipedia.org/wiki/Dot_product) between the embedding vectors - e.g., $a_{ij} = x_{i}^{T} \cdot x_{i}$. This will lead to weights that are higher for embedding vectors that are geometrically nearer to one another in the embedding space (i.e., are semantically closer), and vice versa.

In [6]:
attn_weights = torch.empty(n_tokens, n_tokens)

for i in range(n_tokens):
    for j in range(n_tokens):
        attn_weights[i, j] = torch.dot(embedded_tokens[i], embedded_tokens[j])

attn_weights.shape

torch.Size([8, 8])

This calculation can also be computed more efficiently using matrix multiplication.

In [7]:
attn_weights_matmul = torch.matmul(embedded_tokens, embedded_tokens.T)

And we can verify that the two approaches are equivalent.

In [8]:
torch.allclose(attn_weights_matmul, attn_weights, atol=1e-6)

True

When it comes to implementing this in practice the weights are scaled by the embedding dimension, and subsequently renormalised to sum to one across rows using the [softmax function](https://en.wikipedia.org/wiki/Softmax_function). Steps like these make models easier to train by normalising the magnitude of gradients used within algorithms like [stochastic gradient descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent). Refer to [3] for more insight into this and related issues.

In [9]:
attn_weights_norm = F.softmax(attn_weights / math.sqrt(EMBEDDING_DIM), dim=1)

Verify that rows sum to one.

In [10]:
attn_weights_norm.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

We can now compute the final context-aware embeddings.

In [11]:
context_weighted_embeddings = torch.matmul(attn_weights_norm, embedded_tokens)
context_weighted_embeddings.shape

torch.Size([8, 32])

Let's verify that these embeddings are working as we expect by computing one manually.

In [12]:
context_weighted_embeddings_3 = (
    attn_weights_norm[3, 0] * embedded_tokens[0]
    + attn_weights_norm[3, 1] * embedded_tokens[1]
    + attn_weights_norm[3, 2] * embedded_tokens[2]
    + attn_weights_norm[3, 3] * embedded_tokens[3]
    + attn_weights_norm[3, 4] * embedded_tokens[4]
    + attn_weights_norm[3, 5] * embedded_tokens[5]
    + attn_weights_norm[3, 6] * embedded_tokens[6]
    + attn_weights_norm[3, 7] * embedded_tokens[7]
)

context_weighted_embeddings_3

tensor([-0.7159,  0.4557,  0.8079, -0.3945, -0.3668,  0.9979, -0.1918, -1.1513,
        -0.8344, -0.8734,  0.7873, -0.4425, -0.7375, -0.5568,  0.1270,  0.6518,
        -0.1288, -0.5502,  0.5016,  0.0821, -0.1772,  1.1589,  0.9620,  2.0633,
         1.4493,  0.5128, -0.0210, -0.4529,  0.0912, -0.0551,  0.1342, -0.4279],
       grad_fn=<AddBackward0>)

And verifying the output against the matrix multiplication computed above.

In [13]:
torch.allclose(context_weighted_embeddings_3, context_weighted_embeddings[3], atol=1e-6)

True

### Causal Masking

You may haven noticed that the embedding vector for the first word, $\vec{x_{1}}$, is mapped to a vector $\vec{z_{1}}$ that is a function of embedding vectors for words that come after the first word. This isn't a problem if all we're doing is creating embeddings (or sequences) based on whole passages of text. It does pose a problem, however, if we're trying to develop a model that can generate new sequences given an initial sequence (or prompt). This problem is solved by using causal masking.

Causal masking matrices can be constructed to flag which attention weights should be set to zero so that causal relationships between embeddings aren't broken.

In [14]:
causal_mask = torch.triu(torch.full((n_tokens, n_tokens), True), diagonal=1)
causal_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])

Which we will apply directly to the attention weights.

In [15]:
causal_attn_weights = attn_weights.masked_fill(causal_mask, -1e10)
causal_attn_weights[0]

tensor([ 4.3661e+01, -1.0000e+10, -1.0000e+10, -1.0000e+10, -1.0000e+10,
        -1.0000e+10, -1.0000e+10, -1.0000e+10], grad_fn=<SelectBackward0>)

And apply scaling and normalisation as before.

In [16]:
causal_attn_weights_norm = F.softmax(
    causal_attn_weights / math.sqrt(EMBEDDING_DIM), dim=1
)
causal_attn_weights_norm.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

From which we can compute causal context-aware embeddings.

In [17]:
causal_context_weighted_embeddings = torch.matmul(
    causal_attn_weights_norm, embedded_tokens
)
causal_context_weighted_embeddings.shape

torch.Size([8, 32])

The integrity of the causal structure is easily demonstrated.

In [18]:
causal_context_weighted_embeddings_3 = (
    causal_attn_weights_norm[3, 0] * embedded_tokens[0]
    + causal_attn_weights_norm[3, 1] * embedded_tokens[1]
    + causal_attn_weights_norm[3, 2] * embedded_tokens[2]
    + causal_attn_weights_norm[3, 3] * embedded_tokens[3]
)

torch.allclose(
    causal_context_weighted_embeddings_3, causal_context_weighted_embeddings[3]
)

True

## Parametrised Self-Attention

Up to this point we have described a basic attention mechanism where the only parameters that can be learnt are for the initial embedding vectors. At this point the system is limited in its ability to adapt the attention mechanism to the task(s) at hand.

### Queries, Keys and Values

We begin by generalising the attention mechanism - let, $\textbf{q} = (\vec{q_{1}}, ..., \vec{q_{N}})$, $\textbf{k} = (\vec{k_{1}}, ..., \vec{k_{N}})$ and $\textbf{v} = (\vec{v_{1}}, ..., \vec{v_{N}})$ be three new sequences representing a query, keys and values respectively. In this setup, the values contain the information that we wish to access via a query that is made on a set of keys (that map to the values), such that the context-aware embeddings can now be computed as,

$$
\vec{z_{i}} = \sum_{j=1}^{N}{a_{ij} \times \vec{v_{j}}}
$$

Where, $a_{ij} = q_{i}^{T} \cdot k_{i}$ - i.e., the attention weights now represent the distance between the query and keys.

Very often we only have a single sequence to work with, so the model will have to learn how to infer the queries, keys and values from this. We can enable this level of plasticity by defining three  $N \times N$ weight matrices, $\textbf{U}_{q}$, $\textbf{U}_{k}$ and $\textbf{U}_{v}$.

In [19]:
u_q = torch.rand(n_tokens, n_tokens)
u_k = torch.rand(n_tokens, n_tokens)
u_v = torch.rand(n_tokens, n_tokens)

From which we can define the query, keys and values as functions of $\textbf{x}$.

In [20]:
q = torch.matmul(u_q, embedded_tokens)
k = torch.matmul(u_k, embedded_tokens)
v = torch.matmul(u_v, embedded_tokens)

q.shape == k.shape == v.shape == embedded_tokens.shape

True

We then recompute our parameterised attention weights using the same steps we used before.

In [21]:
attn_weights_param = torch.empty(n_tokens, n_tokens)

for i in range(n_tokens):
    for j in range(n_tokens):
        attn_weights_param[i, j] = torch.dot(q[i], k[j])

attn_weights_param_norm = F.softmax(
    attn_weights_param / math.sqrt(EMBEDDING_DIM), dim=1
)
context_weighted_embeddings_param = torch.matmul(attn_weights_param_norm, v)

context_weighted_embeddings_param.shape

torch.Size([8, 32])

And verify that the context-aware embeddings behave as we'd expect.

In [22]:
context_weighted_embeddings_param_3 = (
    attn_weights_param_norm[3, 0] * v[0]
    + attn_weights_param_norm[3, 1] * v[1]
    + attn_weights_param_norm[3, 2] * v[2]
    + attn_weights_param_norm[3, 3] * v[3]
    + attn_weights_param_norm[3, 4] * v[4]
    + attn_weights_param_norm[3, 5] * v[5]
    + attn_weights_param_norm[3, 6] * v[6]
    + attn_weights_param_norm[3, 7] * v[7]
)

torch.allclose(
    context_weighted_embeddings_param_3, context_weighted_embeddings_param[3]
)

True

### Multi-Head Attention

<center><img src="images/attention.png" width="500"/></center>

In what follows we demonstrate how use the parametrised attention mechanism sketched out above to develop the multi-head attention block that forms the foundation of all transformer architectures. Our aim here is purely didactic - the functions defined below won't yield anything you can train (refer to the full codebase in the `modelling` directory for this), but they do demonstrate how these algorithm are composed.

We start by encapsulating the parametrised attention mechanism within a single function.

In [23]:
def attention(
    query: torch.Tensor,
    keys: torch.Tensor,
    values: torch.Tensor,
    causal_masking: bool = False,
) -> torch.Tensor:
    """Compute single attention head."""
    n_tokens, embedding_dim = query.shape
    attn_weights = torch.matmul(query, keys.T) / math.sqrt(EMBEDDING_DIM)
    if causal_masking:
        mask = torch.triu(torch.full((n_tokens, n_tokens), True), diagonal=1)
        attn_weights = attn_weights.masked_fill(mask, -1e10)
    attn_weights_norm = attn_weights.softmax(dim=1)
    context_weighted_embeddings = torch.matmul(attn_weights_norm, values)
    return context_weighted_embeddings


attn_head_out = attention(q, k, v)
attn_head_out.shape

torch.Size([8, 32])

And use this to define an attention mechanism with multiple attention blocks or 'heads'. This enables models to learn multiple 'contexts' - different sets of keys and values - not unlike how convolutional neural networks use multiple sets of filter banks to detect features at different scales (it is likely that this analog is what motivated this design).

In [24]:
def multi_head_attention(
    x_q: torch.Tensor,
    x_k: torch.Tensor,
    x_v: torch.Tensor,
    n_heads: int,
    causal_masking: bool = False,
) -> torch.Tensor:
    """Computing attention with multiple heads."""
    n_tokens, embedding_dim = embedded_tokens.shape

    u_q = torch.rand(n_heads, n_tokens, n_tokens)
    u_k = torch.rand(n_heads, n_tokens, n_tokens)
    u_v = torch.rand(n_heads, n_tokens, n_tokens)

    attn_head_outputs = torch.concat(
        [attention(u_q[h] @ x_q, u_k[h] @ x_k, u_v[h] @ x_v) for h in range(n_heads)],
        dim=1,
    )

    w_out = torch.rand(n_heads * embedding_dim, embedding_dim, requires_grad=True)
    return torch.matmul(attn_head_outputs, w_out)


multi_head_attn_out = multi_head_attention(
    embedded_tokens, embedded_tokens, embedded_tokens, n_heads=3
)
multi_head_attn_out.shape

torch.Size([8, 32])

Note, that `@` is a shorthand operator for matrix multiplication and that `torch.rand(n_heads, n_tokens, n_tokens)` could also be replaced with `nn.Linear` as these matrices are equivalent to passing the inputs through a fully-connected (or dense) network layer.

## Transformers

<center><img src="images/encoder_decoder.png" width="500"/></center>

We now know enough to assemble the basic transformer architecture, starting with a single layer encoder.

In [25]:
def transformer_encoder_layer(
    src_embedding: torch.Tensor, n_heads: int, causal_masking: bool = False
) -> torch.Tensor:
    """Transformer encoder layer."""
    x = multi_head_attention(src_embedding, src_embedding, src_embedding, n_heads)
    x = F.layer_norm(x + src_embedding, x.shape)

    linear_1 = nn.Linear(EMBEDDING_DIM, 2 * EMBEDDING_DIM)
    linear_2 = nn.Linear(2 * EMBEDDING_DIM, EMBEDDING_DIM)

    x = x + F.relu(linear_2(linear_1(x)))

    return x


encoder_output = transformer_encoder_layer(embedded_tokens, n_heads=2)
encoder_output.shape

torch.Size([8, 32])

And then a single layer decoder.

In [26]:
def transformer_decoder_layer(
    src_embedding: torch.Tensor,
    target_embedding: torch.Tensor,
    n_heads: int,
    causal_masking: bool = False,
) -> torch.Tensor:
    """Transformer decoder layer."""
    x = multi_head_attention(
        target_embedding, target_embedding, target_embedding, n_heads
    )
    x = F.layer_norm(x + target_embedding, x.shape)
    x = x + multi_head_attention(src_embedding, src_embedding, x, n_heads)
    x = F.layer_norm(x, x.shape)

    linear_1 = nn.Linear(EMBEDDING_DIM, 2 * EMBEDDING_DIM)
    linear_2 = nn.Linear(2 * EMBEDDING_DIM, EMBEDDING_DIM)

    x = x + F.relu(linear_2(linear_1(x)))

    return x


decoder_output = transformer_decoder_layer(embedded_tokens, embedded_tokens, n_heads=2)
decoder_output.shape

torch.Size([8, 32])

## Where to go from Here

Now that we have a basic insight into attention and transformers we will be using PyTorch's `torch.nn.TransformerEncoderLayer` and `torch.nn.TransformerDecoderLayer` modules in subsequent notebooks, to compose and train transformer-based models for tackling NLP tasks (e.g., text generation).

## Useful Resources

1. [Introduction to PyTorch](https://alexioannides.com/data-science-and-ml-notebook/pytorch/)
2. [PyTorch docs](https://pytorch.org/docs/stable/index.html)
3. [The Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)