## A simple self-attention mechanishm without trainable weights

In [2]:
import torch

In [3]:
data = [[0.43, 0.15, 0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55]]
inputs = torch.tensor(data)

Calculating context vector for second input vector.

In [4]:
query = inputs[1]
attn_score_2 = torch.zeros(inputs.shape[0])
attn_weights_2 = torch.zeros(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_score_2[i] = torch.dot(x_i, query)
    
attn_weights_2 = torch.softmax(attn_score_2, dim=0)
print(f"Attention scores: {attn_score_2}")
print(f"Attention weights: {attn_weights_2}")

context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor([0.4419, 0.6515, 0.5683])


Calculating context vector for all of the input vectors.

In [5]:
attn_scores = torch.zeros(6,6)
attn_weights = torch.zeros(6,6)

attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
context_vectors = attn_weights @ inputs

print(context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Implementing self-attention weights with trainable weights

In [6]:
x_2 = inputs[1]
d_in = inputs.shape[1] #input embedding size, d=3
d_out = 2 #output embedding size, d=2

## Initializing three weights, W_query, W_key, W_value

In [7]:
torch.manual_seed(123)
w_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [8]:
query_2 = x_2 @ w_query
key_2 = x_2 @ w_key
value_2 = x_2 @ w_value

In [9]:
query_2

tensor([0.4306, 1.4551])