<a href="https://colab.research.google.com/github/ainsley-snell/Data_Mining_CS290/blob/main/multiple_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn

In [15]:
d_in=5
d_out=9
num_heads=3
dropout=0
context_length=3
batch_size=2
qkv_bias= False
head_dim = d_out // num_heads

# Creating input batch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89, 0.73, 0.12],
   [0.55, 0.87, 0.66, 0.65, 0.67],
   [0.57, 0.85, 0.64, 0.32, 0.13]]
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

b, num_tokens, d_in = batch.shape
x= batch

torch.Size([2, 3, 5])


In [16]:
# Linear prjections: transforming each token into query, key, and value vectors


# This creates linear layers that change tokens from d_in to d_out
W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

In [33]:
# Takes linear layers just created and applies them to each token
# creating three new vectors per token which are all size of d_out

queries = W_query(x)
keys = W_key(x)
values = W_value(x)

queries.shape
queries
keys.shape
keys
values.shape
values

tensor([[[-0.1009, -0.1575,  0.5585, -0.2063,  0.1611, -0.2919,  0.0823,
           0.3024,  0.0795],
         [-0.0973, -0.2511,  0.0275, -0.0710,  0.0127, -0.0363,  0.0175,
           0.2394,  0.2842],
         [ 0.0160, -0.3157,  0.0161, -0.0586,  0.1638, -0.0202, -0.1174,
           0.2069,  0.1298]],

        [[-0.1009, -0.1575,  0.5585, -0.2063,  0.1611, -0.2919,  0.0823,
           0.3024,  0.0795],
         [-0.0973, -0.2511,  0.0275, -0.0710,  0.0127, -0.0363,  0.0175,
           0.2394,  0.2842],
         [ 0.0160, -0.3157,  0.0161, -0.0586,  0.1638, -0.0202, -0.1174,
           0.2069,  0.1298]]], grad_fn=<UnsafeViewBackward0>)

In [34]:
# Splitting into multiple heads or mutliple attentions
# The d_out is split so there are multiple heads, in which the model can look at different parts of the input in parallel
queries = queries.view(batch_size, context_length, num_heads, head_dim)
keys = keys.view(batch_size, context_length, num_heads, head_dim)
values = values.view(batch_size, context_length, num_heads, head_dim)
queries.shape
queries
keys.shape
keys
values.shape
values

tensor([[[[-0.1009, -0.1575,  0.5585],
          [-0.2063,  0.1611, -0.2919],
          [ 0.0823,  0.3024,  0.0795]],

         [[-0.0973, -0.2511,  0.0275],
          [-0.0710,  0.0127, -0.0363],
          [ 0.0175,  0.2394,  0.2842]],

         [[ 0.0160, -0.3157,  0.0161],
          [-0.0586,  0.1638, -0.0202],
          [-0.1174,  0.2069,  0.1298]]],


        [[[-0.1009, -0.1575,  0.5585],
          [-0.2063,  0.1611, -0.2919],
          [ 0.0823,  0.3024,  0.0795]],

         [[-0.0973, -0.2511,  0.0275],
          [-0.0710,  0.0127, -0.0363],
          [ 0.0175,  0.2394,  0.2842]],

         [[ 0.0160, -0.3157,  0.0161],
          [-0.0586,  0.1638, -0.0202],
          [-0.1174,  0.2069,  0.1298]]]], grad_fn=<ViewBackward0>)

In [19]:
# Transposing the query, key, and value tensors
# Swaps the second and third dimensions so matrices can be multiplied
queries = queries.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

queries.shape
queries
keys.shape
keys
values.shape
values

torch.Size([2, 3, 3, 3])

In [35]:
# Creating attention scores with dot product of queries and keys for each head
attn_scores = queries @ keys.transpose(2, 3)
attn_scores.shape
attn_scores

tensor([[[[ 0.0337,  0.1558,  0.2383],
          [-0.1445,  0.0005, -0.1732],
          [ 0.2203, -0.1687,  0.1164]],

         [[ 0.3503, -0.0333,  0.3555],
          [-0.3120,  0.0748, -0.3921],
          [-0.0432, -0.1391,  0.3092]],

         [[ 0.1931, -0.0354,  0.2057],
          [-0.1632,  0.0893, -0.1318],
          [-0.0037,  0.0610,  0.1624]]],


        [[[ 0.0337,  0.1558,  0.2383],
          [-0.1445,  0.0005, -0.1732],
          [ 0.2203, -0.1687,  0.1164]],

         [[ 0.3503, -0.0333,  0.3555],
          [-0.3120,  0.0748, -0.3921],
          [-0.0432, -0.1391,  0.3092]],

         [[ 0.1931, -0.0354,  0.2057],
          [-0.1632,  0.0893, -0.1318],
          [-0.0037,  0.0610,  0.1624]]]], grad_fn=<UnsafeViewBackward0>)

In [23]:
# Creating a mask so tokens can't look ahead, it can only look to whats before it
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
mask_bool = mask.bool()[:num_tokens, :num_tokens]

In [42]:
# Softmax
# converts attention scores to attention weights
# each row sums to 1
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights.shape
attn_weights
row_sums = attn_weights.sum(dim=-1)
row_sums

tensor([[[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]]], grad_fn=<SumBackward1>)

In [44]:
# Context vectors
# Multiplies attention weights with values to get weighted sum of values for each token in each head
# Transposes so that heads can be combined later
context_vec = (attn_weights @ values).transpose(1, 2)
context_vec.shape
context_vec

tensor([[[[-0.1009, -0.1575,  0.5585],
          [-0.0973, -0.2511,  0.0275],
          [ 0.0160, -0.3157,  0.0161]],

         [[-0.1558,  0.0084,  0.1155],
          [-0.0827, -0.1045, -0.0080],
          [-0.0240, -0.0585, -0.0034]],

         [[-0.0687,  0.0934,  0.1460],
          [-0.0451,  0.0172,  0.1077],
          [-0.0555,  0.0262,  0.0440]]],


        [[[-0.1009, -0.1575,  0.5585],
          [-0.0973, -0.2511,  0.0275],
          [ 0.0160, -0.3157,  0.0161]],

         [[-0.1558,  0.0084,  0.1155],
          [-0.0827, -0.1045, -0.0080],
          [-0.0240, -0.0585, -0.0034]],

         [[-0.0687,  0.0934,  0.1460],
          [-0.0451,  0.0172,  0.1077],
          [-0.0555,  0.0262,  0.0440]]]], grad_fn=<TransposeBackward0>)

In [45]:
# Combines heads back together into a single vector per token
context_vec = context_vec.contiguous().view(b, num_tokens, d_out)
context_vec.shape
context_vec

tensor([[[-0.1009, -0.1575,  0.5585, -0.0973, -0.2511,  0.0275,  0.0160,
          -0.3157,  0.0161],
         [-0.1558,  0.0084,  0.1155, -0.0827, -0.1045, -0.0080, -0.0240,
          -0.0585, -0.0034],
         [-0.0687,  0.0934,  0.1460, -0.0451,  0.0172,  0.1077, -0.0555,
           0.0262,  0.0440]],

        [[-0.1009, -0.1575,  0.5585, -0.0973, -0.2511,  0.0275,  0.0160,
          -0.3157,  0.0161],
         [-0.1558,  0.0084,  0.1155, -0.0827, -0.1045, -0.0080, -0.0240,
          -0.0585, -0.0034],
         [-0.0687,  0.0934,  0.1460, -0.0451,  0.0172,  0.1077, -0.0555,
           0.0262,  0.0440]]], grad_fn=<ViewBackward0>)