# From Attention Mechanisms to Transformers

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

Start with some text.

In [2]:
sentence = "Sorry Dave, I'm afraid I can't do that."

Tokenise the text.

In [3]:
def tokenise(text: str) -> torch.Tensor:
    words = text.split(" ")
    return torch.randint(0, 20000, [len(words)])


tokenised_sentence = tokenise(sentence)
tokenised_sentence

tensor([14062, 19814,  1096,  2874,  9866, 17985,  3681,   685])

Convert from tokens to embeddings.

In [4]:
vocab_size = 20000
embedding_dim = 32

embedding_layer = nn.Embedding(vocab_size, embedding_dim)

embedded_tokens = embedding_layer(tokenised_sentence)
embedded_tokens.shape

torch.Size([8, 32])

## Basic Self-Attention

Compute the self-similarity matrix between all the embedding vectors.

In [5]:
n_tokens = len(tokenised_sentence)

attn_weights = torch.empty(n_tokens, n_tokens)

for i in range(n_tokens):
    for j in range(n_tokens):
        attn_weights[i, j] = torch.dot(embedded_tokens[i], embedded_tokens[j])

attn_weights.shape

torch.Size([8, 8])

This can also be done with matrix multiplication.

In [6]:
attn_weights_matmul = torch.matmul(embedded_tokens, embedded_tokens.T)

Which is easy to verify.

In [7]:
torch.allclose(attn_weights_matmul, attn_weights)

False

Renormalise the rows, so that they sum to one.

In [8]:
attn_weights_norm = F.softmax(attn_weights, dim=1)

Verify the output.

In [9]:
attn_weights_norm.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<SumBackward1>)

Compute context-aware embeddings.

In [10]:
context_weighted_embeddings = torch.matmul(attn_weights_norm, embedded_tokens)

context_weighted_embeddings.shape

torch.Size([8, 32])

Demonstrate manual computation.

In [11]:
context_weighted_embeddings_3 = (
    attn_weights_norm[3, 0] * embedded_tokens[0] +
    attn_weights_norm[3, 1] * embedded_tokens[1] +
    attn_weights_norm[3, 2] * embedded_tokens[2] +
    attn_weights_norm[3, 3] * embedded_tokens[3] +
    attn_weights_norm[3, 4] * embedded_tokens[4] +
    attn_weights_norm[3, 5] * embedded_tokens[5] +
    attn_weights_norm[3, 6] * embedded_tokens[6] +
    attn_weights_norm[3, 7] * embedded_tokens[7]
)

context_weighted_embeddings_3

tensor([-0.5963, -0.2296,  0.3928, -1.0474, -1.5370, -0.4205, -0.5213,  0.6360,
         0.2418, -0.9080, -0.4428, -0.5328, -0.6034, -1.2227,  0.2397,  0.3942,
        -0.7501,  0.3245, -0.0228,  0.1374, -0.2983, -0.2089,  0.6657, -0.1347,
        -0.3584,  1.4458,  1.4827, -0.3337, -1.1085,  0.4091,  1.9441, -0.3650],
       grad_fn=<AddBackward0>)

And verify the output.

In [12]:
torch.allclose(context_weighted_embeddings_3, context_weighted_embeddings[3])

True

### Causal Masking

Build a causal masking matrix.

In [13]:
causal_mask = torch.triu(torch.full((n_tokens, n_tokens), True), diagonal=1)
causal_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False]])

Apply the mask to the attention weights.

In [14]:
causal_attn_weights = attn_weights.masked_fill(causal_mask, 0.)

causal_attn_weights

tensor([[ 6.2212e+01,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-1.2375e+01,  3.6565e+01,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 7.0231e+00, -6.2132e+00,  2.8644e+01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 3.4108e+00, -2.1306e+00,  1.2248e+00,  1.9428e+01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-2.4254e+00,  1.3699e+00, -1.3215e+00, -4.4710e+00,  2.3533e+01,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-9.7495e+00,  1.0044e+01, -9.8557e+00,  5.0011e+00, -3.6843e+00,
          3.6312e+01,  0.0000e+00,  0.0000e+00],
        [-7.0159e+00, -6.5227e+00, -4.4904e+00, -1.0484e-02,  3.6967e+00,
         -6.4772e-01,  3.0431e+01,  0.0000e+00],
        [-5.0422e-01,  2.7397e+00, -4.4311e+00, -9.8397e-01, -7.0092e+00,
          1.6973e-01, -7.5941e+00,  2.7720e+01]],
       grad_fn=

Apply normalisation.

In [15]:
causal_attn_weights_norm = F.softmax(causal_attn_weights, dim=1)
causal_attn_weights_norm.sum(dim=1)

tensor([1., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<SumBackward1>)

Compute causal context-aware embeddings.

In [16]:
causal_context_weighted_embeddings = torch.matmul(causal_attn_weights_norm, embedded_tokens)

causal_context_weighted_embeddings.shape

torch.Size([8, 32])

Demonstrate causal structure by computing manually.

In [17]:
causal_context_weighted_embeddings_3 = (
    causal_attn_weights_norm[3, 0] * embedded_tokens[0] +
    causal_attn_weights_norm[3, 1] * embedded_tokens[1] +
    causal_attn_weights_norm[3, 2] * embedded_tokens[2] +
    causal_attn_weights_norm[3, 3] * embedded_tokens[3]
)

torch.allclose(causal_context_weighted_embeddings_3, causal_context_weighted_embeddings[3])

True

## Parametrised Self-Attention

Making attention ameanable to learning.

### Queries, Keys and Values

Let...

In [18]:
u_q = torch.rand(n_tokens, n_tokens, requires_grad=True)
u_k = torch.rand(n_tokens, n_tokens, requires_grad=True)
u_v = torch.rand(n_tokens, n_tokens, requires_grad=True)

Such that,

In [19]:
q = torch.matmul(u_q, embedded_tokens)
k = torch.matmul(u_k, embedded_tokens)
v = torch.matmul(u_v, embedded_tokens)

q.shape == embedded_tokens.shape

True

Recompute context aware embeddings.

In [20]:
attn_weights_param = torch.empty(n_tokens, n_tokens)

for i in range(n_tokens):
    for j in range(n_tokens):
        attn_weights_param[i, j] = torch.dot(q[i], k[j])

attn_weights_param_norm = F.softmax(attn_weights_param, dim=1)
context_weighted_embeddings_param = torch.matmul(attn_weights_param_norm, v)

context_weighted_embeddings_param.shape

torch.Size([8, 32])

And verify.

In [21]:
context_weighted_embeddings_param_3 = (
    attn_weights_param_norm[3, 0] * v[0] +
    attn_weights_param_norm[3, 1] * v[1] +
    attn_weights_param_norm[3, 2] * v[2] +
    attn_weights_param_norm[3, 3] * v[3] +
    attn_weights_param_norm[3, 4] * v[4] +
    attn_weights_param_norm[3, 5] * v[5] +
    attn_weights_param_norm[3, 6] * v[6] +
    attn_weights_param_norm[3, 7] * v[7]
)

torch.allclose(context_weighted_embeddings_param_3, context_weighted_embeddings_param[3])

True

### Multi-Head Attention

Putting it all together.

In [22]:
def attention(
        query: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        causal_masking: bool = False
    ) -> torch.Tensor:
    """Compute single attention head."""
    n_tokens, embedding_dim = query.shape
    attn_weights_norm = torch.matmul(query, keys.T).softmax(dim=1)
    if causal_masking:
        mask = torch.triu(torch.full((n_tokens, n_tokens), True), diagonal=1)
        attn_weights_norm = attn_weights_norm.masked_fill(mask, 0.)
    context_weighted_embeddings = torch.matmul(attn_weights_norm, values)
    return context_weighted_embeddings


attn_head_out = attention(q, k, v)
attn_head_out.shape

torch.Size([8, 32])

Multiple heads...

In [23]:
def multi_head_attention(
    x_q: torch.Tensor,
    x_k: torch.Tensor,
    x_v: torch.Tensor,
    n_heads: int,
    causal_masking: bool = False
) -> torch.Tensor:
    """Computing multiple attention heads."""
    n_tokens, embedding_dim = embedded_tokens.shape

    u_q = torch.rand(n_heads, n_tokens, n_tokens, requires_grad=True)
    u_k = torch.rand(n_heads, n_tokens, n_tokens, requires_grad=True)
    u_v = torch.rand(n_heads, n_tokens, n_tokens, requires_grad=True)
    w_out = torch.rand(n_heads*embedding_dim, embedding_dim, requires_grad=True)

    attn_head_outputs = torch.concat(
        [attention(u_q[h] @ x_q, u_k[h] @ x_k, u_v[h] @ x_v) for h in range(n_heads)],
         dim=1
    )

    return torch.matmul(attn_head_outputs, w_out)


multi_head_attn_out = multi_head_attention(embedded_tokens, embedded_tokens, embedded_tokens, n_heads=3)
multi_head_attn_out.shape

torch.Size([8, 32])

## Transformers

Assembling a transforner decoder network to demonstrate the use of multi-head attention blocks.

In [24]:
def transformer_decoder_layer(
        src_embedding: torch.Tensor,
        target_embedding: torch.Tensor,
        n_heads: int,
        causal_masking: bool = False
) -> torch.Tensor:
    """Assemble a transformer decoder layer from """
    x1 = multi_head_attention(target_embedding, target_embedding, target_embedding, n_heads)
    x1 = F.layer_norm(x1 + target_embedding, x1.shape)
    x2 = multi_head_attention(src_embedding, src_embedding, x1, n_heads)
    x2 = F.layer_norm(x2 + x1, x1.shape)

    linear_1 = nn.Linear(embedding_dim, 2*embedding_dim)
    linear_2 = nn.Linear(2*embedding_dim, embedding_dim)

    return F.relu(linear_2(linear_1(x2)))


transformer_output = transformer_decoder_layer(embedded_tokens, embedded_tokens, n_heads=2)
transformer_output.shape

torch.Size([8, 32])