For simplicity we're are gonna use the following
- Use a single sentence, not a full corpus.
- Use a small embedding dimension (3). This allows use examine individual vectors without filling the entire page

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.manual_seed(123)

<torch._C.Generator at 0x10f2cf190>

### Embedding a input sentence

In [3]:
sentence = 'Life is short, eat dessert first'

In [4]:
#vocab is restricted to the words in the sentence :)
words = sorted([s for s in sentence.replace(",", "").split()])

In [5]:
word_to_idx = {w:idx for idx, w in enumerate(words)}

In [6]:
# use the word_to_idx dict to assign integer index to the word
sentence_int = torch.tensor([word_to_idx[s] for s in sentence.replace(",", "").split()])

In [7]:
vocab_size = 50000
embed = torch.nn.Embedding(vocab_size, 3)

In [8]:
embedded_sentence = embed(sentence_int)

In [9]:
embedded_sentence

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]], grad_fn=<EmbeddingBackward0>)

In [10]:
embedded_sentence.shape

torch.Size([6, 3])

### Defining weight matrices

- Usually a linear layer is used to create the Query, Key and Value matrices, here we are gonna use nn.Parameter layer
- The embedding dim of the value matrix need not be the same size of Query and Value, it can be arbitary

In [10]:
d_q, d_k, d_v = 2, 2, 4
d = embedded_sentence.shape[1]

In [11]:
w_query = torch.nn.Parameter(torch.randn(d, d_q))
w_key = torch.nn.Parameter(torch.randn(d, d_k))
w_value = torch.nn.Parameter(torch.randn(d, d_v))

In [12]:
keys = embedded_sentence @ w_key
values = embedded_sentence @ w_value

In [13]:
# lets compute the attention weights for one single example

In [14]:
x_1 = embedded_sentence[1]
query_1 = x_1[None, :] @ w_query

In [15]:
omega_1 = query_1 @ keys.T

In [16]:
omega_1

tensor([[-0.6150,  2.4277, -0.9584,  0.2260,  1.2082,  0.5242]],
       grad_fn=<MmBackward0>)

In [22]:
attention_weights = F.softmax(omega_1 / d_k ** 0.5, dim=-1)

In [23]:
attention_weights.shape

torch.Size([1, 6])

In [24]:
values.shape

torch.Size([6, 4])

In [25]:
context_vector_1 = attention_weights @ values

In [26]:
context_vector_1.shape

torch.Size([1, 4])

### Self Attention
- Now let's summarize the previous section in a single SelfAttention Class

In [28]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.w_query = nn.Parameter(torch.randn(d_in, d_out_kq))
        self.w_key = nn.Parameter(torch.randn(d_in, d_out_kq))
        self.w_value = nn.Parameter(torch.randn(d_in, d_out_v))

    def forward(self, x):
        queries = x @ self.w_query
        keyes = x @ self.w_key
        values = x @ self.w_value

        attn_scores = queries @ keyes.T

        attn_weights = torch.softmax((attn_scores / self.d_out_kq ** 0.5), dim=-1)

        context_vec = attn_weights @ values

        return context_vec

In [29]:
self_attn = SelfAttention(3, 2, 4)

In [32]:
sa_out = self_attn(embedded_sentence)

In [33]:
print(sa_out.shape)

torch.Size([6, 4])
