# Sources
1. https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

In [1]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [2]:
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


In [6]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embeded_sentence = embed(sentence_int).detach()

In [4]:
type(embeded_sentence)

torch.Tensor

In [7]:
embeded_sentence

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

In [8]:
d = embeded_sentence.shape[1]

In [9]:
d

16

In [23]:
d_q, d_k, d_v = 24, 24, 28 # output channel of Q, K, and V nets

In [11]:
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

# Query = "is" - the second word in our sentence

In [13]:
x_2 = embeded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


# query(token) = W_q x embedding(token) = An FFN with arbitary output channels.
1. If we think of the token embedding vector as the input, multiplying with W_q is creating **W_q.shape[0]** output channels, just like a FFN.  **W_q.shape[1]** is input channels and  **W_q.shape[0]** is the output channels.
2. We create 3 FFNs, Q, K, V
3. Q, and K needs to have the same number of output channels. Explanation later

In [17]:
keys = W_key.matmul(embeded_sentence.T).T # converting 6 tokens to their keys. the keys will be in columns, to we transpose them. now we have 6x24 
values = W_value.matmul(embeded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


# Attention Weight!
## The unnormalized weight, w_i_j = (query(current_token)(i)).T           x              key(other_token)(j)

Interestingly, the dot product of current token query and some other token key is a cosine similarity! So, it's going to attend to tokens which have similar keys as the current token's query! Interestingly, there is probably only one way to express relationship between elements in NN: cosine similarity. So, to model relationships between users and products, we need to create two new embeddings of the same length, so that we can measure their similarity. The more products we have, the more channels we need. These embeddings are jointly learned because they need to learn the relationship along with entity identity.

In [19]:
omega_24 = query_2.dot(keys[4]) # second token to 5th token
omega_24

tensor(2.3206, grad_fn=<DotBackward0>)

In [20]:
omega_2 = query_2.matmul(keys.T) # we saved the keys row-wise for tokens. each row -> a token

In [22]:
omega_2 # we see that query_2 and keys[1] current does not have the highest similarity, because our Q and K networks are not trained yet.

tensor([ -7.0847,  -4.5398,   3.9887,  10.2379,   2.3206, -10.5434],
       grad_fn=<SqueezeBackward3>)

## Normalizing attention weights
We can normalized the attention weights by softmax directly. However, the authors in "Attention is all you need" scaled it by (1/root(dk)). It ensures that the euclidean length of the weight vectors to be approximately in the same magnitude.

In [36]:
import torch.nn.functional as F
# attention_weights_2_no_scaling = F.softmax(omega_2 / d_k, dim=0)
attention_weights_2 = F.softmax(omega_2 / d_k ** 0.5, dim=0)
# attention_weights_2_arbitary = F.softmax(omega_2 / 100, dim=0) # arbitary is not good!
# print(attention_weights_2_no_scaling, attention_weights_2, attention_weights_2_arbitary)
attention_weights_2

tensor([0.0185, 0.0312, 0.1778, 0.6368, 0.1265, 0.0092],
       grad_fn=<SoftmaxBackward0>)

# Now the context vector
The context vector is the most confusing. So, each token value is multiplied by its correspoinding weight. Then the columns are summed. Don't know what it accumulates. But more related tokens will have more information added to the context.