<a href="https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap12/12_1_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 12.1: Self Attention**

This notebook builds a self-attention mechanism from scratch, as discussed in section 12.2 of the book.

Work through the cells below, running each cell in turn. In various places you will see the words "TO DO". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.

Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.



In [1]:
import numpy as np
import matplotlib.pyplot as plt

The self-attention mechanism maps $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ and returns $N$ outputs $\mathbf{x}'_{n}\in \mathbb{R}^{D}$.  



In [2]:
# Set seed so we get the same random numbers
np.random.seed(3)
# Number of inputs
N = 3
# Number of dimensions of each input
D = 4
# Create an empty list
all_x = []
# Create elements x_n and append to list
for n in range(N):
  all_x.append(np.random.normal(size=(D,1)))
# Print out the list
print(all_x)


[array([[ 1.78862847],
       [ 0.43650985],
       [ 0.09649747],
       [-1.8634927 ]]), array([[-0.2773882 ],
       [-0.35475898],
       [-0.08274148],
       [-0.62700068]]), array([[-0.04381817],
       [-0.47721803],
       [-1.31386475],
       [ 0.88462238]])]


We'll also need the weights and biases for the keys, queries, and values (equations 12.2 and 12.4)

In [3]:
# Set seed so we get the same random numbers
np.random.seed(0)

# Choose random values for the parameters
omega_q = np.random.normal(size=(D,D))
omega_k = np.random.normal(size=(D,D))
omega_v = np.random.normal(size=(D,D))
beta_q = np.random.normal(size=(D,1))
beta_k = np.random.normal(size=(D,1))
beta_v = np.random.normal(size=(D,1))

Now let's compute the queries, keys, and values for each input

In [4]:
# Make three lists to store queries, keys, and values
all_queries = []
all_keys = []
all_values = []
# For every input
for x in all_x:
  # TODO -- compute the keys, queries and values.
  # Replace these three lines
  query = np.dot(omega_q, x)
  key = np.dot(omega_k, x)
  value = np.dot(omega_v, x)


  all_queries.append(query)
  all_keys.append(key)
  all_values.append(value)

We'll need a softmax function (equation 12.5) -- here, it will take a list of arbitrary numbers and return a list where the elements are non-negative and sum to one


In [5]:
import numpy as np

def softmax(items_in):
    # Exponentiate each item in the input list
    exp_items = np.exp(items_in)
    
    # Divide each exponentiated item by the sum of all exponentiated items
    items_out = exp_items / np.sum(exp_items)
    
    return items_out


Now compute the self attention values:

In [6]:
import numpy as np

# Assume `all_queries`, `all_keys`, and `all_values` are defined, as well as `N` (number of queries) and `D` (dimension).
# `all_keys` and `all_values` should be lists of numpy arrays where each array has shape (D, 1).
# `N` represents the number of queries (outputs we want to generate).

# Create an empty list for outputs
all_x_prime = []

# For each output (query)
for n in range(N):
    # Get the query vector for the nth output
    query = all_queries[n]
    
    # Create list to store dot products of query `n` with all keys
    all_km_qn = []
    
    # Compute the dot products of the query with each key
    for key in all_keys:
        # Compute dot product between the query and each key
        dot_product = np.dot(query.T, key).item()  # .item() converts the result to a scalar
        
        # Store the computed dot product
        all_km_qn.append(dot_product)

    # Compute attention weights using softmax
    attention = softmax(all_km_qn)
    
    # Print attention result (should be positive and sum to one)
    print("Attentions for output", n)
    print(attention)

    # Compute a weighted sum of all the values according to the attention weights
    x_prime = np.zeros((D, 1))  # Initialize the weighted sum as a zero vector of shape (D, 1)
    
    # Sum the values weighted by attention
    for i, value in enumerate(all_values):
        x_prime += attention[i] * value  # Multiply each value by its attention weight and accumulate

    # Store the computed x_prime
    all_x_prime.append(x_prime)

# Print calculated and true values to check correctness
print("x_prime_0_calculated:", all_x_prime[0].T)
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_calculated:", all_x_prime[1].T)
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_calculated:", all_x_prime[2].T)
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")


Attentions for output 0
[2.51131327e-10 9.99871742e-01 1.28257326e-04]
Attentions for output 1
[7.05249423e-08 7.32497391e-02 9.26750190e-01]
Attentions for output 2
[9.92197921e-01 6.37007031e-04 7.16507185e-03]
x_prime_0_calculated: [[ 0.87981225 -0.54621062 -0.28669931 -0.08670637]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_calculated: [[ 1.52832666 -0.39795456  4.32644077  2.37090289]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_calculated: [[-2.74393541  3.22286389 -6.21268498 -2.63336572]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


Now let's compute the same thing, but using matrix calculations.  We'll store the $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ in the columns of a $D\times N$ matrix, using equations 12.6 and 12.7/8.

Note:  The book uses column vectors (for compatibility with the rest of the text), but in the wider literature it is more normal to store the inputs in the rows of a matrix;  in this case, the computation is the same, but all the matrices are transposed and the operations proceed in the reverse order.

In [7]:
# Define softmax operation that works independently on each column
def softmax_cols(data_in):
  # Exponentiate all of the values
  exp_values = np.exp(data_in) ;
  # Sum over columns
  denom = np.sum(exp_values, axis = 0);
  # Replicate denominator to N rows
  denom = np.matmul(np.ones((data_in.shape[0],1)), denom[np.newaxis,:])
  # Compute softmax
  softmax = exp_values / denom
  # return the answer
  return softmax

In [11]:
import numpy as np

def self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):
    """
    Compute self-attention for a sequence in matrix form.

    Parameters:
    X       - Input sequence matrix of shape (N, D), where N is the number of inputs and D is the input dimension.
    omega_v - Weight matrix for values, shape (D, D).
    omega_q - Weight matrix for queries, shape (D, D).
    omega_k - Weight matrix for keys, shape (D, D).
    beta_v  - Bias vector for values, shape (D,).
    beta_q  - Bias vector for queries, shape (D,).
    beta_k  - Bias vector for keys, shape (D,).

    Returns:
    X_prime - Output sequence matrix of shape (N, D) after self-attention is applied.
    """
    # Step 1: Compute Queries, Keys, and Values
    Q = X @ omega_q + beta_q.reshape(1, -1)  # Query matrix   
    K = X @ omega_k + beta_k.reshape(1, -1)  # Key matrix
    V = X @ omega_v + beta_v.reshape(1, -1)  # Value matrix


    # Step 2: Compute Dot Products (Attention Scores)
    # (N, D) x (D, N) -> (N, N)
    attention_scores = Q @ K.T

    # Step 3: Apply Softmax to Calculate Attention Weights
    # Use softmax along each row to normalize the attention scores
    attention_weights = softmax(attention_scores)

    # Step 4: Weight Values by Attention Weights
    # (N, N) x (N, D) -> (N, D)
    X_prime = attention_weights @ V

    return X_prime


In [12]:
# Corrected dimensions for X
X = np.zeros((N, D))

# Assign values to each row of X
X[0, :] = np.squeeze(all_x[0])
X[1, :] = np.squeeze(all_x[1])
X[2, :] = np.squeeze(all_x[2])

# Run the self-attention mechanism
X_prime = self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)


[[ 0.04111728  0.119473    0.06520696 -0.19119696]
 [ 0.00421243 -0.55397478  0.34775034 -0.69706642]
 [-0.04731911 -0.7636419   0.2919839  -0.57878701]]


If you did this correctly, the values should be the same as above.

TODO:  

Print out the attention matrix
You will see that the values are quite extreme (one is very close to one and the others are very close to zero.  Now we'll fix this problem by using scaled dot-product attention.

In [13]:
def scaled_dot_product_self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):
    """
    Compute scaled dot-product self-attention for a sequence in matrix form.

    Parameters:
    X       - Input sequence matrix of shape (N, D), where N is the number of inputs and D is the input dimension.
    omega_v - Weight matrix for values, shape (D, D).
    omega_q - Weight matrix for queries, shape (D, D).
    omega_k - Weight matrix for keys, shape (D, D).
    beta_v  - Bias vector for values, shape (D,).
    beta_q  - Bias vector for queries, shape (D,).
    beta_k  - Bias vector for keys, shape (D,).

    Returns:
    X_prime - Output sequence matrix of shape (N, D) after self-attention is applied.
    """
    # Step 1: Compute Queries, Keys, and Values
    Q = X @ omega_q + beta_q.reshape(1, -1)  # Query matrix (N, D)
    K = X @ omega_k + beta_k.reshape(1, -1)  # Key matrix (N, D)
    V = X @ omega_v + beta_v.reshape(1, -1)  # Value matrix (N, D)

    # Step 2: Compute Dot Products (Attention Scores)
    attention_scores = Q @ K.T  # (N, N)

    # Step 3: Scale the Dot Products
    d_k = Q.shape[1]  # Dimensionality of the queries (or keys)
    scaled_scores = attention_scores / np.sqrt(d_k)

    # Step 4: Apply Softmax to Calculate Attention Weights
    attention_weights = softmax(scaled_scores)  # (N, N)

    # Step 5: Compute Weighted Sum of Values
    X_prime = attention_weights @ V  # (N, D)

    return X_prime

In [14]:
# Run the self attention mechanism
X_prime = scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

[[ 0.06844147  0.19846437  0.11080554 -0.30884349]
 [ 0.0437153  -0.29627934  0.32854197 -0.68906003]
 [-0.0146878  -0.48757334  0.24051127 -0.48450209]]


TODO -- Investigate whether the self-attention mechanism is covariant with respect to permutation.
If it is, when we permute the columns of the input matrix $\mathbf{X}$, the columns of the output matrix $\mathbf{X}'$ will also be permuted.
