# Attention Mechanisms and Transformers

This notebook demonstrates how attention mechanisms can be implemente and how they are used within transformer architectures. We will develop an understanding from first principles using PyTorch (no prior knowledge is requried).

Attention mechanisms in deep learning aim to map a basic sequence of [word embeddings](https://en.wikipedia.org/wiki/Word_embedding) into another sequence of embeddings that represent each word conditioned on the 'derived context' of that word within the text.

We could express this mapping mathematically as, $\textbf{x} \to \textbf{z} = f(\textbf{x})$, where $\textbf{x} = (\vec{x_{1}}, ..., \vec{x_{N}})$, $\textbf{z} = (\vec{z_{1}}, ..., \vec{z_{N}})$, $\vec{x}$ and $\vec{z}$ are individual embedding vectors, and $N$ is the number of tokens in the sequence. The goal of attention is to learn $f$ from data.

## Importing PyTorch

We will mostly be using PyTorch like NumPy (to create and manipulate tensors), but we will also use one or two modules from its neural networks module.

In [82]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

## Example Inputs

Let's start with a sentence

In [48]:
sentence = "Sorry Dave, I'm afraid I can't do that."

We then [tokenise](https://en.wikipedia.org/wiki/Lexical_analysis#Tokenization) our sentence into a sequence of integer values (one for each word), using an imaginary tokenisation function.

In [49]:
def tokenise(text: str, vocab_size: int) -> torch.Tensor:
    """Dummy text tokeniser."""
    words = text.split(" ")
    return torch.randint(0, vocab_size, [len(words)])


VOCAB_SIZE = 20000

tokenised_sentence = tokenise(sentence, VOCAB_SIZE)
n_tokens = len(tokenised_sentence)

tokenised_sentence

tensor([17733, 13358,  3783,  4256,  6665, 10168,  2821, 16004])

And embed each token into a vector space (as a vector).

In [50]:
EMBEDDING_DIM = 32

embedding_layer = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)

embedded_tokens = embedding_layer(tokenised_sentence)
embedded_tokens.shape

torch.Size([8, 32])

Note that these embeddings will need to be learnt when training any model that uses an embedding layer. We can compute the number of parameters in this layers as follows,

In [52]:
n_embedding_params = sum(len(p.flatten()) for p in embedding_layer.parameters())
print(f"number of embedding parameters = {n_embedding_params:,}")

number of embedding parameters = 640,000


## Basic Self-Attention

One approach to computing attention is to express the new 'contextual embeddings' as a weighted linear combination or the input imbeddings - e.g., $\vec{x_{i}} \to \vec{z_{i}} = \sum_{j=1}^{N}{a_{ij} \times \vec{x_{j}}}$. We then can focus on strategies for computing the weights.

A sensible approach to computing the weights is to use the vector [dot product](https://en.wikipedia.org/wiki/Dot_product) between the embedding vectors - e.g., $a_{ij} = x_{i}^{T} \cdot x_{i}$. This will lead to weights that are higher for embedding vectors that are closer to one another in the embedding space (i.e., are semantically closer), and vice versa. We can compute these weights as follows,

In [53]:
attn_weights = torch.empty(n_tokens, n_tokens)

for i in range(n_tokens):
    for j in range(n_tokens):
        attn_weights[i, j] = torch.dot(embedded_tokens[i], embedded_tokens[j])

attn_weights.shape

torch.Size([8, 8])

The attention weights can also be computed more efficiently using matrix multiplication.

In [54]:
attn_weights_matmul = torch.matmul(embedded_tokens, embedded_tokens.T)

And we can verify that the two approaches are equivalent.

In [87]:
torch.allclose(attn_weights_matmul, attn_weights, atol=1e-6)

True

When it comes to implementing this in practice the weights are scaled by the embedding dimension, and subsequently renormalised to sum to one across rows using the [softmax function](https://en.wikipedia.org/wiki/Softmax_function). Steps like these make models easier to train by normalising the magnitude of gradients used within [stochastic gradient descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent). For more insight into this refer to [2].

In [85]:
attn_weights_norm = F.softmax(attn_weights / math.sqrt(EMBEDDING_DIM), dim=1)

Verify that the rows sum to one.

In [86]:
attn_weights_norm.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

We can now compute the final context-aware embeddings.

In [91]:
context_weighted_embeddings = torch.matmul(attn_weights_norm, embedded_tokens)
context_weighted_embeddings.shape

torch.Size([8, 32])

Let's verify that the embeddings are working as we expect by computing one manually.

In [11]:
context_weighted_embeddings_3 = (
    attn_weights_norm[3, 0] * embedded_tokens[0] +
    attn_weights_norm[3, 1] * embedded_tokens[1] +
    attn_weights_norm[3, 2] * embedded_tokens[2] +
    attn_weights_norm[3, 3] * embedded_tokens[3] +
    attn_weights_norm[3, 4] * embedded_tokens[4] +
    attn_weights_norm[3, 5] * embedded_tokens[5] +
    attn_weights_norm[3, 6] * embedded_tokens[6] +
    attn_weights_norm[3, 7] * embedded_tokens[7]
)

context_weighted_embeddings_3

tensor([-0.5963, -0.2296,  0.3928, -1.0474, -1.5370, -0.4205, -0.5213,  0.6360,
         0.2418, -0.9080, -0.4428, -0.5328, -0.6034, -1.2227,  0.2397,  0.3942,
        -0.7501,  0.3245, -0.0228,  0.1374, -0.2983, -0.2089,  0.6657, -0.1347,
        -0.3584,  1.4458,  1.4827, -0.3337, -1.1085,  0.4091,  1.9441, -0.3650],
       grad_fn=<AddBackward0>)

And verifying the output againt the matrix multiplication cmoputed above.

In [88]:
torch.allclose(context_weighted_embeddings_3, context_weighted_embeddings[3], atol=1e-6)

True

### Causal Masking

You may haven noticed that the embedding vector for the first word, $\vec{x_{1}}$, is mapped to a vector $\vec{z_{1}}$, that is a function of embedding vectors for words that come after the first word. This isn't a problem if all we're intrested in doing is creating embeddings (or sequences) based on whole passages of text. It does pose a problem, however, if we're trying to develop a model that can generate new sequences given an initital sequence (or prompt). This problem is solved by using causal masking.

Causal masking matrices can be constructed to flag which attention weights should be set to zero as they contain information that would break the causal relationships between embeddings. For example, in our setup we woudl have,

In [13]:
causal_mask = torch.triu(torch.full((n_tokens, n_tokens), True), diagonal=1)
causal_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])

Which we would apply to the attention weights as follows,

In [90]:
causal_attn_weights = attn_weights.masked_fill(causal_mask, 0.)
causal_attn_weights[0]

tensor([36.5003,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
       grad_fn=<SelectBackward0>)

With the scaling and normalisation applied as before.

In [94]:
causal_attn_weights_norm = F.softmax(causal_attn_weights / math.sqrt(EMBEDDING_DIM), dim=1)
causal_attn_weights_norm.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

From which we can compute causal context-aware embeddings.

In [95]:
causal_context_weighted_embeddings = torch.matmul(causal_attn_weights_norm, embedded_tokens)
causal_context_weighted_embeddings.shape

torch.Size([8, 32])

And we can demonstrate the causal structure explicitly.

In [17]:
causal_context_weighted_embeddings_3 = (
    causal_attn_weights_norm[3, 0] * embedded_tokens[0] +
    causal_attn_weights_norm[3, 1] * embedded_tokens[1] +
    causal_attn_weights_norm[3, 2] * embedded_tokens[2] +
    causal_attn_weights_norm[3, 3] * embedded_tokens[3]
)

torch.allclose(causal_context_weighted_embeddings_3, causal_context_weighted_embeddings[3])

True

## Parametrised Self-Attention

Making attention ameanable to learning.

### Queries, Keys and Values

Let...

In [18]:
u_q = torch.rand(n_tokens, n_tokens, requires_grad=True)
u_k = torch.rand(n_tokens, n_tokens, requires_grad=True)
u_v = torch.rand(n_tokens, n_tokens, requires_grad=True)

Such that,

In [19]:
q = torch.matmul(u_q, embedded_tokens)
k = torch.matmul(u_k, embedded_tokens)
v = torch.matmul(u_v, embedded_tokens)

q.shape == embedded_tokens.shape

True

Recompute context aware embeddings.

In [20]:
attn_weights_param = torch.empty(n_tokens, n_tokens)

for i in range(n_tokens):
    for j in range(n_tokens):
        attn_weights_param[i, j] = torch.dot(q[i], k[j])

attn_weights_param_norm = F.softmax(attn_weights_param, dim=1)
context_weighted_embeddings_param = torch.matmul(attn_weights_param_norm, v)

context_weighted_embeddings_param.shape

torch.Size([8, 32])

And verify.

In [21]:
context_weighted_embeddings_param_3 = (
    attn_weights_param_norm[3, 0] * v[0] +
    attn_weights_param_norm[3, 1] * v[1] +
    attn_weights_param_norm[3, 2] * v[2] +
    attn_weights_param_norm[3, 3] * v[3] +
    attn_weights_param_norm[3, 4] * v[4] +
    attn_weights_param_norm[3, 5] * v[5] +
    attn_weights_param_norm[3, 6] * v[6] +
    attn_weights_param_norm[3, 7] * v[7]
)

torch.allclose(context_weighted_embeddings_param_3, context_weighted_embeddings_param[3])

True

### Multi-Head Attention

Putting it all together.

In [22]:
def attention(
        query: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        causal_masking: bool = False
    ) -> torch.Tensor:
    """Compute single attention head."""
    n_tokens, embedding_dim = query.shape
    attn_weights_norm = torch.matmul(query, keys.T).softmax(dim=1)
    if causal_masking:
        mask = torch.triu(torch.full((n_tokens, n_tokens), True), diagonal=1)
        attn_weights_norm = attn_weights_norm.masked_fill(mask, 0.)
    context_weighted_embeddings = torch.matmul(attn_weights_norm, values)
    return context_weighted_embeddings


attn_head_out = attention(q, k, v)
attn_head_out.shape

torch.Size([8, 32])

Multiple heads...

In [23]:
def multi_head_attention(
    x_q: torch.Tensor,
    x_k: torch.Tensor,
    x_v: torch.Tensor,
    n_heads: int,
    causal_masking: bool = False
) -> torch.Tensor:
    """Computing multiple attention heads."""
    n_tokens, embedding_dim = embedded_tokens.shape

    u_q = torch.rand(n_heads, n_tokens, n_tokens, requires_grad=True)
    u_k = torch.rand(n_heads, n_tokens, n_tokens, requires_grad=True)
    u_v = torch.rand(n_heads, n_tokens, n_tokens, requires_grad=True)
    w_out = torch.rand(n_heads*embedding_dim, embedding_dim, requires_grad=True)

    attn_head_outputs = torch.concat(
        [attention(u_q[h] @ x_q, u_k[h] @ x_k, u_v[h] @ x_v) for h in range(n_heads)],
         dim=1
    )

    return torch.matmul(attn_head_outputs, w_out)


multi_head_attn_out = multi_head_attention(embedded_tokens, embedded_tokens, embedded_tokens, n_heads=3)
multi_head_attn_out.shape

torch.Size([8, 32])

## Transformers

Assembling a transforner decoder network to demonstrate the use of multi-head attention blocks.

In [24]:
def transformer_decoder_layer(
        src_embedding: torch.Tensor,
        target_embedding: torch.Tensor,
        n_heads: int,
        causal_masking: bool = False
) -> torch.Tensor:
    """Assemble a transformer decoder layer from """
    x1 = multi_head_attention(target_embedding, target_embedding, target_embedding, n_heads)
    x1 = F.layer_norm(x1 + target_embedding, x1.shape)
    x2 = multi_head_attention(src_embedding, src_embedding, x1, n_heads)
    x2 = F.layer_norm(x2 + x1, x1.shape)

    linear_1 = nn.Linear(embedding_dim, 2*embedding_dim)
    linear_2 = nn.Linear(2*embedding_dim, embedding_dim)

    return F.relu(linear_2(linear_1(x2)))


transformer_output = transformer_decoder_layer(embedded_tokens, embedded_tokens, n_heads=2)
transformer_output.shape

torch.Size([8, 32])

## Useful Resources

1. [PyTorch docs](https://pytorch.org/docs/stable/index.html)
2. [The Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/)