# Experiments with Attention Mechanisms

Some simple pytorch manipulation based on the math in [Attention is All You Need](https://arxiv.org/abs/1706.03762).


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Scaled Dot Product Attention

In [15]:
# Parameters
B, T, C = 4, 8, 32
retrieval_dims = 16
value_dims = 16

# Input
x = torch.rand(B, T, C)
x.shape

torch.Size([4, 8, 32])

In [57]:
# Project input tensor onto lower dimension
query_projection = nn.Linear(C, retrieval_dims, bias=False)
key_projection = nn.Linear(C, retrieval_dims, bias=False)
value_projection = nn.Linear(C, value_dims, bias=False)

query: torch.Tensor = query_projection(x)  # (B, T, retrieval_dims)
key: torch.Tensor = key_projection(x)  # (B, T, retrieval_dims)
value: torch.Tensor = value_projection(x)  # (B, T, value_dims)

# Calculate attention_weight[i, j] = dot_product(query[i], key[j]) for each batch
wei = query @ key.transpose(2, 1)  # (B, T, T) -> (T, T) = attention_weight

# Scale attention weights to keep variance from exploding
wei = wei * retrieval_dims**-0.5

# Mask and normalize with softmax so each token only attends to previously tokens in the sequence
mask = torch.tril(torch.ones(T, T)) == 0
wei = wei.masked_fill(mask, float("-inf"))
wei = F.softmax(wei, dim=-1)  # (B, T, T) -> (T, T) = normalized_attention_weight

# Weighted average of value tensors based on attention weights
output = wei @ value  # (B, T, value_dims)
print(f"{output.shape=}\n{(B, T, value_dims)=}")

# Sanity Checking
assert torch.allclose(
    output[:, 0], value[:, 0]
), "The first token should only tend to itself"
assert not torch.allclose(
    output[:, 1:], value[:, 1:]
), "Future tokens attend to more than just themselves"

output.shape=torch.Size([4, 8, 8])
(B, T, value_dims)=(4, 8, 8)


## Multi-Head Attention

In [56]:
# Operate on 4 heads in parallel, each of which deal with 8 dimensional
# keys, queries, and values
heads = 4
value_dims = retrieval_dims = C // heads

In [72]:
# Project input to 8-d space for each head
query_projection = nn.Linear(C, retrieval_dims * heads, bias=False)
key_projection = nn.Linear(C, retrieval_dims * heads, bias=False)
value_projection = nn.Linear(C, value_dims * heads, bias=False)

query: torch.Tensor = query_projection(x)
key: torch.Tensor = key_projection(x)
value: torch.Tensor = value_projection(x)

query = query.view(B, T, heads, retrieval_dims)
key = key.view(B, T, heads, retrieval_dims)
value = value.view(B, T, heads, value_dims)

# Calculate attention weights
wei = torch.einsum("bihk,bjhk->bhij", query, key)  # (B, heads, T, T)
mask = torch.tril(torch.ones((T, T))) == 0
wei = wei.masked_fill(mask, float("-inf"))
wei = F.softmax(wei, dim=-1)  # (B, heads, T, T)

# Calculate weighted value vectors
# I wonder if there's a more efficient way to do this without needing to call .contiguous().
# Does torch.compile take care of this for you?
output = torch.einsum(
    "bhij,bjhc->bihc", wei, value
).contiguous()  # (B, T, heads, value_dims)

# Concatenate value vectors
output = output.view(B, T, heads * value_dims)

# Sanity check
assert torch.allclose(output[:, 0], value[:, 0].view(B, -1))
assert not torch.allclose(output[:, 1:], value[:, 1:].view(B, T - 1, -1))