## Attention

An attention function can be described as mapping a query and set of key-value pairs to an output, where the query, keys, values and output are all vectors. The output can be computed as a weighted sum of the values, where the weights assigned to each value is computed by the compatibility function of the query with the corresponding key.

> How is the output vector $y_i$ produced in self-attention?

To produced output vector $y_i$, the self attention operation simply takes a weighted average over all the input vectors.

$y_i = \sum_{j} w_{ij}x_j $

where $j$ indexes over the whole sequence and the weights sum to one over all $j$. The weight $w_{ij}$ is not a parameter, as in a normal neural network, but it is dervied from a function over $x_i$ and $x_j$. The simplest option for this function is a dot product:

$w'_{ij} = x_i^{T}x_j$

Note that $x_i$ is the input vector at the same position as the current vector $y_i$. 


## Scaled Dot-Product Attention

The authors call attention "Scaled Dot-Product Attention" (Figure 2). The input consists of queries and keys of dimension $d_k$ and values of dimension $d_v$ (Although in practice, these dimension are same). We compute the dot productsof the query with all keys, divide each by $\sqrt(d_k)$, and apply a softmax function to obtain the weights of the values.

> What is a good intuition with attention mechanism?

Attention mimics the retrieval of `value` $v_i$ for a `query` $q$ based on a `key` $k_i$ in database. So the attention can be writing as a probabilistic lookup in the database using the formula:

$ Attention(q,k,v) = \sum_{i}similarity(q,k_i) X v_i $

In practice, there are scale normalization and softmax functions applied to make it more effective for stacking

$ Attention(Q,K,V) = Softmax(QK^{T}/\sqrt(d_k))V$


> Why is a $ \sqrt(d_k)$ applied?

The $\sqrt(d_k)$ applied over the dimension $d_k$ is used to bring the scale down to unit variance. This enables the softmax function to work better due to diffusion of the computed values to map the values to $[0,1]$ to ensure that they sum to 1 over the whole sequence. 


In [10]:
# Simple character-level tokenization with small code books
import string
import torch
import torch.nn as nn
from torch.nn import functional as F

vocab = string.whitespace + string.ascii_letters + string.digits + string.punctuation
vocab_size = len(vocab)

# create a mapping of vocab to integers
vtoi = { ch:i for i,ch in enumerate(vocab)}
itov = { i:ch for i,ch in enumerate(vocab)}
# encode takes a vocab vector and returns a list of integer tokens
encode = lambda v: [vtoi[ch] for ch in v]
# decode takes a token vector and returns a list of vocab characters
decode = lambda t: [itov[i] for i in t] 

text = "Attention is all you need"
text_enc =  encode(text)
text_dec = decode(text_enc)
data = torch.LongTensor(text_enc)

# Let's construct the simplest transformer model meant for visualization only
block_size = 4 # maximum content length for prediction
n_embd = 5 # number of embedding tokens
batch_size = 2 # number of batches processed per pass

# We need two embedding tables - one for token with length of vocab_size, and another
# for the position embedding with length of block_size. We also need to specify how many
# embedding tokens to use for each embedding operation represented as n_embd
token_embedding_table = nn.Embedding(vocab_size, n_embd)
position_embedding_table = nn.Embedding(block_size, n_embd)

def get_batch():
    ix = torch.randint(0, len(text_enc)-block_size-1, (block_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y

x, y = get_batch()
print("Token:", token_embedding_table, "Position:", position_embedding_table)
B, T = x.shape
emb_tok = token_embedding_table(x)
emb_pos = position_embedding_table(torch.arange(T))
x.shape, emb_tok.shape, emb_pos.shape

Token: Embedding(100, 5) Position: Embedding(4, 5)


(torch.Size([4, 4]), torch.Size([4, 4, 5]), torch.Size([4, 5]))

In [23]:
# One head of single attention
head_size = 6

x = emb_pos + emb_tok
B,T,C = x.shape
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
print("keys:", key, "queries:", query)
k = key(x) # B, T, head_size
q = query(x) #B, T, head_size
print("x:", x.shape, "key_x:", k.shape, "query_x:", q.shape)
W = q @ k.transpose(-2, -1) # transpose the T and head_size so that B,T,h @ B,h,T -> B,T,T
W = W * C ** -0.5 # sqrt of C to maintain std and variance

tril = torch.tril(torch.ones(T,T)) 
W = W.masked_fill(tril == 0, float('-inf'))
W = F.softmax(W, dim=-1)
v = value(x)
out = W @ v



keys: Linear(in_features=5, out_features=6, bias=False) queries: Linear(in_features=5, out_features=6, bias=False)
x: torch.Size([4, 4, 5]) key_x: torch.Size([4, 4, 6]) query_x: torch.Size([4, 4, 6])
tensor(-0.4661, grad_fn=<MeanBackward0>) tensor(1.3166, grad_fn=<StdBackward0>)
tensor(-0.5325, grad_fn=<MeanBackward0>) tensor(1.2047, grad_fn=<StdBackward0>)
