# Embedding an Input Sentence

For simplicity, here our dictionary dc is restricted to the words that occur in the input sentence. In a real-world application, we would consider all words in the training dataset (typical vocabulary sizes range between 30k to 50k).



In [3]:
sentence  = "Life is short, eat dessert first"

# Create Dictionary
dict = {s : i for i, s in enumerate(sorted(sentence.replace(",", "").split()))}

dict

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

### Assign the Index to Each Word

In [4]:
import torch

sentence_idx = torch.tensor([dict[s] for s in sentence.replace(',', '').split()])
sentence_idx

tensor([0, 4, 5, 2, 1, 3])

### Word Embedding
Here, we will use a 16-dimensional embedding such that each input word is represented by a 16-dimensional vector.

In [6]:
torch.manual_seed(123)
embeder = torch.nn.Embedding(6, 16)
embedded_sentence = embeder(sentence_idx).detach()


print(embedded_sentence.shape)

torch.Size([6, 16])


### Define Weight Matrices

In [12]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28

W_query = torch.nn.Parameter(torch.randn(d_q, d))
W_key = torch.nn.Parameter(torch.randn(d_k, d))
W_value = torch.nn.Parameter(torch.randn(d_v, d))

W_query.shape, W_key.shape, W_value.shape

(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

### Computing the Unnormalized Attention Weights
We pick the second words $x^{(2)}$ as example

In [8]:
x_2 = embedded_sentence[1]
query_2 = W_query @ x_2
key_2 = W_key @ x_2
value_2 = W_value @ x_2

query_2.shape, key_2.shape, value_2.shape

(torch.Size([24]), torch.Size([24]), torch.Size([28]))

We can then generalize this to compute th remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights ω:

In [9]:
keys = (W_key @ embedded_sentence.T).T
values = (W_value @ embedded_sentence.T).T

keys.shape, values.shape

(torch.Size([6, 24]), torch.Size([6, 28]))

We can then generalize this to compute th remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights ω

As illustrated in the figure above, we compute $w_{ij}$
 as the dot product between the query and key sequences, $ω_{ij}=q^{(i)}^⊤k^{(j)}$



In [13]:
# Compute the unnormalized attention weights for the query and 5th input word
omega_24 = query_2.dot(keys[4])
omega_24

tensor(-98.1709, grad_fn=<DotBackward0>)

In [17]:
# For all tokens
omega_2 = query_2 @ keys.T
omega_2, omega_2.shape

(tensor([  83.1533,   95.5014, -100.8583,   63.5880,  -98.1709,    9.3997],
        grad_fn=<SqueezeBackward3>),
 torch.Size([6]))

In [18]:
# Calculate the
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k ** .5, dim = 0)
attention_weights_2

tensor([7.4329e-02, 9.2430e-01, 3.6185e-18, 1.3699e-03, 6.2628e-18, 2.1523e-08],
       grad_fn=<SoftmaxBackward0>)

In [19]:
context_vector_2 = attention_weights_2 @ values

context_vector_2.shape

torch.Size([28])

In [21]:
head = 3

# (3, 24, 16)
multihead_W_query = torch.nn.Parameter(torch.randn(head, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.randn(head, d_k, d))
# (3, 28, 16)
multihead_W_value = torch.nn.Parameter(torch.randn(head, d_v, d))

In [29]:
# q, k, v for x_2
multihead_query_2 = multihead_W_query @ x_2 # (3, 24)
multihead_key_2 = multihead_W_key @ x_2 # (3, 24)
multihead_value_2 = multihead_W_value @ x_2 # (3, 24)


# x_2 asks each other words, then we need to calculate the k, v for other tokens
# first, we need to expand the input sequence embeddings to the number of heads
stacked_inputs = embedded_sentence.T.repeat(head, 1, 1) # (3, 16, 6)


multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)# (3, 24, 6)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs) # (3, 28, 6)



In [33]:
# let x_2 asks every token -> unnormalized attention score
multihead_query_2.unsqueeze(1)

multihead_attention_unnormalized_score_2 =  torch.bmm(multihead_query_2.unsqueeze(dim = 1),
                                                      multihead_keys).squeeze() # (3, 6)

multihead_attention_normalized_score_2 = F.softmax(multihead_attention_unnormalized_score_2 / d_k
                                                   ** 0.5, dim = 1) # (3, 6)

torch.Size([3, 6])

In [39]:
multihead_context_score_2 = torch.bmm(multihead_attention_normalized_score_2.unsqueeze(1),
                                      multihead_values.permute(0, 2, 1)).squeeze()  # (3, 28)
multihead_context_score_2.shape

torch.Size([3, 28])