In [43]:
from nltk import word_tokenize
import gensim.downloader as api
import torch

In [44]:
model = api.load("glove-twitter-25")

In [45]:
sentence = "Your journey starts with one step"
tokens = word_tokenize(sentence.lower())
tokens

['your', 'journey', 'starts', 'with', 'one', 'step']

In [46]:
embeddings = torch.tensor(model[tokens])
embeddings.shape

torch.Size([6, 25])

In [51]:
d_in = embeddings.shape[-1]
d_out = 2

w_queries = torch.nn.Linear(d_in,d_out, bias=False)
w_keys = torch.nn.Linear(d_in,d_out, bias=False)
w_values = torch.nn.Linear(d_in,d_out, bias=False)
w_values.weight.shape

torch.Size([2, 25])

In [56]:
# the following variable will have meaning when we train the w matrices using NN's
# this is essentially just matrix multiplication of (2,25) weights with (,25) embeddings
queries = w_queries(embeddings) # represents the current token the model is focusing on
keys = w_keys(embeddings) # keys are like token embeddings but specialized to this sentence, so we can judge the relation between them/ match with queries
values = w_values(embeddings)# reduced dimension representation of a token
values.shape

torch.Size([6, 2])

In [57]:
# SCALED DOT PRODUCT ATTENTION

attention_scores = queries @ keys.T
# essentially a matrix of what percent every query is affected by every other token in the sentence
attention_weights = torch.softmax(attention_scores / d_out**0.5, dim=1)
# we will scale the above by root of d_out, which makes the softmax less strict the higher the number of d_out, making the output more stable, as the higher the d_out the higher the variance of the distribution, so to keep it close to 1 we divide
print(attention_weights)

tensor([[0.1770, 0.1408, 0.1560, 0.1981, 0.1814, 0.1466],
        [0.1430, 0.1570, 0.2047, 0.1667, 0.1642, 0.1645],
        [0.1701, 0.1706, 0.1612, 0.1639, 0.1656, 0.1686],
        [0.1669, 0.1487, 0.1690, 0.1857, 0.1754, 0.1542],
        [0.1866, 0.1420, 0.1442, 0.1991, 0.1825, 0.1457],
        [0.1508, 0.1555, 0.1921, 0.1720, 0.1676, 0.1620]],
       grad_fn=<SoftmaxBackward0>)


In [58]:
# now we assign the weightages of all the keys to the values
context_vectors = attention_weights @ values
# enriched vectors which has the semantic meaning of a token wrt all other tokens that affect it
context_vectors.shape

torch.Size([6, 2])

In [None]:
# CAUSAL ATTENTION ( masked attention )
# variation of self attention in which we only have access to the current and the previous tokens
# code is also same as self attention but the context vectors are of the form
# x 0 0 0
# x x 0 0
# x x x 0
# x x x x
# using an attention mask
# instead of
# x x x x
# x x x x
# x x x x
# x x x x