In [1]:
import torch

## Simple Self Attention Mechanism without trainable weights

In [2]:
inputs = torch.tensor(
    [
        [0.43, 0.15,0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55]
    ]
)

In [3]:
atten_score = inputs@inputs.T
atten_weights = torch.softmax(atten_score, dim=1)
all_context_vecs = atten_weights @ inputs
all_context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## Implementing self-attention with trainable weights

### Computing attention weights step by step

In [4]:
x_2  = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [6]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out))

In [7]:
query_2 = x_2 @ W_query
query_2

tensor([0.4306, 1.4551], grad_fn=<SqueezeBackward4>)

In [8]:
keys = inputs @ W_key
value = inputs @ W_value

keys.shape

torch.Size([6, 2])

In [9]:
keys

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]], grad_fn=<MmBackward0>)

In [10]:
keys_2 = keys[1]
atten_score_22 = torch.dot(query_2, keys_2)

In [11]:
atten_score_22

tensor(1.8524, grad_fn=<DotBackward0>)

In [12]:
atten_score_22 = query_2 @ keys.T
atten_score_22

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
       grad_fn=<SqueezeBackward4>)