<a href="https://colab.research.google.com/github/mrdbourke/learn-transformers/blob/main/attention_mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [WIP] Attention mechanism

**Focus:** Build intuition to build up to replicating the original Transformer paper.

### This notebook

* Recreate self-attention as per Transformer paper
* Recreate multi-head attention as per Transformer paper

### Later
* Recreate Transformer model architecture
* Train on a simple example

Sources:

* Transformer paper: https://arxiv.org/abs/1706.03762
* The annotated transformer: http://nlp.seas.harvard.edu/2018/04/01/attention.html
* https://lilianweng.github.io/posts/2018-06-24-attention/#self-attention
* https://jaykmody.com/blog/attention-intuition/
* Compact transformers - https://medium.com/pytorch/training-compact-transformers-from-scratch-in-30-minutes-with-pytorch-ff5c21668ed5
* Implemented MHA - https://nn.labml.ai/transformers/mha.html

## What we're going to do

Simple goals:

Replicate the following functions/modules as fast as possible:
* PyTorch's `scaled_dot_product_attention` - https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
  * Explain the intuition of the attention mechanism
* PyTorch's MultiHeadAttention - https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
  * Both of these are the same equations used in the Transformer paper
* Do with and without masking
* Use these to build on the next chatper: replicating the Transformer architecture
  * Focus on what the inputs and outputs should be (e.g. text/vision/audio, in essence, seq2seq)




In [1]:
import torch
from torch import nn

import torch.nn.functional as F

## Simple scaled-dot-product-attention (no mask)

Attention formula = `softmax((Q, K.T)/torch.sqrt(d_k))V`

$$
\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{\mathrm{T}}}{\sqrt{d_k}}\right) V
$$

TK:
- Explain what each of these values are

TK:
- Can I replicate this in Google Sheets?... yes I can... kind of (except for softmax, etc)
- Turn this function into the same format as the transformer paper (e.g. figure 2)

In [2]:
def attention(query, key, value):

  # Create the scale factor (sqrt(d_k))
  d_k = torch.sqrt(torch.tensor(query.shape[-1])) # torch.sqrt needs a tensor

  q_k = torch.matmul(query, key.mT) # .mT = matrix Transpose (transposes the last two dimensions)
  q_k_softmax_scale = F.softmax(q_k/torch.sqrt(d_k), dim=-1)
  q_k_softmax_scale_v = torch.matmul(q_k, value)
  return q_k_softmax_scale_v

In [3]:
torch.manual_seed(42)
x = torch.randn(3, 3)

output_custom = attention(query=x, key=x, value=x)
output_custom

tensor([[ 1.7361, -0.3428,  0.4193],
        [ 2.7884, -2.2552,  0.2486],
        [12.6584, -4.6864,  2.5056]])

In [4]:
# Does this equal PyTorch's scaled_dot_product_attention?
output_pytorch = F.scaled_dot_product_attention(query=x,
                                                key=x,
                                                value=x)
output_pytorch

tensor([[ 1.1175, -0.5276,  0.2233],
        [ 1.0066, -0.7048,  0.1397],
        [ 1.9621, -0.6285,  0.4030]])

In [5]:
# Should output true
torch.all(output_pytorch.isclose(output_custom))

tensor(False)

## TODO: What is a query, key and value?

* What if I told you you already know about the attention mechanism?... and your local cafe owner knows it very well

* Give an example of different values input and output into our attention mechansim

TK - Can I do sales of different products? Does this relate?

E.g.
* query = sales on monday
* key = product
* value = amount?

Does this work??

In [8]:
# TODO: finalize this and upload it to GitHub (if it works)
#!wget 
url = "https://www.dropbox.com/s/8heqlnrpkf7tlbq/cafe_sales_data_csv.xlsx"

In [12]:
import pandas as pd
df = pd.read_excel(url) # TODO: read_excel with _csv in the filename is confusing...
df

URLError: <urlopen error [Errno 11002] getaddrinfo failed>

In [None]:
# Create price vector
price_dict = {
    "coffee": 5,
    "bread": 8,
    "bacon": 15,
    "milk": 4,
    "bagel": 9,
    "sandwich": 12,
    "croissant": 8
}
price_vector = torch.tensor(list(price_dict.values()), dtype=torch.float32)
price_vector

tensor([ 5.,  8., 15.,  4.,  9., 12.,  8.])

In [None]:
# Create sales matrix
sales_matrix = torch.tensor(df.drop("Unnamed: 0", axis=1).values, dtype=torch.float32)
sales_matrix

tensor([[  0.,  50.,  55.,  68.,  91., 107.,  84.],
        [  0.,  20.,  22.,  25.,  12.,  40.,  49.],
        [  0.,  10.,  15.,  20.,  10.,  65.,  39.],
        [  0.,  15.,  15.,  18.,  16.,  51.,  45.],
        [  0.,  21.,   8.,  20.,  60.,  56.,  44.],
        [  0.,   9.,   8.,  50.,  18.,  62.,  50.],
        [  0.,  11.,   4.,   3.,   7.,  49.,  55.]])

In [None]:
print(f"Sales: {sales_matrix.shape} (seven products, seven days of week)")
print(f"Prices: {price_vector.shape} (seven products)")

Sales: torch.Size([7, 7]) (seven products, seven days of week)
Prices: torch.Size([7]) (seven products)


In [None]:
# Find the sales per day
price_vector.matmul(sales_matrix)

tensor([   0., 1005.,  936., 1716., 1577., 3674., 3013.])

In [None]:
price_vector.unsqueeze(1)

tensor([[ 5.],
        [ 8.],
        [15.],
        [ 4.],
        [ 9.],
        [12.],
        [ 8.]])

In [None]:
price_vector

tensor([ 5.,  8., 15.,  4.,  9., 12.,  8.])

In [None]:
# WRONG: Sales per item per day
price_vector * sales_matrix

tensor([[   0.,  400.,  825.,  272.,  819., 1284.,  672.],
        [   0.,  160.,  330.,  100.,  108.,  480.,  392.],
        [   0.,   80.,  225.,   80.,   90.,  780.,  312.],
        [   0.,  120.,  225.,   72.,  144.,  612.,  360.],
        [   0.,  168.,  120.,   80.,  540.,  672.,  352.],
        [   0.,   72.,  120.,  200.,  162.,  744.,  400.],
        [   0.,   88.,   60.,   12.,   63.,  588.,  440.]])

In [None]:
# CORRECT: Manipulate price vector before multiplying to sales matrix
price_vector.unsqueeze(1) * sales_matrix

tensor([[  0., 250., 275., 340., 455., 535., 420.],
        [  0., 160., 176., 200.,  96., 320., 392.],
        [  0., 150., 225., 300., 150., 975., 585.],
        [  0.,  60.,  60.,  72.,  64., 204., 180.],
        [  0., 189.,  72., 180., 540., 504., 396.],
        [  0., 108.,  96., 600., 216., 744., 600.],
        [  0.,  88.,  32.,  24.,  56., 392., 440.]])

In [None]:
# TK note:
# matmul = sum over dim=0 -> item sales per day
# manual setup = sum over dim=1 -> item sales per week
total_item_sales_per_week = torch.sum(price_vector.unsqueeze(1) * sales_matrix, dim=1)
total_item_sales_per_week

tensor([2275., 1344., 2385.,  640., 1881., 2364., 1032.])

In [None]:
total_sales_per_day = torch.sum(price_vector.unsqueeze(1) * sales_matrix, dim=0)
total_sales_per_day

tensor([   0., 1005.,  936., 1716., 1577., 3674., 3013.])

In [None]:
# TODO: try einsum? or einops?

TK - try create an example for attention

In [None]:
# Create a query ("What are the sales on Wednesday?")
sales_on_wednesday_vector = torch.zeros(7) # days of week
sales_on_wednesday_vector[2] = 1
sales_on_wednesday_vector

tensor([0., 0., 1., 0., 0., 0., 0.])

In [None]:
# Compare the sales matrix (key) to the query (Q * K.T)
wednesday_sales = sales_on_wednesday_vector.matmul(sales_matrix.T)
wednesday_sales

tensor([55., 22., 15., 15.,  8.,  8.,  4.])

Why scale?

Watch this... softmax blows it out of the water...

In [None]:
F.softmax(wednesday_sales, dim=0)

tensor([1.0000e+00, 4.6589e-15, 4.2484e-18, 4.2484e-18, 3.8740e-21, 3.8740e-21,
        7.0955e-23])

Now scale...

In [None]:
wednesday_sales.shape

torch.Size([7])

In [None]:
F.softmax(wednesday_sales / torch.sqrt(torch.tensor(wednesday_sales.shape[0])), dim=0)

tensor([1.0000e+00, 3.8293e-06, 2.7170e-07, 2.7170e-07, 1.9277e-08, 1.9277e-08,
        4.2507e-09])

Still blown out of the water but better... (could normalize these values first)

e.g.

```python
norm_tensor = (x - torch.min(x))/(torch.max(x) - torch.min(x))
```

In [None]:
F.softmax(wednesday_sales / torch.sqrt(torch.tensor(wednesday_sales.shape[0])), dim=0) @ price_vector

tensor(5.0000)

In [None]:
# Total sales on Wednesday
attention_to_pay_on_wednesdays = wednesday_sales @ price_vector # price_vector = value
attention_to_pay_on_wednesdays

tensor(936.)

### TODO: Try normalizing values

See:

In [None]:
# NEXT:
# Try normalizing the tensor values and see how they change/improve stability
# Softmax on values with large differences = blows larger values out of the water (e.g. 1.0 vs 1e-10... basically nothing)

In [None]:
from sklearn.preprocessing import MinMaxScaler
min_max = MinMaxScaler()
sales_matrix_normalized = min_max.fit(sales_matrix).transform(sales_matrix)
sales_matrix_normalized

array([[0.        , 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        ],
       [0.        , 0.26829268, 0.35294118, 0.33846154, 0.05952381,
        0.        , 0.22222222],
       [0.        , 0.02439024, 0.21568627, 0.26153846, 0.03571429,
        0.37313433, 0.        ],
       [0.        , 0.14634146, 0.21568627, 0.23076923, 0.10714286,
        0.1641791 , 0.13333333],
       [0.        , 0.29268293, 0.07843137, 0.26153846, 0.63095238,
        0.23880597, 0.11111111],
       [0.        , 0.        , 0.07843137, 0.72307692, 0.13095238,
        0.32835821, 0.24444444],
       [0.        , 0.04878049, 0.        , 0.        , 0.        ,
        0.13432836, 0.35555556]])

In [None]:
def min_max_normalize_tensor(x):
  """
  See: https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization)
  """

  return (x - torch.min(x)) / (torch.max(x) - torch.min(x))

In [None]:
price_vector

tensor([ 5.,  8., 15.,  4.,  9., 12.,  8.])

In [None]:
sales_matrix_normalized = min_max_normalize_tensor(sales_matrix)
price_vector_normalized = min_max_normalize_tensor(price_vector)
sales_matrix_normalized, price_vector_normalized

(tensor([[0.0000, 0.4673, 0.5140, 0.6355, 0.8505, 1.0000, 0.7850],
         [0.0000, 0.1869, 0.2056, 0.2336, 0.1121, 0.3738, 0.4579],
         [0.0000, 0.0935, 0.1402, 0.1869, 0.0935, 0.6075, 0.3645],
         [0.0000, 0.1402, 0.1402, 0.1682, 0.1495, 0.4766, 0.4206],
         [0.0000, 0.1963, 0.0748, 0.1869, 0.5607, 0.5234, 0.4112],
         [0.0000, 0.0841, 0.0748, 0.4673, 0.1682, 0.5794, 0.4673],
         [0.0000, 0.1028, 0.0374, 0.0280, 0.0654, 0.4579, 0.5140]]),
 tensor([0.0909, 0.3636, 1.0000, 0.0000, 0.4545, 0.7273, 0.3636]))

### TODO: Try standardizing

See: https://en.wikipedia.org/wiki/Feature_scaling#Standardization_(Z-score_Normalization)

In [None]:
torch.std(x, dim=0)

tensor([1.1125, 0.6311, 0.3288])

In [None]:
def standardize_tensor(x):
  """
  See: https://en.wikipedia.org/wiki/Feature_scaling#Standardization_(Z-score_Normalization)
  """
  return (x - torch.mean(x)) / torch.std(x)

In [None]:
sales_matrix_standardized = standardize_tensor(sales_matrix)
price_vector_standardized = standardize_tensor(price_vector)
price_vector_standardized

tensor([-0.9730, -0.1871,  1.6467, -1.2350,  0.0748,  0.8608, -0.1871])

In [None]:
price_vector

tensor([ 5.,  8., 15.,  4.,  9., 12.,  8.])

In [None]:
# Create a query ("What are the sales on Wednesday?")
sales_on_wednesday_vector = torch.zeros(7) # days of week
sales_on_wednesday_vector[2] = 1
sales_on_wednesday_vector

tensor([0., 0., 1., 0., 0., 0., 0.])

In [None]:
sales_matrix_standardized

tensor([[-1.1194,  0.7374,  0.9231,  1.4059,  2.2600,  2.8542,  2.0000],
        [-1.1194, -0.3767, -0.3024, -0.1910, -0.6738,  0.3661,  0.7003],
        [-1.1194, -0.7480, -0.5623, -0.3767, -0.7480,  1.2945,  0.3289],
        [-1.1194, -0.5623, -0.5623, -0.4509, -0.5252,  0.7745,  0.5517],
        [-1.1194, -0.3395, -0.8223, -0.3767,  1.1088,  0.9602,  0.5146],
        [-1.1194, -0.7852, -0.8223,  0.7374, -0.4509,  1.1830,  0.7374],
        [-1.1194, -0.7109, -0.9708, -1.0080, -0.8594,  0.7003,  0.9231]])

In [None]:
F.softmax(sales_on_wednesday_vector.unsqueeze(0).matmul(sales_matrix_standardized.T)/7, dim=1) @ price_vector_standardized

tensor([-0.0365])

In [None]:
F.softmax(sales_on_wednesday_vector.unsqueeze(0).matmul(sales_matrix.T)/7, dim=1) @ price_vector

tensor([5.0707])

In [None]:
sales_on_monday = torch.zeros(7)
sales_on_monday[0] = 1
print(sales_on_monday)
F.softmax(sales_on_monday.unsqueeze(0).matmul(sales_matrix.T)/7, dim=1) @ price_vector

tensor([1., 0., 0., 0., 0., 0., 0.])


tensor([8.7143])

In [None]:
# Non-standardize

In [None]:
sales_on_monday = torch.zeros(7)
sales_on_monday[5] = 1
print(sales_on_monday)
F.softmax(sales_on_monday.unsqueeze(0).matmul(sales_matrix_standardized.T)/7, dim=1) @ price_vector_standardized

tensor([0., 0., 0., 0., 0., 1., 0.])


tensor([-0.0193])

In [None]:
day_names = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
day_values = range(0, 7)
day_dict = dict(zip(day_values, day_names))
day_dict

{0: 'Monday',
 1: 'Tuesday',
 2: 'Wednesday',
 3: 'Thursday',
 4: 'Friday',
 5: 'Saturday',
 6: 'Sunday'}

In [None]:
# Non-standardize
for i in range(7):
  day_tensor = torch.zeros(7)
  day_tensor[i] = 1
  day_name = day_dict[i]

  print(f"\nDay name: {day_name}")
  print(f"Day tensor: {day_tensor}")
  # day_tensor_standardize = standardize_tensor(day_tensor)
  d_k = day_tensor.shape[-1]
  print(d_k)
  attn_score = F.softmax(day_tensor.unsqueeze(1).matmul(sales_matrix.T)/torch.sqrt(torch.tensor(7)), dim=1) @ price_vector
  print(f"Attention score: {attn_score}")


Day name: Monday
Day tensor: tensor([1., 0., 0., 0., 0., 0., 0.])
7


RuntimeError: ignored

In [None]:
# Standardize
# TODO: does this work for the triangle matrix?
# e.g. triangle down the left to bottom right corner
for i in range(7):
  day_tensor = torch.zeros(7)
  day_tensor[i] = 1
  day_name = day_dict[i]

  print(f"\nDay name: {day_name}")
  print(f"Day tensor: {day_tensor}")
  day_tensor_standardize = standardize_tensor(day_tensor)
  attn_score = F.softmax(day_tensor_standardize.unsqueeze(0).matmul(sales_matrix_standardized.T)/torch.sqrt(torch.tensor(7)), dim=1) @ price_vector_standardized
  print(f"Attention score: {attn_score}")

In [None]:
all_days = torch.eye(7)

attention(query=all_days,
          key=sales_matrix,
          value=price_vector)

In [None]:
all_days = torch.eye(7)
all_days_standardized = standardize_tensor(all_days)

attention(query=all_days_standardized,
          key=sales_matrix_standardized,
          value=price_vector_standardized.unsqueeze(1))

In [None]:
F.scaled_dot_product_attention(query=all_days_standardized,
                               key=sales_matrix_standardized,
                               value=price_vector_standardized.unsqueeze(1))

In [None]:
standardize_tensor_day = standardize_tensor(sales_on_monday)
standardize_tensor_day

In [None]:
# Combine with the key (<q, k.T>)
wednesday_sales = sales_on_wednesday_vector.matmul(sales_matrix_standardized.T)
wednesday_sales

In [None]:
sales_matrix

In [None]:
# Softmax on the sales (not so blown out! ... once the values were standardized)
wednesday_sales = F.softmax(wednesday_sales, dim=0)
wednesday_sales

In [None]:
# Scale on the sales
wednesday_sales = wednesday_sales/torch.sqrt(torch.tensor(wednesday_sales.shape[0]))
wednesday_sales

In [None]:
# Multiply by the value to get the attention
attention_to_pay_on_wednesdays = wednesday_sales @ price_vector_standardized
attention_to_pay_on_wednesdays

In [None]:
price_vector_standardized

In [None]:
each_day = torch.eye(7)
each_day_standardized = standardize_tensor(each_day)
each_day, each_day_standardized

In [None]:
# Attention to pay each day
each_day = torch.ones((7))

def attention(query, key, value):
  d_k = torch.tensor(query.shape[-1]) # torch.sqrt needs a tensor
  print(d_k)
  q_k = F.softmax(torch.matmul(query, key.T)/torch.sqrt(d_k), dim=-1)
  print(q_k.shape)
  return torch.matmul(q_k, value)

attention(query=each_day_standardized,
          key=sales_matrix_standardized,
          value=price_vector_standardized)

In [None]:
# Attention to pay each day
monday = torch.zeros(7)
monday[0] = 1
monday_standardized = standardize_tensor(monday)
print(monday)
print(monday_standardized)

def attention(query, key, value):
  d_k = torch.tensor(query.shape[-1]) # torch.sqrt needs a tensor
  print(d_k)
  q_k = F.softmax(torch.matmul(query, key.T)/torch.sqrt(d_k), dim=-1)
  print(q_k.shape)
  return torch.matmul(q_k, value)

attention(query=monday_standardized,
          key=sales_matrix_standardized,
          value=price_vector_standardized)

In [None]:
# UPTOHERE
# NEXT: clean up all of the above so it makes sense... in a bit of a mess now
# Less but better...

In [None]:
# query = day of week
# key = sales per day
# value = prices of products
# result/output = value of total products sold on target day (how much attention to pay to a certain day)

# TK - if you wanted to get more information, you could increase the cafe sales to (52, 7) -> sales per day for 52 weeks in a year
# -> or (5, 52, 7) (year, weeks, days) -> sales per day per week for 5 years

# TODO:

# How does this relate to attention?

# At a large enough scale, you can do this for words in sentences.
# For example, say we have 100 sentences.
# Which words mean the most to which other words?
# In a small sample like this, you might be able to design fixed values (like our coffee shop for different products).
# But with a larger scale, you might have to design different values for different contexts.
# The analogy being in a coffee shop in Australia, your pricing and products might be different to a coffee shop in Africa.
# With a large enough corpus of words, having fixed values isn't going to work.
# But the principle remains, how important is each other word to another word in a sentence?
# What should you do?
# Well, you'd never have time to assign a value for each word across a huge corpus.
# So you can make the values for each word learnable.
# Much like you might adjust your cafe prices and sales events given different days of the week.
# The sales on Monday are very low (zero) because your cafe is closed on Monday.
# Your customers don't assign much money (or attention) to your cafe on Monday's since they know it's closed.
# Much like the attention score for the word "cat" might be very low in comparison to the word "sodium metabisulfite" because the two hardly ever occur in context of each other.

# Why self-attention?

# Applying the mechanism to itself over and over for different sequences enables the system to learn from the data itself.
# As in, what words keep on showing up in the context of other words?
# The dot product/matrix multiplication will amplify larger values.
# In essence, given the query "dog" and the key of every word in the vocabulary, hopefully the model will learn to return "cat" as a likely value and "sodium metabisulfite" as a less likely value.

## TODO: Simple scaled-dot-product-attention (with mask)

Next:
- Read through GPT-from-scratch again
- Read through Facebook's xformers
- Read through Transformers from scratch blog post

--

* see: https://jaykmody.com/blog/gpt-from-scratch/#causal
* And see: https://github.com/facebookresearch/xformers/blob/main/xformers/components/attention/attention_mask.py
  * Default to causal mask: https://github.com/facebookresearch/xformers/blob/97daac83cece6d3d77bb09479777ad6e8ef7dfed/xformers/components/attention/attention_mask.py#LL74C16-L74C16 (`make_causal()`)

In [None]:
# Make causal mask, see: https://jaykmody.com/blog/gpt-from-scratch/#causal
additive_mask = torch.triu(
    # torch.ones(x.shape[0], x.shape[0]) * float("inf"),
    torch.ones(x.shape[0], x.shape[0]) * -1e10, # can use -1e10 to prevent nans
    diagonal=1
)

additive_mask

In [None]:
def attention_with_mask(query, key, value, mask=None):
  d_k = torch.tensor(query.shape[-1]) # torch.sqrt needs a tensor
  q_k = torch.matmul(query, key.T) / torch.sqrt(d_k)
  print(q_k.shape)


  print(f"q_k: {q_k}")

  # Apply attention mask
  if mask is not None:
    q_k = q_k + mask

  print(f"q_k with mask: {q_k}")

  # Softmax
  attn = F.softmax(q_k, dim=-1)

  return torch.matmul(attn, value), attn

attention_with_mask(query=x, key=x, value=x, mask=additive_mask)

## TODO: Why scaled?

TL;DR softmax can get out of hand with large values

In [None]:
small_values = torch.tensor([1, 2, 3], dtype=torch.float32) # need dtype otherwise error
big_values = small_values * 10
huge_values = big_values * 10

small_softmax = F.softmax(small_values, dim=0)
big_softmax = F.softmax(big_values, dim=0)
huge_softmax = F.softmax(huge_values, dim=0)

print(f"Small values: {small_values}\nSmall softmax: {small_softmax}\n")
print(f"Big values: {big_values}\nBig softmax: {big_softmax}\n")
print(f"Huge values: {huge_values}\nHuge softmax: {huge_softmax}\n")

## TODO: Why dot-product?

TL;DR dot product measures how closely two vectors are related

* big values = close
* negative values = far away
* zero value = same direction? (TK - fix this)

See:
* 3blue1brown on dot product - https://www.youtube.com/watch?v=LyGKycYT2v0

In [None]:
vector_1 = torch.arange(0, 1, 0.1)
vector_2 = torch.ones_like(vector_1) / 10
vector_3 = vector_1 - 0.1
vector_4 = -vector_1
print(vector_2)

import matplotlib.pyplot as plt

plt.plot(vector_1, label="vector_1")
plt.plot(vector_2, label="vector_2")
plt.plot(vector_3, label="vector_3")
plt.plot(vector_4, label="vector_4")
plt.legend();

In [None]:
cosine_sim = nn.CosineSimilarity(dim=0)
cosine_sim(vector_1, vector_1)

In [None]:
torch.dot(vector_1, vector_1)

In [None]:
torch.dot(vector_1, vector_2)

In [None]:
torch.dot(vector_1, vector_3)

In [None]:
torch.dot(vector_2, vector_3)

In [None]:
torch.dot(vector_1, vector_4)

In [None]:
# Sales example (total sales on Wednesday)
torch.dot(price_vector, sales_matrix[:, 2])

In [None]:
# Same as taking the multiple and then summing them
torch.sum(price_vector * df["Wednesday"].values, dtype=torch.float32)

In [None]:
# Compresses the information into a single number
cosine_sim(price_vector, price_vector)

In [None]:
import numpy as np
np.dot(price_vector, price_vector)

## TODO: Replicate PyTorch's `scaled_dot_product_attention`

(minus all the fancy optimizations, the library can do those for us)

See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

Also see: https://github.com/facebookresearch/xformers/blob/main/xformers/components/attention/core.py#L297

In [None]:
# Optionally use the context manager to ensure one of the fused kerenels is run
torch.manual_seed(42)

query = torch.rand(32, 8, 128, 64) # [batch_size, num_heads, sequence_length, embedding_dim]
key = torch.rand(32, 8, 128, 64)
value = torch.rand(32, 8, 128, 64)


In [None]:
output_pytorch = F.scaled_dot_product_attention(query, key, value)
print(output_pytorch.shape)
print(output_pytorch[0, 0, 0])

torch.Size([32, 8, 128, 64])
tensor([0.5430, 0.5479, 0.5143, 0.4744, 0.5149, 0.4867, 0.5063, 0.5088, 0.4863,
        0.4620, 0.4989, 0.5488, 0.4746, 0.4955, 0.5334, 0.4886, 0.5158, 0.5267,
        0.5183, 0.5251, 0.4939, 0.5092, 0.5408, 0.4267, 0.4645, 0.5221, 0.5587,
        0.4917, 0.5142, 0.4762, 0.4839, 0.4837, 0.4937, 0.4671, 0.4898, 0.5195,
        0.4942, 0.4938, 0.4783, 0.4796, 0.5454, 0.4686, 0.5112, 0.5717, 0.5081,
        0.4588, 0.5151, 0.4970, 0.4649, 0.5143, 0.5019, 0.5053, 0.4928, 0.5278,
        0.5332, 0.5121, 0.4882, 0.4992, 0.5197, 0.4865, 0.5028, 0.4908, 0.4975,
        0.4808])


In [None]:
# Does this have learnable parameters?
output_pytorch.requires_grad

False

In [None]:
def attention(query, key, value):
  d_k = torch.tensor(query.shape[-1]) # torch.sqrt needs a tensor
  print(torch.matmul(query, key.mT).shape)
  print(key.mT.shape)
  print(key.transpose(-2, -1).shape)

  # tensor.mT is equivalent to tensor.transpose(-2, -1), see: https://pytorch.org/docs/stable/tensors.html#torch.Tensor.mT
  # -> last two dimensions reversed
  q_k = F.softmax(torch.matmul(query, key.mT)/torch.sqrt(d_k), dim=-1)

  print(d_k)

  print(d_k.shape, q_k.shape, value.shape)
  print(q_k.shape, value.mT.shape)

  return torch.matmul(q_k, value)

output_custom = attention(query, key, value)
print(output_pytorch.shape)
# print(output_pytorch[0, 0, 0])

torch.Size([32, 8, 128, 128])
torch.Size([32, 8, 64, 128])
torch.Size([32, 8, 64, 128])
tensor(64)
torch.Size([]) torch.Size([32, 8, 128, 128]) torch.Size([32, 8, 128, 64])
torch.Size([32, 8, 128, 128]) torch.Size([32, 8, 64, 128])
torch.Size([32, 8, 128, 64])


In [None]:
# Assert all of the output values are close
assert torch.all(output_custom.isclose(output_pytorch))

## TODO: Why "self" attention

TL;DR given a sequence, which parts of the sequence are most important based on the sequence itself

Self-attention = based on its own input how should its representation differ

Eg the word “cup” should have a different representation given the sentences:

* “England won the World Cup”
* “I filled my cup with orange juice”

Same word, but different contexts - self-attention will adjust the weight values given the other items in the context.

## TODO: Make it learnable

* TK - make this section more clear
* TK - see section 3.2.2 Multi-Head Attention for "where the **projections** are parameter matrics `WQ`, `WK`, `WV` etc
  * Projections = learnable embedding = linear projection

Okay cool, we've replicated PyTorch's `scaled_dot_product_attention`.

But right now it's just a static operaton.

And the whole goal of machine learning is to write algortihms that *learn* over time.

So we need to make a learnable version of attention mechansim.

How?

Projections!

What?

Projections into embedding space.

An embedding is a learnable representation of something.

And so instead of a static vector representing our data, we can turn it into an embedding and create a *learnable vector* (a vector that changes over time given new information).

In [None]:
class SelfAttentionLearnable(nn.Module):
  def __init__(self, embed_dim):
    super().__init__()
    self.embed_dim = embed_dim
    self.scale = embed_dim ** -0.5 # power of -0.5 == same as square root

    # TK - one option = make one big projection (e.g. embed_dim * 3), then reshape = faster (on bigger GPUs)
    # TK - another option = make one projection per Q, K, V

    # Create a projection (learnable embedding)
    self.qkv = nn.Linear(in_features=embed_dim,
                         out_features=embed_dim * 3, # one per Q, K, V
                         bias=False)

  def forward(self, x):
    B, N, _ = x.shape

    # Do the projection
    qkv = self.qkv(x)
    print(f"qkv shape: {qkv.shape}")

    # qkv = qkv.reshape(B, N, 3, self.embed_dim).permute(1, # qkv
    #                                                    0, # batch
    #                                                    2, # num_tokens
    #                                                    3) # embed_dim

    qkv = qkv.reshape(B, N, 3, self.embed_dim).permute(2, # qkv
                                                       0, # batch
                                                       1, # num_tokens
                                                       3) # embed_dim

    print(f"qkv shape: {qkv.shape}")

    # TODO: replace the above with einops?

    q, k, v = qkv[0], qkv[1], qkv[2]

    print(f"q shape: {q.shape} | k shape: {k.shape} | v shape: {v.shape}")

    # Perform self-attention (self = the operation happens on itself)
    q_k = torch.matmul(q, k.mT)
    print(f"q_k shape: {q_k.shape}")

    q_k_scale = q_k * self.scale

    # Softmax on embedding dim (last dim)
    q_k_scale_softmax = torch.softmax(q_k_scale, dim=-1)

    # Pefrom final batch mm
    q_k_scale_softmax_v = torch.matmul(q_k_scale_softmax, v)

    print(f"q_k_scale_softmax_v output shape: {q_k_scale_softmax_v.shape}")

    # TODO: Try this with einops rearrange
    x = q_k_scale_softmax_v.transpose(1, 2).reshape(B, N, self.embed_dim)

    print(f"x output shape: {x.shape}")

    return x

embed_dim = 512
batch_size = 32
num_tokens = 128
attention_learnable = SelfAttentionLearnable(embed_dim=embed_dim)
x = torch.arange(batch_size*num_tokens*embed_dim, dtype=torch.float32).reshape(batch_size, num_tokens, embed_dim)
print(f"x input shape: {x.shape}")
x_out = attention_learnable(x)

x input shape: torch.Size([32, 128, 512])
qkv shape: torch.Size([32, 128, 1536])
qkv shape: torch.Size([3, 32, 128, 512])
q shape: torch.Size([32, 128, 512]) | k shape: torch.Size([32, 128, 512]) | v shape: torch.Size([32, 128, 512])
q_k shape: torch.Size([32, 128, 128])
q_k_scale_softmax_v output shape: torch.Size([32, 128, 512])
x output shape: torch.Size([32, 128, 512])


Now try learnable attention with einops...

In [None]:
# What about einops?
!pip install einops

In [None]:
from einops import rearrange, reduce, repeat

In [None]:
class SelfAttentionLearnable(nn.Module):
  def __init__(self, embed_dim):
    super().__init__()
    self.embed_dim = embed_dim
    self.scale = embed_dim ** -0.5 # power of -0.5 == same as square root

    # Create a projection (learnable embedding) for each input (x -> Q, K, V)
    self.q_projection = nn.Linear(in_features=embed_dim,
                                  out_features=embed_dim,
                                  bias=False)

    self.k_projection = nn.Linear(in_features=embed_dim,
                                  out_features=embed_dim,
                                  bias=False)

    self.v_projection = nn.Linear(in_features=embed_dim,
                                  out_features=embed_dim,
                                  bias=False)

  def forward(self, x):
    B, N, _ = x.shape

    # Do the projection(s)
    q = self.q_projection(x)
    k = self.k_projection(x)
    v = self.v_projection(x)

    print(f"q shape: {q.shape} | k shape: {k.shape} | v shape: {v.shape}")

    # Perform self-attention (self = the operation happens on itself)
    q_k = torch.matmul(q, k.mT)
    print(f"q_k shape: {q_k.shape}")

    q_k_scale = q_k * self.scale

    # Softmax on embedding dim (last dim)
    q_k_scale_softmax = torch.softmax(q_k_scale, dim=-1)

    # Pefrom final batch mm
    q_k_scale_softmax_v = torch.matmul(q_k_scale_softmax, v)

    print(f"q_k_scale_softmax_v output shape: {q_k_scale_softmax_v.shape}")

    # # TODO: Try this with einops rearrange
    # x = q_k_scale_softmax_v.transpose(1, 2).reshape(B, N, self.embed_dim)

    print(f"x output shape: {x.shape}")

    return x

embed_dim = 512
batch_size = 32
num_tokens = 128
attention_learnable = SelfAttentionLearnable(embed_dim=embed_dim)
x = torch.arange(batch_size*num_tokens*embed_dim, dtype=torch.float32).reshape(batch_size, num_tokens, embed_dim)
print(f"x input shape: {x.shape}")
x_out = attention_learnable(x)

x input shape: torch.Size([32, 128, 512])
q shape: torch.Size([32, 128, 512]) | k shape: torch.Size([32, 128, 512]) | v shape: torch.Size([32, 128, 512])
q_k shape: torch.Size([32, 128, 128])
q_k_scale_softmax_v output shape: torch.Size([32, 128, 512])
x output shape: torch.Size([32, 128, 512])


In [None]:
# UPTOHERE:
# Replace forward() with F.scaled_dot_product_attention...
# by the end you should know the attention formula off by heart (it's not too hard... a few variables + a few operations), getting the shapes to line up is the hard part

In [None]:
class SelfAttentionLearnableCustom(nn.Module):
  def __init__(self, embed_dim):
    super().__init__()

    self.embed_dim = embed_dim

    # Create scale
    self.scale = embed_dim ** -0.5

    # Create projections
    self.query_projection = nn.Linear(in_features=embed_dim,
                                      out_features=embed_dim,
                                      bias=False)

    self.key_projection = nn.Linear(in_features=embed_dim,
                                    out_features=embed_dim,
                                    bias=False)

    self.value_projection = nn.Linear(in_features=embed_dim,
                                      out_features=embed_dim,
                                      bias=False)

  def forward(self, x):
    query, key, value = self.query_projection(x), self.key_projection(x), self.value_projection(x)

    # Perform scaled_dot_production_attention
    attn = attention(query=query,
                     key=key,
                     value=value)

    return attn

In [None]:
class SelfAttentionLearnablePyTorch(nn.Module):
  def __init__(self, embed_dim):
    super().__init__()

    self.embed_dim = embed_dim

    # Create scale
    self.scale = embed_dim ** -0.5

    # Create projections
    self.query_projection = nn.Linear(in_features=embed_dim,
                                      out_features=embed_dim,
                                      bias=False)

    self.key_projection = nn.Linear(in_features=embed_dim,
                                    out_features=embed_dim,
                                    bias=False)

    self.value_projection = nn.Linear(in_features=embed_dim,
                                      out_features=embed_dim,
                                      bias=False)

  def forward(self, x):
    query, key, value = self.query_projection(x), self.key_projection(x), self.value_projection(x)

    # Perform scaled_dot_production_attention
    attn = F.scaled_dot_product_attention(query=query,
                                          key=key,
                                          value=value)

    return attn

In [None]:
# Make sure the outcomes are the same

torch.manual_seed(42)
self_attention_custom = SelfAttentionLearnableCustom(embed_dim=512)

torch.manual_seed(42)
self_attention_pytorch = SelfAttentionLearnablePyTorch(embed_dim=512)

x = torch.arange(batch_size*num_tokens*embed_dim, dtype=torch.float32).reshape(batch_size, num_tokens, embed_dim)
print(f"x input shape: {x.shape}")

x_out_custom = self_attention_custom(x)

x_out_pytorch = self_attention_pytorch(x)
x_out_custom.shape, x_out_pytorch.shape

x input shape: torch.Size([32, 128, 512])
torch.Size([32, 128, 128])
torch.Size([32, 512, 128])
torch.Size([32, 512, 128])
tensor(512)
torch.Size([]) torch.Size([32, 128, 128]) torch.Size([32, 128, 512])
torch.Size([32, 128, 128]) torch.Size([32, 512, 128])


(torch.Size([32, 128, 512]), torch.Size([32, 128, 512]))

In [None]:
x_out_pytorch[0][0][0]

tensor(104.0217, grad_fn=<SelectBackward0>)

In [None]:
x_out_custom[0][0][0]

tensor(104.0217, grad_fn=<SelectBackward0>)

In [None]:
torch.all(x_out_custom.isclose(x_out_pytorch))

tensor(True)

## TODO: A cool trick with `einops`

In [None]:
# What about einops?
!pip install einops

In [None]:
from einops import rearrange, reduce, repeat

print(f"Key shape: {key.shape} [batch, num_heads, input_sequence, embedding_dim]")

### The following all do the same ###

# Rearrange the shape for our use case
key_rearranged = rearrange(key, 'batch heads input embed -> batch heads embed input')

# Key tranposed
key_transposed = key.transpose(-2, -1)

# Key mT (note: use .mT rather than .T on tensors with more than two dimensions)
key_mt = key.mT

print(key_rearranged.shape, key_transposed.shape, key_mt.shape)
assert key_rearranged.shape == key_transposed.shape == key_mt.shape

## TODO: Why multi-head attention?

TL;DR more opportunities to learn (e.g. 8x64 scaled dot-product attention = better than 1*512)

TK:
- One big matrix multiplication better than lots of small ones
- Just perform a `nn.Linear()` then break it up
- Implementing multi-head attention gives you self-attention, because you just use 1 head (versus multiple)

**Goal:** Replicate PyTorch's `torch.nn.MultiheadAttention()` - https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

In [None]:
multi_head_attention_pytorch = torch.nn.MultiheadAttention(embed_dim=512,
                                                           num_heads=8,
                                                           batch_first=True, # Does your batch dimension come first?
                                                           )

attn_output, attn_output_weights = multi_head_attention_pytorch(query=x, key=x, value=x,
                                                                need_weights=True) # Return weights or not
x.shape, attn_output.shape, attn_output_weights.shape

(torch.Size([32, 128, 512]),
 torch.Size([32, 128, 512]),
 torch.Size([32, 128, 128]))

In [None]:
attn_output.requires_grad

True

In [None]:
# TK - embed_dim, num_heads = minimum viable variables (masking can come later)
class MultiheadAttentionCustom(nn.Module):
  def __init__(self,
               embed_dim,
               num_heads,
               # TK - dropout
               ):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

    self.head_dim = embed_dim // num_heads
    self.scale = self.head_dim ** -0.5 # "to the power" is same as squareroot

    self.softmax = nn.Softmax(dim=-1) # perform softmax on embedding dim (last dim)

    # TK - see bias parameter in docs: https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
    # "bias – If specified, adds bias to input / output projection layers. Default: True."
    self.query_projection = nn.Linear(embed_dim, embed_dim, bias=True)
    self.key_projection = nn.Linear(embed_dim, embed_dim, bias=True)
    self.value_projection = nn.Linear(embed_dim, embed_dim, bias=True)

    # Project out
    self.project_out = nn.Linear(embed_dim, embed_dim, bias=True)

    # TODO: dropout
    # TODO: masking

  def forward(self, x):
    batch_size, num_tokens, embed_dim = x.shape

    # Project_in (linear)
    query, key, value = self.query_projection(x), self.key_projection(x), self.value_projection(x)

    print(f"Query shape: {query.shape} | Key shape: {key.shape} | Value shape: {value.shape}")

    # Convert to num heads
    query = query.reshape(batch_size, num_tokens, self.num_heads, self.head_dim)
    key = key.reshape(batch_size, num_tokens, self.num_heads, self.head_dim)
    value = value.reshape(batch_size, num_tokens, self.num_heads, self.head_dim)

    print(f"Query shape (heads): {query.shape} | Key shape (heads): {key.shape} | Value shape (heads): {value.shape}")

    # self-attention * heads = softmax((<q, k>)/d_k))v
    dots = torch.matmul(query, key.mT) * self.scale
    attn = self.softmax(dots)
    out = torch.matmul(attn, value)

    print(f"Out shape: {out.shape}")

    # TODO: dropout

    # Concat last two dims together
    concat = out.reshape(batch_size, num_tokens, embed_dim)
    print(f"Concat shape: {concat.shape}")

    # project_out (linear)
    x_out = self.project_out(concat)
    print(f"Projection out shape: {x_out.shape}")

    return x_out


multihead_attention_custom = MultiheadAttentionCustom(embed_dim=512, num_heads=8)
x_multihead_out_custom = multihead_attention_custom(x)


Query shape: torch.Size([32, 128, 512]) | Key shape: torch.Size([32, 128, 512]) | Value shape: torch.Size([32, 128, 512])
Query shape (heads): torch.Size([32, 128, 8, 64]) | Key shape (heads): torch.Size([32, 128, 8, 64]) | Value shape (heads): torch.Size([32, 128, 8, 64])
Out shape: torch.Size([32, 128, 8, 64])
Concat shape: torch.Size([32, 128, 512])
Projection out shape: torch.Size([32, 128, 512])


In [None]:
x = torch.arange(batch_size*num_tokens*embed_dim, dtype=torch.float32).reshape(batch_size, num_tokens, embed_dim)
print(f"x input shape: {x.shape}")

torch.manual_seed(42)
multi_head_attention_pytorch = torch.nn.MultiheadAttention(embed_dim=512,
                                                           num_heads=8,
                                                           batch_first=True, # Does your batch dimension come first?
                                                           )


torch.manual_seed(42)
multihead_attention_custom = MultiheadAttentionCustom(embed_dim=512, num_heads=8)


x_attn_output_pytorch, attn_output_weights = multi_head_attention_pytorch(query=x, key=x, value=x,
                                                                need_weights=True) # Return weights or not

x_multihead_out_custom = multihead_attention_custom(x)

x input shape: torch.Size([32, 128, 512])
Query shape: torch.Size([32, 128, 512]) | Key shape: torch.Size([32, 128, 512]) | Value shape: torch.Size([32, 128, 512])
Query shape (heads): torch.Size([32, 128, 8, 64]) | Key shape (heads): torch.Size([32, 128, 8, 64]) | Value shape (heads): torch.Size([32, 128, 8, 64])
Out shape: torch.Size([32, 128, 8, 64])
Concat shape: torch.Size([32, 128, 512])
Projection out shape: torch.Size([32, 128, 512])


In [None]:
print(x_attn_output_pytorch.shape), print(x_multihead_out_custom.shape)

torch.Size([32, 128, 512])
torch.Size([32, 128, 512])


(None, None)

In [None]:
x_attn_output_pytorch[0][0][0], x_multihead_out_custom[0][0][0]

(tensor(16859.5352, grad_fn=<SelectBackward0>),
 tensor(-68.8901, grad_fn=<SelectBackward0>))

In [None]:
# TK - make sure these are close
torch.all(x_attn_output_pytorch.isclose(x_multihead_out_custom))

tensor(False)

In [None]:
# Next:
# Go through attention = simple version
# Go through attention -> replicate PyTorch scaled_dot_product_attention
# Go through attention -> replicate masking
# Make it learnable
# Make it multi-head (multi-head can be the same as single head if you code it to be so...)