<a href="https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap12/12_1_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 12.1: Self Attention**

This notebook builds a self-attention mechanism from scratch, as discussed in section 12.2 of the book.

Work through the cells below, running each cell in turn. In various places you will see the words "TODO". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.

Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.



In [1]:
import numpy as np
import matplotlib.pyplot as plt

The self-attention mechanism maps $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ and returns $N$ outputs $\mathbf{x}'_{n}\in \mathbb{R}^{D}$.  



In [2]:
# Set seed so we get the same random numbers
np.random.seed(3)
# Number of inputs
N = 3
# Number of dimensions of each input
D = 4
# Create an empty list
all_x = []
# Create elements x_n and append to list
for n in range(N):
  all_x.append(np.random.normal(size=(D,1)))
# Print out the list
print(all_x)


[array([[ 1.78862847],
       [ 0.43650985],
       [ 0.09649747],
       [-1.8634927 ]]), array([[-0.2773882 ],
       [-0.35475898],
       [-0.08274148],
       [-0.62700068]]), array([[-0.04381817],
       [-0.47721803],
       [-1.31386475],
       [ 0.88462238]])]


We'll also need the weights and biases for the keys, queries, and values (equations 12.2 and 12.4)

In [3]:
# Set seed so we get the same random numbers
np.random.seed(0)

# Choose random values for the parameters
omega_q = np.random.normal(size=(D,D))
omega_k = np.random.normal(size=(D,D))
omega_v = np.random.normal(size=(D,D))
beta_q = np.random.normal(size=(D,1))
beta_k = np.random.normal(size=(D,1))
beta_v = np.random.normal(size=(D,1))

Now let's compute the queries, keys, and values for each input

In [5]:
# Make three lists to store queries, keys, and values
all_queries = []
all_keys = []
all_values = []
# For every input
for x in all_x:
  # TODO -- compute the keys, queries and values.
  # Replace these three lines
  query = omega_q @ x + beta_q
  key = omega_k @ x + beta_k
  value = omega_v @ x + beta_v

  all_queries.append(query)
  all_keys.append(key)
  all_values.append(value)

We'll need a softmax function (equation 12.5) -- here, it will take a list of arbitrary numbers and return a list where the elements are non-negative and sum to one


In [6]:
def softmax(items_in):
    # Convert to a NumPy array
    items = np.array(items_in, dtype=float)
    
    # Numerical stability trick:
    # Subtract the largest value before exponentiating
    # (prevents overflow for large numbers)
    items_exp = np.exp(items - np.max(items))
    
    # Normalize so they sum to 1
    items_out = items_exp / np.sum(items_exp)
    
    return items_out


Now compute the self attention values:

In [7]:
# Create empty list for output
all_x_prime = []

# For each output n
for n in range(N):
    q_n = all_queries[n]

    # 1) dot-products k_m^T q_n for all m
    all_km_qn = []
    for key in all_keys:
        dot_product = float(key.T @ q_n)   # scalar
        all_km_qn.append(dot_product)

    # 2) softmax to get attention weights over m
    attention = softmax(all_km_qn)

    print("Attentions for output ", n)
    print(attention)

    # 3) weighted sum of values
    x_prime = np.zeros((D,1))
    for m, w in enumerate(attention):
        x_prime += w * all_values[m]

    all_x_prime.append(x_prime)

# Print out true values to check you have it correct (from the notebook)
print("x_prime_0_calculated:", all_x_prime[0].transpose())
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_calculated:", all_x_prime[1].transpose())
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_calculated:", all_x_prime[2].transpose())
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")


Attentions for output  0
[1.24326146e-13 9.98281489e-01 1.71851130e-03]
Attentions for output  1
[2.79525306e-12 5.85506360e-03 9.94144936e-01]
Attentions for output  2
[0.00505708 0.00654776 0.98839516]
x_prime_0_calculated: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_calculated: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_calculated: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


  dot_product = float(key.T @ q_n)   # scalar


Now let's compute the same thing, but using matrix calculations.  We'll store the $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ in the columns of a $D\times N$ matrix, using equations 12.6 and 12.7/8.

Note:  The book uses column vectors (for compatibility with the rest of the text), but in the wider literature it is more normal to store the inputs in the rows of a matrix;  in this case, the computation is the same, but all the matrices are transposed and the operations proceed in the reverse order.

In [10]:
def softmax_cols(data_in):
    # Ensure float array
    data = np.array(data_in, dtype=float)

    # For numerical stability: subtract max per column
    max_per_col = np.max(data, axis=0, keepdims=True)

    # Exponentiate
    exp_values = np.exp(data - max_per_col)

    # Normalize so each column sums to 1
    denom = np.sum(exp_values, axis=0, keepdims=True)
    softmax = exp_values / denom

    return softmax


In [11]:
def self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):
    D, N = X.shape
    ones_row = np.ones((1, N))  # to broadcast biases across columns

    # 1) Compute queries, keys, values (each D x N)
    Q = omega_q @ X + beta_q @ ones_row
    K = omega_k @ X + beta_k @ ones_row
    V = omega_v @ X + beta_v @ ones_row

    # 2) Dot products: K^T Q -> (N x D)(D x N) = N x N, with entries k_m^T q_n
    scores = K.T @ Q

    # 3) Column-wise softmax to get attention over keys for each query
    A = softmax_cols(scores)   # each column sums to 1

    # 4) Weight values by attention
    X_prime = V @ A            # (D x N)(N x N) = D x N

    return X_prime


In [12]:
# Copy data into matrix
X = np.zeros((D, N))
X[:,0] = np.squeeze(all_x[0])
X[:,1] = np.squeeze(all_x[1])
X[:,2] = np.squeeze(all_x[2])

# Run the self attention mechanism
X_prime = self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

[[ 0.94744244  1.64201168  1.61949281]
 [-0.24348429 -0.08470004 -0.06641533]
 [-0.91310441  4.02764044  3.96863308]
 [-0.44522983  2.18690791  2.15858316]]


If you did this correctly, the values should be the same as above.

TODO:  

Print out the attention matrix
You will see that the values are quite extreme (one is very close to one and the others are very close to zero.  Now we'll fix this problem by using scaled dot-product attention.

In [13]:
def scaled_dot_product_self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):
    D, N = X.shape
    ones_row = np.ones((1, N))

    # 1) Q, K, V  (each D x N)
    Q = omega_q @ X + beta_q @ ones_row
    K = omega_k @ X + beta_k @ ones_row
    V = omega_v @ X + beta_v @ ones_row

    # 2) Scores K^T Q (N x N), then 3) scale by sqrt(D)
    scores = (K.T @ Q) / np.sqrt(D)

    # 4) Column-wise softmax -> attention A (N x N)
    A = softmax_cols(scores)

    # 5) Weighted sum: X' = V A  (D x N)
    X_prime = V @ A
    return X_prime


In [14]:
# Run the self attention mechanism
X_prime = scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

[[ 0.97411966  1.59622051  1.32638014]
 [-0.23738409 -0.09516106  0.13062402]
 [-0.72333202  3.70194096  3.02371664]
 [-0.34413007  2.01339538  1.6902419 ]]


TODO -- Investigate whether the self-attention mechanism is covariant with respect to permutation.
If it is, when we permute the columns of the input matrix $\mathbf{X}$, the columns of the output matrix $\mathbf{X}'$ will also be permuted.


In [None]:
# Define a permutation of the 3 input columns, e.g. swap columns 0 and 2
P = np.array([[0,0,1],
              [0,1,0],
              [1,0,0]])

# Permute inputs
X_perm = X @ P

# Run through attention
Xp_original = scaled_dot_product_self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)
Xp_permuted = scaled_dot_product_self_attention(X_perm, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Check covariance
print("Are they the same (up to permutation)?",
      np.allclose(Xp_permuted, Xp_original @ P))
