# Embedding an Input Sentence

For simplicity, dictionary `dc` is restricted to words that occur in the input sentence, otherwise large in real world applications.

In [34]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
dc #This is the dictionary (vocabulary)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

Now need to convert the sentence into a sequence of integers.

In [35]:
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
sentence_int

tensor([0, 4, 5, 2, 1, 3])

Can use an embedding layer to cencode the inputs into a real-vector embedding using the integer-vector representation of the input sentence.

In [36]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16) # Here 6 words, and each word is represented by a 16 dimensional vector
embedded_sentence = embed(sentence_int).detach() #detach() is used to prevent the computation graph from being traced
print(embedded_sentence)
print(embedded_sentence.shape)



tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])


# Weight Matrices

Self attention uses $W_q$, $W_k$, and $W_v$ to project the embedded sentence into query, key, and value vectors.

The respective query, key and value sequences are obtained via matrix multiplication between the weight matrices W and the embedded inputs x:

Query sequence: $q^{(i)} = W_q x^{(i)}$ for $i ∈ [1,T]$

Key sequence: $k^{(i)} = W_k x^{(i)}$ for $i ∈ [1,T]$

Value sequence: $v^{(i)} = W_v x^{(i)}$ for $i ∈ [1,T]$

The index i refers to the token index position in the input sequence, which has length T.

Another important thing to note is that the projection matrices:

$W_q$ and $W_k$ have a shape of $d_k$ x $d$

$W_v$ has a shape of $d_v$ x $d$

$d$ = size of each word vector $x$ (here, $16$)

For this code, $d_q = d_k = 24$ and $d_v = 28$

In [37]:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 24, 24, 28

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

In [38]:
W_query.shape, W_key.shape, W_value.shape

(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

# Unnormalized Attention Weight Computation

Suppose we want to calculate attention-vector for the 2nd input element - then the second input element acts as the query.

In [39]:
x_2 = embedded_sentence[1] # get the 2nd vector among the 6 in embedded_sentence

query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


We can then generalize this to compute the remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights $\omega$

In [40]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print(keys.shape)
print(values.shape)

torch.Size([6, 24])
torch.Size([6, 28])


### Finally computing unnormalized attention weights

<img src="assets/image.png" width="40%">

As illustrated in the figure above, we compute $\omega_{i,j}$ as the dot product between the query and key sequences, $\omega_{ij} = q^{(i)\top} k^{(j)}$.

In [41]:
# For example, calculating omega_24
print(f"Query_2 shape: {query_2.shape}")
print(f"Keys shape: {keys.shape}")
print(f"Keys[4] shape: {keys[4].shape}")
omega_24 = query_2.dot(keys[4])
print(omega_24)

Query_2 shape: torch.Size([24])
Keys shape: torch.Size([6, 24])
Keys[4] shape: torch.Size([24])
tensor(11.1466, grad_fn=<DotBackward0>)


In [42]:
query_2

tensor([ 0.8982,  0.1030,  0.4428,  0.6328, -1.7003,  1.3489, -0.3082, -0.5900,
        -0.9257, -0.7688,  1.8828, -1.6065, -0.8011, -0.4114, -0.6116,  1.3902,
        -0.1460,  0.0244, -0.5577,  1.5972, -2.2190, -0.0214,  0.2002,  1.3752],
       grad_fn=<MvBackward0>)

In [43]:
keys[4]

tensor([ 1.1230,  1.3014,  0.7475,  0.2554, -2.3979, -0.9883, -1.1096, -1.3873,
         0.9164, -2.3064, -2.7067, -3.1677, -1.4181, -1.0188, -0.8252,  1.0323,
        -2.0219, -0.7073, -0.7288, -2.5216, -2.8680, -0.9919, -0.9798, -1.2008],
       grad_fn=<SelectBackward0>)

In [44]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)
