In [3]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [11]:
query = inputs[1]
attn_score_2 = torch.empty(inputs.shape[0])
for i, x_n in enumerate(inputs):
    attn_score_2[i] = torch.dot(query, x_n)

print(attn_score_2)
print(torch.softmax(attn_score_2, dim=0))

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [15]:
attn_weight = torch.softmax(inputs @ inputs.T, dim=-1)
all_context_vectors = attn_weight @ inputs
print(all_context_vectors)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [20]:
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.query = torch.nn.Linear(input_dim, output_dim)
        self.key = torch.nn.Linear(input_dim, output_dim)
        self.value = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        attn_score = query @ key.T
        attn_weights = torch.softmax(attn_score / key.shape[-1]**0.5, dim=-1)
        context_vector = attn_weights @ value
        return context_vector

In [21]:
torch.manual_seed(789)
d_in = 3
d_out = 2
attention = SelfAttention(d_in, d_out)

attention(inputs)

tensor([[-0.3427, -0.2720],
        [-0.3439, -0.2692],
        [-0.3438, -0.2693],
        [-0.3417, -0.2728],
        [-0.3412, -0.2736],
        [-0.3426, -0.2714]], grad_fn=<MmBackward0>)