Ok, so in this file I'm going to implement an attention function. What is an attention function?
- It helps us find the contextualized meaning of a certain word/token.
- Input: One query vector, many key vectors, equally many value vectors.
- Output: New, more richly contextualized value vector corresponding to the word we got a query vector for.
Also, in practice it's a good idea to compute many attention functions at once to take advantage of matrix multiplication speedup. I'll worry about that after I get the minimal version of the function working.

In [2]:
import torch
from torch import tensor
from torch.nn.functional import softmax
from math import sqrt

In [31]:
# I assume all rows have the same dimension, which is provided in dim
# keys, values are assumed to be structured such that each key/value is a row.
# Of course, there must also be an equal number of keys and values, so those matrices must have equal dimensions
def attention(query: tensor, keys: tensor, values: tensor, dim: int) -> tensor:
    # query = query.view(1, -1)
    raw_weights = query @ keys.T
    scale_factor = sqrt(dim)
    scaled_weights = softmax(raw_weights / scale_factor, dim=0)

    scaled_values = scaled_weights.view(-1, 1) * values
    contextualized_value = torch.sum(scaled_values, 0)

    return contextualized_value


In [32]:
# Define the dimension
dim = 5

# Generate random tensors for testing
query = torch.rand(dim)
keys = torch.rand(4, dim)  # 4 keys, each of dimension 'dim'
values = torch.rand(4, dim)  # 4 values, each of dimension 'dim'

# Call the attention function
result = attention(query, keys, values, dim)

# Print the tensors and result
print("Query tensor:", query)
print("Keys tensor:", keys)
print("Values tensor:", values)
print("Result:", result)

Query tensor: tensor([0.2431, 0.6326, 0.1152, 0.8423, 0.6666])
Keys tensor: tensor([[0.7123, 0.1488, 0.2927, 0.0413, 0.1383],
        [0.1572, 0.0147, 0.4294, 0.6719, 0.2709],
        [0.2583, 0.0736, 0.8304, 0.6375, 0.5369],
        [0.5114, 0.0867, 0.6610, 0.1255, 0.9041]])
Values tensor: tensor([[0.4356, 0.9212, 0.4823, 0.4276, 0.9394],
        [0.5483, 0.2654, 0.7928, 0.3881, 0.2428],
        [0.9190, 0.8477, 0.3635, 0.2128, 0.9751],
        [0.0507, 0.0645, 0.1399, 0.3570, 0.6768]])
Result: tensor([0.4976, 0.5113, 0.4364, 0.3390, 0.7064])


This is not a meaningful way of testing my code lol, but I'm in a hurry so I'll leave this largely untested and reconsider testing later if I suspect a problem in this component. Next is the parallelized attention function, which takes many queries at once instead of just one.

In [25]:
# I assume all rows have the same dimension, which is provided in dim.
# queries, keys, values are assumed to be structured such that each query/key/value is a row.
def attention(queries: tensor, keys: tensor, values: tensor, dim: int) -> tensor:
    breakpoint()
    raw_weights = queries @ keys.T
    scale_factor = sqrt(dim)
    scaled_weights = softmax(raw_weights / scale_factor, dim=1)

    # now scaled weights is a matrix where each row represents the scaled weights produced based on a given query.
    # meanwhile values just has a value vector on each row.

    reshaped_scaled_weights = scaled_weights.view(scaled_weights.shape[0], scaled_weights.shape[1], 1)
    reshaped_values = values.view(1, values.shape[0], values.shape[1])

    scaled_values = reshaped_scaled_weights * reshaped_values

    contextualized_values = torch.sum(scaled_values, 1)
    return contextualized_values


In [29]:
# %%debug
queries = torch.rand((4, 3))  # 4 queries, each of dimension 3
print(f"Queries: {queries}")
keys = torch.rand((5, 3))     # 5 keys, each of dimension 3
print(f"Keys: {keys}")
values = torch.rand((5, 3))   # 5 values, each of dimension 3
print(f"Values: {values}")

# Call the attention function
result = attention(queries, keys, values, 3)

print("Resulting contextualized values:")
print(result)

Queries: tensor([[0.5801, 0.4604, 0.9046],
        [0.3663, 0.0597, 0.2056],
        [0.0462, 0.5163, 0.8533],
        [0.4250, 0.6046, 0.0285]])
Keys: tensor([[0.3544, 0.9319, 0.0046],
        [0.1784, 0.3313, 0.1533],
        [0.8533, 0.4826, 0.3342],
        [0.7300, 0.0827, 0.6533],
        [0.7321, 0.9282, 0.1344]])
Values: tensor([[0.4662, 0.5212, 0.7306],
        [0.5503, 0.7979, 0.5419],
        [0.7162, 0.7844, 0.0497],
        [0.6775, 0.2824, 0.1225],
        [0.0245, 0.2821, 0.8845]])
Resulting contextualized values:
tensor([[0.4871, 0.5167, 0.4452],
        [0.4886, 0.5270, 0.4537],
        [0.4837, 0.5226, 0.4617],
        [0.4657, 0.5237, 0.4865]])


Something seems to be off with this one, as the output values all seem to be very similar. I think it's time to move on, so I'll plan to use the non parallelized version and fix this one later if necessary.