In [18]:
import torch
import torch.nn as nn

In [19]:
data = [[0.43, 0.15, 0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55]]
inputs = torch.tensor(data)

## A simple self-attention mechanishm without trainable weights

In [20]:
attn_scores = torch.zeros(6,6)
attn_weights = torch.zeros(6,6)

attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
context_vectors = attn_weights @ inputs

print(f"Context vector for all input tokens:\n {context_vectors}")

Context vector for all input tokens:
 tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Implementing self-attention weights with trainable weights

Initializing three weights: W_query, W_key, W_value

In [21]:
d_in = inputs.shape[1]
d_out = 2

In [31]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.w_query = nn.Parameter(torch.rand(d_in, d_out))
        self.w_key = nn.Parameter(torch.rand(d_in, d_out))
        self.w_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x):
        queries = x @ self.w_query
        keys = x @ self.w_key
        values = x @ self.w_value
        attention_scores = queries @ keys.T
        attenton_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attenton_weights @ values
        return context_vector

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        keys = self.w_key(x)
        queries = self.w_query(x)
        values = self.w_value(x)
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ values
        return context_vector    

In [40]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(f"Forward pass on inputs in version 1:\n {sa_v1.forward(inputs)}")

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(f"Forward pass on inputs in version 2:\n {sa_v2.forward(inputs)}")

Forward pass on inputs in version 1:
 tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)
Forward pass on inputs in version 2:
 tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
