# Attention examples
Anthony Gitter  
October 13, 2022  
Explore the attention calculations from the lecture slides

In [1]:
import numpy as np
from scipy.special import softmax

Define the query vector. This could be a query matrix with multiple queries as shown later.

In [2]:
mammal = np.array([8.7, 3.2, 4.1])

Initially define the key vectors individually so we can see how they correspond to the keys in the slides. These must have the same length as the query.

In [3]:
kitten = np.array([9.1, 1.0, 2.1])

In [4]:
lizard = np.array([0.1, 7.5, 4.3])

In [5]:
salmon = np.array([1.3, 5.5, 8.2])

In [6]:
whale = np.array([7.6, 2.4, 4.0])

In [7]:
wolf = np.array([8.5, 2.7, 2.7])

Compute and inspect the scaling factor we will use later.

In [8]:
scaling = np.sqrt(len(mammal))
print(scaling)

1.7320508075688772


Combine all of the keys into the keys matrix.

In [9]:
animals = np.stack([kitten, lizard, salmon, whale, wolf], axis=0)

In [10]:
animals

array([[9.1, 1. , 2.1],
       [0.1, 7.5, 4.3],
       [1.3, 5.5, 8.2],
       [7.6, 2.4, 4. ],
       [8.5, 2.7, 2.7]])

Inspect the query-key dot products before scaling and taking the softmax

In [11]:
mammal@animals.T

array([90.98, 42.5 , 62.53, 90.2 , 93.66])

In [12]:
mammal@animals.T/np.sqrt(len(mammal))

array([52.52732749, 24.53738644, 36.10171233, 52.07699428, 54.07462621])

Compute the similarity weights

In [13]:
weights = softmax(mammal@animals.T/scaling)
print(weights)

[1.57823895e-01 1.10228985e-13 1.16042942e-08 1.00599432e-01
 7.41576662e-01]


Store the values that correspond to each key

In [14]:
values = np.array([[3.4,1.3,0.4,9.8], # furry
                  [7.5,3.9,4.1,0.2], # scaly
                  [8.3,2.8,2.3,0.1], # slippery
                  [1.6,8.4,9.9,3.4], # huge
                  [2.2,9.4,8.7,1.1]]) # ferocious

Compute the final ouput value for our query `mammal`

In [15]:
output = weights@values
print(output)

[2.32902909 8.02102694 7.51078092 2.70444657]


Try it with a different query

In [16]:
reptile = np.array([2.1, 9.9, 1.6])

In [17]:
weights = softmax(reptile@animals.T/scaling)
print(weights)

[5.25436708e-13 9.98297480e-01 1.70251120e-03 1.47297680e-09
 7.33251915e-09]


In [18]:
output = weights@values
print(output)

[7.50136196 3.89812728 4.09693552 0.19982976]


What if we have a query matrix instead?

In [19]:
query = np.stack([mammal, reptile], axis=0)
print(query)

[[8.7 3.2 4.1]
 [2.1 9.9 1.6]]


In [20]:
# Take the softmax over each row, not the entire matrix
weights = softmax(query@animals.T/scaling, axis=1)
print(weights)

[[1.57823895e-01 1.10228985e-13 1.16042942e-08 1.00599432e-01
  7.41576662e-01]
 [5.25436708e-13 9.98297480e-01 1.70251120e-03 1.47297680e-09
  7.33251915e-09]]


In [21]:
# First row is the updated value for 'mammal'
# Second row is the updated value for 'reptile'
output = weights@values
print(output)

[[2.32902909 8.02102694 7.51078092 2.70444657]
 [7.50136196 3.89812728 4.09693552 0.19982976]]
