# Sources
1. https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

In [1]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [2]:
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


In [3]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embeded_sentence = embed(sentence_int).detach()

In [4]:
type(embeded_sentence)

torch.Tensor

In [5]:
embeded_sentence

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

In [19]:
d = embeded_sentence.shape[1]
embeded_sentence.shape

torch.Size([6, 16])

In [7]:
d

16

In [8]:
d_q, d_k, d_v = 24, 24, 28 # output channel of Q, K, and V nets

In [9]:
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

# Query = "is" - the second word in our sentence

In [10]:
x_2 = embeded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


# query(token) = W_q x embedding(token) = An FFN with arbitary output channels.
1. If we think of the token embedding vector as the input, multiplying with W_q is creating **W_q.shape[0]** output channels, just like a FFN.  **W_q.shape[1]** is input channels and  **W_q.shape[0]** is the output channels.
2. We create 3 FFNs, Q, K, V
3. Q, and K needs to have the same number of output channels. Explanation later

In [11]:
keys = W_key.matmul(embeded_sentence.T).T # converting 6 tokens to their keys. the keys will be in columns, to we transpose them. now we have 6x24 
values = W_value.matmul(embeded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


# Attention Weight!
## The unnormalized weight, w_i_j = (query(current_token)(i)).T           x              key(other_token)(j)

Interestingly, the dot product of current token query and some other token key is a cosine similarity! So, it's going to attend to tokens which have similar keys as the current token's query! Interestingly, **there is probably only one way to express relationship between elements in NN: cosine similarity**. So, to model relationships between users and products, we need to create two new embeddings of the same length, so that we can measure their similarity. The more products we have, the more channels we need. These embeddings are jointly learned because they need to learn the relationship along with entity identity.

In [19]:
omega_24 = query_2.dot(keys[4]) # second token to 5th token
omega_24

tensor(2.3206, grad_fn=<DotBackward0>)

In [20]:
omega_2 = query_2.matmul(keys.T) # we saved the keys row-wise for tokens. each row -> a token

In [22]:
omega_2 # we see that query_2 and keys[1] current does not have the highest similarity, because our Q and K networks are not trained yet.

tensor([ -7.0847,  -4.5398,   3.9887,  10.2379,   2.3206, -10.5434],
       grad_fn=<SqueezeBackward3>)

## Normalizing attention weights
We can normalized the attention weights by softmax directly. However, the authors in "Attention is all you need" scaled it by (1/root(dk)). **It ensures that the euclidean length of the weight vectors to be approximately in the same magnitude. how?**

In [36]:
import torch.nn.functional as F
# attention_weights_2_no_scaling = F.softmax(omega_2 / d_k, dim=0)
attention_weights_2 = F.softmax(omega_2 / d_k ** 0.5, dim=0) # attention vector for query 2
# attention_weights_2_arbitary = F.softmax(omega_2 / 100, dim=0) # arbitary is not good!
# print(attention_weights_2_no_scaling, attention_weights_2, attention_weights_2_arbitary)
attention_weights_2

tensor([0.0185, 0.0312, 0.1778, 0.6368, 0.1265, 0.0092],
       grad_fn=<SoftmaxBackward0>)

## Attention weight is a vector of size of vocabulary size.

# Now the context vector
The context vector is the most confusing. So, each token value is multiplied by its correspoinding weight. Then the columns are summed. Don't know what it accumulates. **But more related tokens will have more information added to the context.** Context vector feels like a continuous variable, but can blend concepts.


In [38]:
context_vector_2 = attention_weights_2.matmul(values)
context_vector_2

tensor([-0.9495, -1.4345, -2.0504, -0.3737, -1.5098, -0.5921, -0.4289, -1.9790,
        -1.7937, -0.7146, -0.9926, -2.0061, -2.1961, -1.7174, -1.0732, -0.7900,
        -1.7367, -2.2095, -0.9344, -1.5299, -0.2828, -0.5350, -1.7285, -1.5485,
        -0.2043, -0.7109, -1.5165, -1.5167], grad_fn=<SqueezeBackward3>)

In [12]:
values

tensor([[-0.8441,  0.2829,  1.7343,  1.6112,  2.1563,  1.1398,  1.6928,  1.5736,
          1.7709,  0.9618,  1.3077,  0.2716,  0.3070,  0.3427,  2.4012,  1.9869,
         -1.1107,  0.6782, -0.2181,  0.8178,  0.5018,  0.9887,  1.3350,  0.1589,
         -0.3449,  0.9065,  1.6519, -0.3440],
        [ 0.5165,  0.2638,  0.1946,  0.1296, -0.2176, -1.2548, -0.9272, -1.3402,
         -0.4107, -0.0859,  1.0926,  0.4078, -0.6770,  0.1110, -1.1055,  0.3156,
         -0.3169,  0.7937, -1.1166,  3.0497, -0.2863,  1.5513,  2.7004,  0.5483,
         -2.4544, -1.5389, -0.4168,  0.2455],
        [-1.4350, -3.0582, -1.3735, -1.0167, -0.9396, -2.5408, -2.1351, -1.8701,
         -1.9994, -3.7609, -3.8755, -3.1365, -2.1639, -3.0949, -3.7118, -1.8682,
         -1.8869, -1.7023, -1.4043, -4.1602, -3.5326, -1.8202, -3.1335, -1.7162,
         -2.0997, -1.8854, -2.1002, -3.4872],
        [-1.0645, -0.9245, -3.0223, -0.6932, -2.1795,  0.1399,  0.0170, -2.8164,
         -2.1397,  0.7122, -0.7365, -1.5728, -2.6054

# Multihead attention
So far, we made one channel to produce a single context vector. Multi-head copies the single head to create multiple context vector. Input is copied multiple times

In [14]:
heads = 3
multihead_W_Q = torch.nn.Parameter(torch.rand(heads, d_q, d))
multihead_W_K = torch.nn.Parameter(torch.rand(heads, d_k, d)) # d_k == d_q
multihead_W_V = torch.nn.Parameter(torch.rand(heads, d_v, d))

In [18]:
multihead_query_2 = multihead_W_Q.matmul(x_2) # first row is regarded as the batch dim, x_2 is broadcasted

In [17]:
multihead_query_2

tensor([[-0.5317, -1.6069,  1.1446, -1.7651, -0.0069, -1.2229,  0.6265, -1.1083,
         -0.5696, -0.2894,  0.9853, -1.1417, -0.7089, -1.5938,  0.9208,  0.1191,
          0.5370,  0.1285, -1.1273,  0.3900,  1.6581,  0.2398, -0.4409,  0.8581],
        [-1.7130,  0.7300,  0.3016,  1.2672, -0.4968,  1.7345, -0.4737, -0.1441,
         -1.5281,  0.1814,  2.1658,  0.0782,  1.1053, -0.2787, -0.8150,  1.2533,
         -1.6784, -0.6629, -0.7785, -0.4442, -0.7335, -1.0979, -1.6565, -0.4303],
        [-0.2142,  0.3233, -0.9046, -1.1118,  1.2510,  1.1929, -0.8451,  1.8892,
         -0.6453,  0.0580, -0.4397, -0.2499,  1.5047,  0.9832, -1.5491, -0.4635,
          0.5074,  1.5554, -2.2974, -0.0805, -0.3201,  1.0347,  1.1994,  0.1928]],
       grad_fn=<UnsafeViewBackward>)

In [22]:
batch = torch.rand(2, 3, 5)
print(batch)

tensor([[[0.8873, 0.2034, 0.9871, 0.1758, 0.6914],
         [0.8859, 0.6605, 0.8328, 0.6707, 0.6894],
         [0.9387, 0.4778, 0.4763, 0.7615, 0.2538]],

        [[0.9377, 0.7955, 0.9131, 0.1981, 0.6997],
         [0.8676, 0.3539, 0.2717, 0.6077, 0.2121],
         [0.2421, 0.4025, 0.9509, 0.3354, 0.6794]]])


In [23]:

print(batch.T)

tensor([[[0.8873, 0.9377],
         [0.8859, 0.8676],
         [0.9387, 0.2421]],

        [[0.2034, 0.7955],
         [0.6605, 0.3539],
         [0.4778, 0.4025]],

        [[0.9871, 0.9131],
         [0.8328, 0.2717],
         [0.4763, 0.9509]],

        [[0.1758, 0.1981],
         [0.6707, 0.6077],
         [0.7615, 0.3354]],

        [[0.6914, 0.6997],
         [0.6894, 0.2121],
         [0.2538, 0.6794]]])


In [24]:
print(batch.T.shape)

torch.Size([5, 3, 2])


### Quick explanation of torch.repeat or np.tiles: https://www.sharpsightlabs.com/blog/numpy-tile/

In [28]:
embeded_sentence.T.repeat(3, 1, 1).shape

torch.Size([3, 16, 6])

In [29]:
embeded_sentence.T.repeat(3, 1).shape

torch.Size([48, 6])