# Single Self Attention Head
Given a set of tokens, the self attention solves the problem of working out by how much a token depends on another in a sequence. This notebook starts with the most basic case of a decoder attention block (one where tokens can't see those that come after it in a sequence), and slowly improves it to one that can be used in a Transformer.

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1337)

<torch._C.Generator at 0x7f2044171d10>

# Weakest Case
The code below creates an exmaple input in the shape `[B, T, C]`, where `B` refers to the number of batches, `T` the number of tokens, and `C` the dimensionality of a token.

In a decoder attention head, the model can't see into the future, so at each stage it wants the information of the data that comes before it. In this weakest case, this is achieved through averaging and some for loops. 

Each row of `x_bag_of_words[0]` contains the average of the rows up to an including the one being processed (`t`th row) of `x[0]`. So the final row is the complete average of all the rows of `x[0]`. 

In [56]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

x_bag_of_words = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_prev = x[b, :t+1] # (t, C)
        x_bag_of_words[b, t] = torch.mean(x_prev, dim=0)

display(x[0])
display(x_bag_of_words[0])

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

# Improved Case
The above implementation is based on `for` loops which are slow. We can make the code much more efficient using matrix multiplication. The improved case is built up slowly across a number of cells. 

Below, the code shows the basic matrix multiplication, where $C$ contains the sums of the columns in $B$ (as $A$ is all ones). 

In [57]:
# Basic mat mul
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b

print("a=")
print(a)
print('---')
print("b=")
print(b)
print('---')
print("c=")
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


Remembering back to the `for` loop implementation, we only want to incorporate data from the token being processed and those that come before it. To achieve this using matrix multiplication, we use a lower triangular matrix. In the code below, $C$ is given by $A \times B$ but as $A$ is lower-triangular, the values only incorporate those of the rows above.


In [58]:
# Mat mul with lower triangular matrix
torch.manual_seed(42)
new_a = torch.tril(torch.ones(3, 3))
new_c = new_a @ b

print("new a=")
print(new_a)
print('---')
print("b=")
print(b)
print('---')
print("old c=")
print(c)
print('---')
print("new c=")
print(new_c)

new a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
old c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])
---
new c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


However, the above code is only summing, and isn't taking the average, as used in the for loop example. Let's fix this.

In [59]:
# Mat mul with lower triangular matrix
torch.manual_seed(42)
newer_a = torch.tril(torch.ones(3, 3))
newer_a = newer_a / torch.sum(newer_a, dim=1, keepdim=True)
newer_c = newer_a @ b

print("newer a=")
print(newer_a)
print('---')
print("b=")
print(b)
print('---')
print("newer c=")
print(newer_c)
print('---')
print("original c=")
print(c)

newer a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
newer c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
---
original c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


## Comparing with the For Loop


In [60]:
def for_loop_based(x):
    B, T, C = x.shape
    x_bag_of_words = torch.zeros((B, T, C))
    for b in range(B):
        for t in range(T):
            x_prev = x[b, :t+1] # (t, C)
            x_bag_of_words[b, t] = torch.mean(x_prev, dim=0)

    return x_bag_of_words


def matrix_based(x):
    B, T, C = x.shape
    weights = torch.tril(torch.ones(T, T))
    weights = weights / torch.sum(weights, dim=1, keepdim=True)
    return weights @ x


B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

y1 = for_loop_based(x)
y2 = matrix_based(x)

print("y1=")
print(y1[0])
print('---')
print("y2=")
print(y2[0])

print(f"Tensors are equal: {torch.allclose(y1, y2)}")


y1=
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])
---
y2=
tensor([[ 1.9269,  1.4873],
        [ 1.4138, -0.3091],
        [ 1.1687, -0.6176],
        [ 0.8657, -0.8644],
        [ 0.5422, -0.3617],
        [ 0.3864, -0.5354],
        [ 0.2272, -0.5388],
        [ 0.1027, -0.3762]])
Tensors are equal: True


## Softmax
The final way we can change this is to use `Softmax`. Rather than just use a simple averaging, we will move to use softmax. This allows us to move away from equal weighting between each of the tokens and eventaully move to a technique that allows different weights. Softmax will ensure that when we do this, the weights will be normalised. For now however, it will achieve the same results as `matrix_based`. 

Note: In order to the softmax to normalise the future tokens to $0$, we mask them to be `-inf`.

In [61]:
def softmax_based(x):
    B, T, C = x.shape
    tril = torch.tril(torch.ones(T, T))
    weights = torch.zeros((T,T))
    weights = weights.masked_fill(tril == 0, float('-inf'))
    weights = F.softmax(weights, dim=-1)
    return weights @ x


B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

y1 = for_loop_based(x)
y2 = matrix_based(x)
y3 = softmax_based(x)

print("y1=")
print(y1[0])
print('---')
print("y2=")
print(y2[0])
print('---')
print("y3=")
print(y3[0])

print(f"Tensors are equal: {torch.allclose(y1, y2) and torch.allclose(y1, y3)}")

y1=
tensor([[1.4451, 0.8564],
        [1.8316, 0.6898],
        [1.3366, 0.3941],
        [0.7388, 0.6151],
        [0.5566, 0.5968],
        [0.4733, 0.5684],
        [0.4878, 0.3955],
        [0.1510, 0.2522]])
---
y2=
tensor([[1.4451, 0.8564],
        [1.8316, 0.6898],
        [1.3366, 0.3941],
        [0.7388, 0.6151],
        [0.5566, 0.5968],
        [0.4733, 0.5684],
        [0.4878, 0.3955],
        [0.1510, 0.2522]])
---
y3=
tensor([[1.4451, 0.8564],
        [1.8316, 0.6898],
        [1.3366, 0.3941],
        [0.7388, 0.6151],
        [0.5566, 0.5968],
        [0.4733, 0.5684],
        [0.4878, 0.3955],
        [0.1510, 0.2522]])
Tensors are equal: True


# Going to Self Attention!
The next block of code improves on that of the `softmax_based` function and takes us to something that looks like a basic self attention block.

Rather than create a weight matrix that determines the strength of the relationship between tokens that is based on averages (and therefore gives equal weighting to all previous tokens), we want to move to a technique that is more data driven. 

Given the sequence `['t', 'h', 'e', '<s>', 'c', 'a', 't', 's']`, when computing the self attention of the `<s>` token, the `'t'`, and `'e'` tokens will most likely have different weights assigned to them. However, the previous mechanisms we implemented won't account for this.

To overcome this issue, we introduce three vectors:
1. Query
2. Key
3. Value

Which are computed by performing linear transformations (via `nn.Linear` with `bias=False`) on the input `x`. The query vector roughly represents a token saying *"What am I looking for?"*, whilst a key vector roughly represents a token saying *"This is what I contain"*. Finally, the value vector roughly represents the token saying *"If you find me interesting, this is what I will give you"*. Rather than give the token itself, it will return the value vector. (So `x` is "private" or a "hidden" vector, and `v` is what is provided).

Remember that as the vectors are provided by separate `nn.Linear` modules, they will "learn" the outputs they need to provide via back propagation.

The weight tensor that we created in previous cells is generated by the dot product of the query vector and the transposed key vector. (Note: The more complex transpose operation below is because we are working with batched tensors). If two vectors are well aligned, they will have a high dot product. This is where the softmax comes in, it will normalise the values between 0 and 1. However, it should be noted that if the softmax receives one very high value with the rest being low values, it will converge on 1-hot vectors. To prevent this, and ensure a better variance across the weights, it is scaled by $\frac{1}{\sqrt{HeadSize}}$

Given that we still want a decoder self attention block, we'll mask out future tokens using the same lower triangular matrix as we did before. 

The parameter `head_size` ($d_k$) correlates to the dimension of the `query` and `key` vectors. In the Transformer paper, the `value` vector has a different dimension ($d_v$) but for simplicity sake, we keep it the same in this implementation.


In [62]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16 #d_k in the paper
query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

def self_attention(x):
    B, T, C = x.shape
    
    k = key(x) # (B, T, head_size)
    q = query(x) # (B, T, head_size)

    weights = q @ k.transpose(-2,-1) * (head_size**-0.5) # (B, T, 16) @ (B, 16, T) = (B, T, T)

    # Make it a decoder by masking out future tokens
    tril = torch.tril(torch.ones(T, T))
    weights = weights.masked_fill(tril == 0, float('-inf'))
    
    weights = F.softmax(weights, dim=-1)
    
    v = value(x) # (B, T, head_size)
    out = weights @ v

    return out





out = self_attention(x)
out.shape

torch.Size([4, 8, 16])