<a href="https://colab.research.google.com/github/Angus-Eastell/Intro_to_AI/blob/main/10_3_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Attention

Consider some inputs and attention weights

In [1]:
import torch as t
t.manual_seed(0)

H = 10            #Embedding dimension (i.e. dimension of inputs)
S = 4             #Sequence length
D = 3             #Dimension of keys/values
Dv = 5            #Dimension of queries

X = t.randn(S, H) #Example inputs
Wq = t.randn(H, D)
Wk = t.randn(H, D)
Wv = t.randn(H, Dv)

Compute the queries, keys and values

In [2]:
#Your code here

Q = X @ Wq
K = X @ Wk
V = X @ Wv

In [None]:
# @title Answer
Q = X @ Wq
K = X @ Wk
V = X @ Wv

Write a function that computes the attention matrix,
\begin{align}
  A_{ij}(\mathbf{Q}, \mathbf{K}) = \frac{\exp(\mathbf{Q}_{i, :} \cdot \mathbf{K}_{j, :})}{\sum_k \exp(\mathbf{Q}_{i, :} \cdot \mathbf{K}_{k, :})}
\end{align}
Check that the results of your function have the right shape, and sum to $1$ in the right way.

In [47]:
#Your code here

# @ computes dot product
# summing over last dimensions which corresponds to k
# need to keep original dimensions to ensure tensor keeps the same shape
A = t.exp(Q @ K.mT)/ t.sum(t.exp(Q @ K.mT), dim = -1, keepdim= True)

# From solution
#Results should be 4 x 4:
assert A.shape == (S, S)
#And columns should sum to 1:
print(A.sum(-1))
print(A)

tensor([1., 1., 1., 1.])
tensor([[2.4771e-14, 2.7799e-12, 1.0000e+00, 2.0112e-15],
        [7.8475e-16, 4.0728e-13, 1.0000e+00, 1.2259e-10],
        [3.9596e-03, 3.9879e-03, 1.3989e-04, 9.9191e-01],
        [5.4816e-09, 1.9935e-12, 8.3131e-18, 1.0000e+00]])


In [None]:
# @title Answer
def compute_A(K, Q):
    exp_QK = (Q@K.mT).exp()
    return exp_QK / exp_QK.sum(-1, keepdim=True)

A = compute_A(K, Q)

#Results should be 4 x 4:
assert A.shape == (S, S)
#And columns should sum to 1:
print(A.sum(-1))

tensor([1., 1., 1., 1.])


Write a function that computes the causal attention matrix,
\begin{align}
  A_{ij}(\mathbf{Q}, \mathbf{K}) = \frac{\mathbb{I}(j \leq i) \exp(\mathbf{Q}_{i, :} \cdot \mathbf{K}_{j, :})}{\sum_k \mathbb{I}(k \leq i) \exp(\mathbf{Q}_{i, :} \cdot \mathbf{K}_{k, :})}
\end{align}
Check that the results of your function have the right shape, and sum to $1$ in the right way.

In [46]:
#Your code here
shape_mask =(t.exp(Q @ K.mT)).shape
# shape of mask needs to be 4 by 4
mask = t.zeros(shape_mask)
i, j = t.arange(0,mask.shape[0]), t.arange(0,mask.shape[1])
for index1,val1 in enumerate(i):
  for index2,val2 in enumerate(j):
    if val2 <= val1:
      mask[index1, index2] = 1

causal_A = (mask * t.exp(Q @ K.mT))/ t.sum((mask * t.exp(Q @ K.mT)), dim = -1, keepdim= True)

# From solution
#Results should be 4 x 4:
assert causal_A.shape == (S, S)
#And columns should sum to 1:
print(causal_A.sum(-1))

print(causal_A)


tensor([1., 1., 1., 1.])
tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.9231e-03, 9.9808e-01, 0.0000e+00, 0.0000e+00],
        [4.8960e-01, 4.9310e-01, 1.7297e-02, 0.0000e+00],
        [5.4816e-09, 1.9935e-12, 8.3131e-18, 1.0000e+00]])


In [45]:
# @title Answer
def compute_causal_A(K, Q):
    range = t.arange(S)
    i = range[:, None]
    j = range[None, :]
    mask = j <= i

    exp_QK = (Q@K.mT).exp()
    masked_exp_QK = mask * exp_QK
    return masked_exp_QK / masked_exp_QK.sum(-1, keepdim=True)

causal_A = compute_causal_A(K, Q)
#Results should be 4 x 4:
assert causal_A.shape == (S, S)
#And columns should sum to 1:
print(causal_A.sum(-1))

tensor([1., 1., 1., 1.])


Write a function that computes the output of causal self-attention:
\begin{align}
\text{self-attention}_{i, :}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) &= \sum_j A_{ij}(\mathbf{Q}, \mathbf{K}) V_{j, :}
\end{align}

In [56]:
#Your code here
# use mat mul to find weighted sum
attension = causal_A @ V
print(attension.shape)
attension

torch.Size([4, 5])


tensor([[-0.7919, -2.3897,  3.8101,  2.2223, -0.2126],
        [-1.0591,  1.0445,  3.9767,  1.7151,  1.5959],
        [-0.8514, -0.6575,  3.8082,  1.9692,  0.6545],
        [ 0.3252,  4.1818, -2.1640,  0.4850,  4.6732]])

In [58]:
# @title Answer
def compute_self_attention(Q, K, V):
    return compute_causal_A(Q, K) @ V
self_attention = compute_self_attention(Q, K, V)
print(self_attention)
assert self_attention.shape == (S, Dv)

tensor([[-0.7919, -2.3897,  3.8101,  2.2223, -0.2126],
        [-0.7953, -2.3463,  3.8122,  2.2159, -0.1897],
        [-0.8133, -2.1141,  3.8235,  2.1816, -0.0674],
        [ 0.3252,  4.1817, -2.1640,  0.4850,  4.6732]])
