In [1]:
import torch

#Learning simple self-attention

inputs = torch.tensor([
    [0.43, 0.15, 0.89],
    [0.55, 0.87, 0.66],
    [0.57, 0.85, 0.64],
    [0.22, 0.58, 0.33],
    [0.77, 0.25, 0.10],
    [0.05, 0.80, 0.55]
])

#Attention weights for the second input - [0.55, 0.87, 0.66], the q,k,v here is the input embedding itself
#Attention weights are - Q.K

query_input = inputs[1]
attention_weight_2 = torch.zeros(len(inputs))

for i,x_i in enumerate(inputs):
    attention_weight_2[i] = torch.dot(query_input, x_i)

print(f"The attention weight for input 2 is: {attention_weight_2}")

The attention weight for input 2 is: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [2]:
#Moving on to attention scores = normalized attention weights whose sum is equal to 1
attention_scores_2 = torch.softmax(attention_weight_2, dim=0)
print(f"The attention scores are: {attention_scores_2}")

The attention scores are: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [3]:
#Let's move on to the context vector for input 2.
#Context vector is computed as a combination of all input vectors weighted with respect to input 2
#attention scores = softmax(Q.K), context vector = attention scores . V
attention_scores_2.unsqueeze(0)
context_vector_2 = attention_scores_2 @ inputs
print(f"Shape of the context vector is: {context_vector_2.shape}")

#This can also be done in the following way 
context_vector_2_sim = torch.zeros(query_input.shape)
attention_scores_2.squeeze()

for i,x_i in enumerate(inputs):
    context_vector_2_sim += attention_scores_2[i]*x_i

print(f"Shape of the contenxt_vector without matmul is: {context_vector_2_sim.shape}")
    

Shape of the context vector is: torch.Size([3])
Shape of the contenxt_vector without matmul is: torch.Size([3])


In [5]:
import torch.nn as nn

dim_in = inputs.shape[1]
dim_out = 2

#Compute self-attention with weights 
class SelfAttention_V1(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_query = nn.Linear(dim_in, dim_out, bias=False)
        self.W_key = nn.Linear(dim_in, dim_out, bias=False)
        self.W_value = nn.Linear(dim_in, dim_out, bias=False)

    def forward(self, input):
        query = self.W_query(input)
        key = self.W_key(input)
        value = self.W_value(input)

        d_k = key.shape[1]

        attention_scores = query @ key.T
        print(f"Attention scores are: {attention_scores}")
        attention_weights = torch.softmax(attention_scores/d_k**0.5, dim=-1)
        print(f"Attention scores when normalized ----> Attention weights: {attention_weights}")
        context_vector = attention_weights @ value

        return context_vector
    
selfattention_v1 = SelfAttention_V1()
context_vector = selfattention_v1(inputs)
print(f"The context vector is: {context_vector}")

Attention scores are: tensor([[-0.0134, -0.1257, -0.1229, -0.0847, -0.0387, -0.1151],
        [-0.1208, -0.1911, -0.1837, -0.1217, -0.0004, -0.1987],
        [-0.1140, -0.1839, -0.1770, -0.1173, -0.0016, -0.1907],
        [-0.0930, -0.1204, -0.1152, -0.0755,  0.0096, -0.1291],
        [ 0.0394, -0.0039, -0.0051, -0.0054, -0.0244,  0.0058],
        [-0.1613, -0.1956, -0.1869, -0.1219,  0.0216, -0.2122]],
       grad_fn=<MmBackward0>)
Attention scores when normalized ----> Attention weights: tensor([[0.1750, 0.1617, 0.1620, 0.1664, 0.1719, 0.1629],
        [0.1683, 0.1601, 0.1609, 0.1682, 0.1832, 0.1593],
        [0.1685, 0.1603, 0.1611, 0.1681, 0.1824, 0.1596],
        [0.1659, 0.1627, 0.1633, 0.1680, 0.1784, 0.1617],
        [0.1712, 0.1661, 0.1659, 0.1659, 0.1637, 0.1672],
        [0.1642, 0.1603, 0.1613, 0.1689, 0.1869, 0.1584]],
       grad_fn=<SoftmaxBackward0>)
The context vector is: tensor([[ 0.0274, -0.0842],
        [ 0.0325, -0.0891],
        [ 0.0322, -0.0888],
        [ 0.03