# Self-Attention

## Embedding an Input Sentence

For simplicity, dictionary `dc` is restricted to words that occur in the input sentence, otherwise large in real world applications.

In [3]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i, s in enumerate(sorted(sentence.replace(',', '').split()))}
dc #This is the dictionary (vocabulary)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

Now need to convert the sentence into a sequence of integers.

In [4]:
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
sentence_int

tensor([0, 4, 5, 2, 1, 3])

Can use an embedding layer to cencode the inputs into a real-vector embedding using the integer-vector representation of the input sentence.

In [5]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16) # Here 6 words, and each word is represented by a 16 dimensional vector
embedded_sentence = embed(sentence_int).detach() #detach() is used to prevent the computation graph from being traced
print(embedded_sentence)
print(embedded_sentence.shape)



tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])


## Weight Matrices

Self attention uses $W_q$, $W_k$, and $W_v$ to project the embedded sentence into query, key, and value vectors.

The respective query, key and value sequences are obtained via matrix multiplication between the weight matrices W and the embedded inputs x:

Query sequence: $q^{(i)} = W_q x^{(i)}$ for $i ∈ [1,T]$

Key sequence: $k^{(i)} = W_k x^{(i)}$ for $i ∈ [1,T]$

Value sequence: $v^{(i)} = W_v x^{(i)}$ for $i ∈ [1,T]$

The index i refers to the token index position in the input sequence, which has length T.

Another important thing to note is that the projection matrices:

$W_q$ and $W_k$ have a shape of $d_k$ x $d$

$W_v$ has a shape of $d_v$ x $d$

$d$ = size of each word vector $x$ (here, $16$)

For this code, $d_q = d_k = 24$ and $d_v = 28$

In [6]:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 24, 24, 28

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

In [7]:
W_query.shape, W_key.shape, W_value.shape

(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

## Unnormalized Attention Weight Computation

Suppose we want to calculate attention-vector for the 2nd input element - then the second input element acts as the query.

In [8]:
x_2 = embedded_sentence[1] # get the 2nd vector among the 6 in embedded_sentence

query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


We can then generalize this to compute the remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights $\omega$

In [9]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print(keys.shape)
print(values.shape)

torch.Size([6, 24])
torch.Size([6, 28])


### Finally computing unnormalized attention weights

<img src="./assets/unnormalized.png" width="40%">

As illustrated in the figure above, we compute $\omega_{i,j}$ as the dot product between the query and key sequences, $\omega_{ij} = q^{(i)\top} k^{(j)}$.

In [10]:
# For example, calculating omega_24
print(f"Query_2 shape: {query_2.shape}")
print(f"Keys shape: {keys.shape}")
print(f"Keys[4] shape: {keys[4].shape}")
omega_24 = query_2.dot(keys[4])
print(omega_24)

Query_2 shape: torch.Size([24])
Keys shape: torch.Size([6, 24])
Keys[4] shape: torch.Size([24])
tensor(11.1466, grad_fn=<DotBackward0>)


In [11]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)


## Computing the Attention Scores

Next step - normalize the unnormalized attention weights, $\omega$, to obtain normalized attention weights, $\alpha$, by applying the softmax function to the unnormalized attention weights.

$\frac{1}{\sqrt(d_k)}$ is used to scale the attention scores to prevent the softmax function from becoming too saturated, which can lead to numerical instability.

<img src="./assets/normalized.png" width="40%">


In [12]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k ** 0.5, dim=0)
print(attention_weights_2)


tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)


## Computing the Context Vector

This is the attention weighted version of the original query input $x^{(2)}$, called context vector $z^{(2)}$.

It includes all the other input elements as its context via the attention weights.

<img src="./assets/context.png" width="45%">

In [13]:
values.shape

torch.Size([6, 28])

In [14]:
attention_weights_2.shape

torch.Size([6])

In [15]:
values.shape

torch.Size([6, 28])

In [16]:
context_vector_2 = attention_weights_2.matmul(values) # interpret attention_weights_2 as row vector here (shape 1 x 6)
print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)


The embedding size choice is arbitrary.

<img src="./assets/self_attention_overview.png" width="45%">

# Multi-Head Attention

Simply self attention but replicated and parallelized.

<img src="./assets/multi_headed_attention.png" width="45%">


Let number of heads $= h = 3$

In [17]:
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

In [18]:
x_2 = embedded_sentence[1]

multihead_query_2 = multihead_W_query.matmul(x_2)
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
print(multihead_query_2.shape)
print(multihead_key_2.shape)
print(multihead_value_2.shape)

torch.Size([3, 24])
torch.Size([3, 24])
torch.Size([3, 28])


The above was calculated only for $x_2$, but for $x_2$, need to calculate keys and values considering all.

In [19]:
multihead_W_key.shape, embedded_sentence.shape

(torch.Size([3, 24, 16]), torch.Size([6, 16]))

In [20]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
stacked_inputs.shape # each needs to be input to each head

torch.Size([3, 16, 6])

In [21]:
multihead_W_key.shape, multihead_W_value.shape

(torch.Size([3, 24, 16]), torch.Size([3, 28, 16]))

In [22]:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs) # use batch matrix multiplication
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])


The three attention heads are represented in their first dimension. 2nd and 3rd dimension refers to the embedding size and number of words respectively. We will swap dimensions purely for easier interpretability.

In [23]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])


Resulting in tensors with the same dimensional structure as the original input sequence, `embedded_sentence`.

In [24]:
x_2 = embedded_sentence[1]

multihead_query_2 = multihead_W_query.matmul(x_2)
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
print(multihead_query_2.shape)
print(multihead_key_2.shape)
print(multihead_value_2.shape)

torch.Size([3, 24])
torch.Size([3, 24])
torch.Size([3, 28])


In [25]:
multihead_query_2.shape

torch.Size([3, 24])

In [26]:
multihead_keys.shape

torch.Size([3, 6, 24])

In [27]:
multihead_keys[:, 4].shape

torch.Size([3, 24])

In [28]:
# omega_24_multihead -> expected shape is [3, 1]
omega_24_multihead = torch.sum(multihead_query_2.matmul(multihead_keys[:, 4].T), dim=-1, keepdim=True)
print(f"Shape of omega_24_multihead: {omega_24_multihead.shape}")
print(omega_24_multihead)

Shape of omega_24_multihead: torch.Size([3, 1])
tensor([[ 24.2878],
        [-19.1025],
        [ 12.4772]], grad_fn=<SumBackward1>)


TODO: Finish Multi-head attention implementation. 

We compute unscaled attention weights $\omega$ and attention weights $\alpha$, followed by softmax computation to obtain $h \times d_v (3 \times d_v)$ dimensional context vector $z$ for input element $x^{(2)}$

<img src="./assets/self_attention.png" width="45%">

While the value matrix $W_v$ is often chosen to have the same dimension as the query and key matrices (such as in PyTorch’s MultiHeadAttention class), we can select an arbitrary number size for the value dimensions.

# Cross Attention

<img src="./assets/cross_attention_transformer.png" width="45%">

In cross-attention, we mix or combine two different input sequences. In the case of the original transformer architecture above, that's the sequence returned by the encoder module on the left and the input sequence being processed by the decoder part on the right.

<img src="./assets/cross_attention.png" width="45%">

$x_1, x_2$ can have different numbers of elements, however their embedding dimensions must match. Queries usually come from an encoder, keys and values from decoder.

### Previous Code

In [29]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
print("embedded_sentence.shape:", embedded_sentence.shape)

d_q, d_k, d_v = 24, 24, 28

W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)

x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
print("query.shape", query_2.shape)

keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

embedded_sentence.shape: torch.Size([6, 16])
query.shape torch.Size([24])
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


### Part that changes

We now have a second input sequence. (We consider a second sentence with 8 elements instead of 6.)

In [30]:
embedded_sentence_2 = torch.rand(8, 16) # 2nd input sequence

keys = W_key.matmul(embedded_sentence_2.T).T
values = W_value.matmul(embedded_sentence_2.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([8, 24])
values.shape: torch.Size([8, 28])


Compared to self-attention, the keys and values now have 8 instead of 6 rows. Everything else stays the same.