# Self-Attention with trainable weights

Each input embedded token is mapped to a key and value 
Each key is used once as a query, to calculate the dot product like before 

## need to do this for all input tokens

x = [[1,2,3],[4,5,6],...] shape = data_size x no_embedding_dims

X[1] = [1, 2, 3] shape = 1 x no_embedding_dims 

3 weight layers map X[1] to Wq, Wk and Wv
Wk = Wkey shape = no_of_embedding_dims x output_dims (eg. 2)
Wv = Wvalue shape = ^ 
Wq = Query shape = ^

k = key
k[1] = X[1] * Wk shape = data_size x output_dims
v = value
v[1] = X[1] * Wv shape = ^
q[1] = ..

In [10]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [37]:
x2 = inputs[1]
x2.shape

torch.Size([3])

In [38]:
torch.manual_seed(123)
weight_query = torch.nn.Parameter(torch.rand(x2.shape[0], 2), requires_grad=False)
weight_key = torch.nn.Parameter(torch.rand(x2.shape[0], 2), requires_grad=False)
weight_value = torch.nn.Parameter(torch.rand(x2.shape[0], 2), requires_grad=False)

weight_key

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])

In [39]:
query2 = x2 @ weight_query
key2 = x2 @ weight_key
value2 = x2 @ weight_value

query2

tensor([0.4306, 1.4551])

for all values do

In [40]:
keys = inputs @ weight_key
values = inputs @ weight_value

## attention score 

attention_score_22 = dot product between the query vector of the token 2 and key of the iter-ed token 

In [41]:
attention_scores_22 = query2.dot(keys[1])
attention_scores_22

tensor(1.8524)

In [44]:
attention_scores_2 = query2 @ keys.T
attention_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

weighted sum with attention weight and value builds context vector 

# SelfAttentionV1

In [45]:
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()

        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)

        context_vectors = attention_weights @ values
        return context_vectors

In [47]:
torch.manual_seed(123)

selfAttention = SelfAttentionV1(3, 2)
selfAttention(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

# SelfAttentionV2

In [50]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)

        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)

        context_vectors = attention_weights @ values 
        return context_vectors

In [51]:
torch.manual_seed(789)
selfAttention = SelfAttentionV2(3, 2)
selfAttention(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)