In this notebook I'm updating my simpler attention function to incorporate masking, and also doing some more thorough testing of it. Later I'll need to ensure that the parallelized attention function is working and also implement masking/appropriate testing for that.

In [1]:
import torch
from torch import tensor
from torch.nn.functional import softmax
from math import sqrt

In order for masking to work, I need to know the position of the current token (query token) relative to the positions of all the key/value tokens. This information might actually be more accessible in the parallelized versions since query/key/value matrices might have the same indexing relative to the original input (meaning for example that the 5th value vector and 5th query vector are both associated with the 5th input token). I think in this case I need to manually pass in the position of the query token, as an index of the key/value vectors. So passing in position 5 would mean that the query token also has the value vector on row 5 of the values matrix, and I can therefore mask every row of the values matrix after that. In the parallelized function I think this will look more like cutting out a diagonal slice of the scaled weights matrix.

In [5]:
# I assume all rows have the same dimension, which is provided in dim
# keys, values are assumed to be structured such that each key/value is a row.
# Of course, there must also be an equal number of keys and values, so those matrices must have equal dimensions
from numpy import Infinity


def attention(query: tensor, keys: tensor, values: tensor, dim: int, query_pos: int) -> tensor:
    # query = query.view(1, -1)
    raw_weights = query @ keys.T

    # Masking:
    for i in range(query_pos + 1, len(raw_weights)):
        raw_weights[i] = -1 * Infinity

    scale_factor = sqrt(dim)
    scaled_weights = softmax(raw_weights / scale_factor, dim=0)

    scaled_values = scaled_weights.view(-1, 1) * values
    contextualized_value = torch.sum(scaled_values, 0)

    return contextualized_value


Simple test case that I did by hand. Expected result is approximately [3, 2], which is in fact what we see!

In [25]:
Q = tensor([1, 3]).float()
K = tensor([[2, 1], [3, 5]]).float()
V = tensor([[0, 5], [3, 2]]).float()
d = 2
print(attention(Q, K, V, d, 1))
print(attention(Q, K, V, d, 0))

tensor([2.9997, 2.0003])
tensor([0., 5.])


Now I'll do the same thing with my parallel attention function, although I won't worry about adding masking quite yet.

In [69]:
def par_attention(queries: tensor, keys: tensor, values: tensor, dim: int) -> tensor:
    raw_weights = queries @ keys.T
    # for query_pos in range(0, queries.shape[0]):
    #     for weight_pos in range(query_pos + 1, dim):
    #         raw_weights[query_pos][weight_pos] = -1 * Infinity
    mask = torch.tril(torch.ones_like(raw_weights), diagonal=0)
    raw_weights = raw_weights.masked_fill(mask == 0, float('-inf'))
    print(raw_weights)

    scale_factor = sqrt(dim)
    scaled_weights = softmax(raw_weights / scale_factor, dim=1)

    # now scaled weights is a matrix where each row represents the scaled weights produced based on a given query.
    # meanwhile values just has a value vector on each row.

    reshaped_scaled_weights = scaled_weights.view(scaled_weights.shape[0], scaled_weights.shape[1], 1)
    reshaped_values = values.view(1, values.shape[0], values.shape[1])

    scaled_values = reshaped_scaled_weights * reshaped_values

    contextualized_values = torch.sum(scaled_values, 1)
    return contextualized_values

In [70]:
Q = tensor([[1, 3], [1, 1]]).float()
K = tensor([[2, 1], [3, 5]]).float()
V = tensor([[0, 5], [3, 2]]).float()
d = 2
par_attention(Q, K, V, d)


tensor([[5., -inf],
        [3., 8.]])


tensor([[0.0000, 5.0000],
        [2.9150, 2.0850]])

That looks like it's working right to me, which is great! I now have masking implemented and a much more fully validated, parallel attention function.

Is the function differentiable?

In [74]:
# Define toy input tensors (ensure that they require gradients)
queries = torch.randn(5, 4, requires_grad=True)
keys = torch.randn(5, 4, requires_grad=True)
values = torch.randn(5, 4, requires_grad=True)

# Call your attention function
output = par_attention(queries, keys, values, 4)

# Sum up the elements of the output to get a scalar
# This is a simple stand-in for a loss function
loss = output.sum()

# Backpropagate
loss.backward()

# Check the gradients
print("Gradients w.r.t queries:", queries.grad)
print("Gradients w.r.t keys:", keys.grad)
print("Gradients w.r.t values:", values.grad)
print("output:", output)


tensor([[-9.3423e-01,        -inf,        -inf,        -inf,        -inf],
        [ 3.0195e+00, -1.0212e+00,        -inf,        -inf,        -inf],
        [-9.1091e-03,  4.1238e+00,  1.1084e+00,        -inf,        -inf],
        [ 1.8217e+00, -1.1473e+00,  8.7295e-01,  8.9629e-01,        -inf],
        [-7.9833e+00,  1.7338e-01,  1.2350e+00, -1.0137e+00, -9.6294e+00]],
       grad_fn=<MaskedFillBackward0>)
Gradients w.r.t queries: tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.4473,  0.9911,  0.4480, -0.6566],
        [-0.1640,  0.9723,  0.1193, -0.7024],
        [-0.9393,  0.4948,  0.9826, -0.9392],
        [ 0.3207,  0.2942, -0.4090, -0.4319]])
Gradients w.r.t keys: tensor([[-3.1726e-01,  1.3945e-02,  3.3322e-01, -4.2003e-01],
        [-2.8608e-01,  6.0633e-01, -3.5469e-01, -8.7983e-02],
        [ 7.4231e-01, -4.8737e-01, -4.9657e-02,  3.4402e-01],
        [-1.3873e-01, -1.3295e-01,  7.1092e-02,  1.6410e-01],
        [-2.4486e-04,  4.6582e-05,  2.9947e-05, -1.1360e-04]