# Attention examples
Anthony Gitter  
October 18, 2022  
Explore the attention calculations from the lecture slides

Recall from [Vaswani et al. Attention is All you Need](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html)
> An attention function can be described as mapping a query and a set of key-value pairs to an output,
where the query, keys, values, and output are all vectors. The output is computed as a weighted sum
of the values, where the weight assigned to each value is computed by a compatibility function of the
query with the corresponding key

The goal of attention and multi-headed self-attention is to learn good word embeddings based on word-word relationships in the training data.

In [1]:
import numpy as np
from scipy.special import softmax

## Attention warmup
First we show the calcuations for the mammal attention examples in the slides.

Define the query vector. This could be a query matrix with multiple queries as shown later.

In [2]:
mammal = np.array([8.7, 3.2, 4.1])

Initially define the key vectors individually so we can see how they correspond to the keys in the slides. These must have the same length as the query.

In [3]:
kitten = np.array([9.1, 1.0, 2.1])

In [4]:
lizard = np.array([0.1, 7.5, 4.3])

In [5]:
salmon = np.array([1.3, 5.5, 8.2])

In [6]:
whale = np.array([7.6, 2.4, 4.0])

In [7]:
wolf = np.array([8.5, 2.7, 2.7])

Compute and inspect the scaling factor we will use later.

In [8]:
scaling = np.sqrt(len(mammal))
print(scaling)

1.7320508075688772


Combine all of the keys into the keys matrix.

In [9]:
animal_names = ['kitten', 'lizard', 'salmon', 'whale', 'wolf']
animals = np.stack([kitten, lizard, salmon, whale, wolf], axis=0)

In [10]:
animals

array([[9.1, 1. , 2.1],
       [0.1, 7.5, 4.3],
       [1.3, 5.5, 8.2],
       [7.6, 2.4, 4. ],
       [8.5, 2.7, 2.7]])

Inspect the query-key dot products before scaling and taking the softmax

In [11]:
mammal@animals.T

array([90.98, 42.5 , 62.53, 90.2 , 93.66])

In [12]:
mammal@animals.T/np.sqrt(len(mammal))

array([52.52732749, 24.53738644, 36.10171233, 52.07699428, 54.07462621])

Compute the similarity weights

In [13]:
weights = softmax(mammal@animals.T/scaling)
print(weights)

[1.57823895e-01 1.10228985e-13 1.16042942e-08 1.00599432e-01
 7.41576662e-01]


Store the values that correspond to each key

In [14]:
values = np.array([[3.4,1.3,0.4,9.8], # furry
                  [7.5,3.9,4.1,0.2], # scaly
                  [8.3,2.8,2.3,0.1], # slippery
                  [1.6,8.4,9.9,3.4], # huge
                  [2.2,9.4,8.7,1.1]]) # ferocious

Compute the final ouput value for our query `mammal`

In [15]:
output = weights@values
print(output)

[2.32902909 8.02102694 7.51078092 2.70444657]


Try it with a different query

In [16]:
reptile = np.array([2.1, 9.9, 1.6])

In [17]:
weights = softmax(reptile@animals.T/scaling)
print(weights)

[5.25436708e-13 9.98297480e-01 1.70251120e-03 1.47297680e-09
 7.33251915e-09]


In [18]:
output = weights@values
print(output)

[7.50136196 3.89812728 4.09693552 0.19982976]


What if we have a query matrix instead?

In [19]:
query = np.stack([mammal, reptile], axis=0)
print(query)

[[8.7 3.2 4.1]
 [2.1 9.9 1.6]]


In [20]:
# Take the softmax over each row, not the entire matrix
weights = softmax(query@animals.T/scaling, axis=1)
print(weights)

[[1.57823895e-01 1.10228985e-13 1.16042942e-08 1.00599432e-01
  7.41576662e-01]
 [5.25436708e-13 9.98297480e-01 1.70251120e-03 1.47297680e-09
  7.33251915e-09]]


In [21]:
# First row is the updated value for 'mammal'
# Second row is the updated value for 'reptile'
animal_output = weights@values
print(f'Updated mammal embedding: {animal_output[0]}')
print(f'Updated reptile embedding: {animal_output[1]}')

Updated mammal embedding: [2.32902909 8.02102694 7.51078092 2.70444657]
Updated reptile embedding: [7.50136196 3.89812728 4.09693552 0.19982976]


## Attention function
Define an `attention` function to simplify these calculations. The outputs are the same as the above.

In [22]:
def attention(queries, keys, values):
    # Assume all inputs are 2D matrices
    assert queries.ndim == 2
    assert keys.ndim == 2
    assert values.ndim == 2
    
    scale = np.sqrt(len(queries[0]))
    # Take the softmax over each row, not the entire matrix
    weights = softmax(queries@keys.T/scale, axis=1)
    return weights@values

In [23]:
# The output from the attention function matches the output above
animal_output2 = attention(query, animals, values)
print(f'Updated mammal embedding: {animal_output2[0]}')
print(f'Updated reptile embedding: {animal_output2[1]}')

# Test that animal_output and animal_output2 are very similar
assert np.isclose(animal_output, animal_output2).all()

Updated mammal embedding: [2.32902909 8.02102694 7.51078092 2.70444657]
Updated reptile embedding: [7.50136196 3.89812728 4.09693552 0.19982976]


## Self-attention
It is not very useful to calculate self-attention yet without introducing multi-headed attention, but we can see how to use the `attention` function to do it.

We can compute self-attention between the animals and themselves.  The animals act as the queries, key, and values.  We obtain and updated representation of each animal based on its similarity to the other animals.

In [24]:
animal_self_output = attention(animals, animals, animals)

Inspect the original animal embeddings and the updated animal embeddings. They have not changed much.

In [25]:
for i in range(len(animal_names)):
    print(f'{animal_names[i]}\toriginal {animals[i]} and updated {animal_self_output[i]} embeddings')

kitten	original [9.1 1.  2.1] and updated [8.97593633 1.33207376 2.22679209] embeddings
lizard	original [0.1 7.5 4.3] and updated [0.99832734 6.00278776 7.21956386] embeddings
salmon	original [1.3 5.5 8.2] and updated [1.29999732 5.50000446 8.1999913 ] embeddings
whale	original [7.6 2.4 4. ] and updated [8.47958483 2.29781683 2.78497945] embeddings
wolf	original [8.5 2.7 2.7] and updated [8.6669283  2.1237928  2.54756204] embeddings


## Multi-headed self-attention
Self-attention is useful in its multi-headed form when the same original embeddings are transformed by a different trainable $W^Q$, $W^K$, and $W^V$ to project the embeddings before the attention calculations. These $W$ matrices are updated during stochastic gradient descent-based training, so the model can learn how to adjust the embeddings to obtain meaningful self-attention calculations in order to update the embeddings.

Here the $W$ are made up and not meaningful but can illustrate the calculations. We also only demonstrate a single attention head, whereas multi-headed attention would have many heads and many different traininable copies of the $W$ matrices.

In [26]:
np.random.seed(775)
w_q = np.random.rand(3,2)
w_k = np.random.rand(3,2)
w_v = np.random.rand(3,2)

A random projection, which would typically be trained and meaningful. Reduces the animal embedding dimension from 3 to 2.

In [27]:
w_q

array([[0.37573502, 0.24683143],
       [0.24614648, 0.38234858],
       [0.03510655, 0.89775264]])

In [28]:
w_q.shape

(3, 2)

In [29]:
animals.shape

(5, 3)

The three different $W$ matrices give three different 2-dimensional representations of the animals.

In [30]:
animals@w_q

array([[3.73905893, 4.51379518],
       [2.03463026, 6.75263382],
       [2.13013489, 9.78536968],
       [3.58676392, 6.38456605],
       [3.95313086, 5.55434048]])

In [31]:
animals@w_k

array([[ 8.96561697,  2.8778713 ],
       [ 9.18598435,  6.01216844],
       [12.06063652,  5.9414747 ],
       [10.5822552 ,  4.0030275 ],
       [10.27859483,  4.02849108]])

In [32]:
animals@w_v

array([[2.18065523, 5.67484759],
       [4.44534671, 5.48037139],
       [7.14038542, 7.99398679],
       [3.76938987, 6.79569809],
       [2.87492446, 6.39373835]])

Multi-headed self-attention calls the `attention` function with the new projections of the animal embeddings. In a Transformer, we would need to multiply this by another trainable matrix $W^O$ to map back from 2 to 3 dimensions, or if we had several attention heads from the concatenated outputs back to 3 dimensions. Here, we leave the output in 2 dimensions.

In [33]:
animal_multi_output = attention(animals@w_q, animals@w_k, animals@w_v)

Using random $W$ matrices was a bad idea! These updated embeddings don't mean anything. But we can see how the operations connect together.

In [34]:
for i in range(len(animal_names)):
    print(f'{animal_names[i]}\toriginal {animals[i]} and updated {animal_multi_output[i]} embeddings')

kitten	original [9.1 1.  2.1] and updated [7.1384725  7.99233055] embeddings
lizard	original [0.1 7.5 4.3] and updated [7.08124031 7.93886421] embeddings
salmon	original [1.3 5.5 8.2] and updated [7.08371845 7.94113509] embeddings
whale	original [7.6 2.4 4. ] and updated [7.1378387  7.99162334] embeddings
wolf	original [8.5 2.7 2.7] and updated [7.13919142 7.99289749] embeddings
