In [13]:
import torch
import torch.nn as nn
def softmax_native(input):
    max_val = torch.max(input)
    val = input - max_val
    val = torch.exp(val)
    val = val/torch.sum(val)
    return val

class Simple_Attention(nn.Module):
    def __init__(self, d_x, d_y, qkv_bias=False):
        super().__init__()
        self.q_w = nn.Linear(d_x, d_y, bias=qkv_bias)
        self.k_w = nn.Linear(d_x, d_y, bias=qkv_bias)
        self.v_w = nn.Linear(d_x, d_y, bias=qkv_bias)
    
    def forward(self, token_input):
        q_w = self.q_w(token_input)
        k_w = self.k_w(token_input)
        v_w = self.v_w(token_input)
        attention_weight = q_w @ k_w.T
        attention_score = torch.softmax(attention_weight/k_w.shape[-1]**2, dim=-1)
        context_vectors = attention_score @ v_w
        return context_vectors



In [14]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
d_x = inputs.shape[1]
d_y = 2

In [19]:
torch.manual_seed(789)
simple_attention = Simple_Attention(d_x, d_y)
context_vectors = simple_attention(inputs)
print(context_vectors)

tensor([[-0.0760,  0.0682],
        [-0.0763,  0.0679],
        [-0.0763,  0.0679],
        [-0.0767,  0.0672],
        [-0.0769,  0.0670],
        [-0.0765,  0.0675]], grad_fn=<MmBackward0>)
