In [1]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)
print(inputs.shape)

torch.Size([6, 3])


### W_Q, W_V, W_K are trainable parameters to learn best weights for constructing context words

In [3]:
d_in = inputs.shape[1]
d_out = 2

W_q = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_k = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [4]:
print("Input : ", inputs.shape)
print("Query : ", W_q.shape)
print("Key : ", W_k.shape)
print("Value : ", W_v.shape)


Input :  torch.Size([6, 3])
Query :  torch.Size([3, 2])
Key :  torch.Size([3, 2])
Value :  torch.Size([3, 2])


In [6]:
Q = inputs @  W_q
K = inputs @  W_k
V = inputs @  W_v

print(Q.shape)
print(K.shape)
print(V.shape)

torch.Size([6, 2])
torch.Size([6, 2])
torch.Size([6, 2])


In [10]:
d_k = K.shape[1]
score = Q @ K.T
attention_weights = torch.softmax(score/d_k ** 0.5, dim=-1)
print(score.shape, attention_weights.shape)

torch.Size([6, 6]) torch.Size([6, 6])


In [11]:
context_vector = attention_weights @ V

In [18]:
torch.triu(torch.ones(6, 6), diagonal=1)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [43]:
score = Q @ K.T
mask = torch.tril(torch.ones(6, 6))
score = score.masked_fill(mask == 0, -torch.inf)
attention_weights = torch.softmax(score / K.shape[-1] ** 0.5, dim=-1)
attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3739, 0.6261, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2279, 0.3787, 0.3934, 0.0000, 0.0000, 0.0000],
        [0.2112, 0.2811, 0.2870, 0.2207, 0.0000, 0.0000],
        [0.1642, 0.2050, 0.2084, 0.1688, 0.2537, 0.0000],
        [0.1257, 0.1841, 0.1894, 0.1325, 0.2650, 0.1033]])

In [44]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3

torch.Size([2, 6, 3])


In [58]:
W_q = torch.nn.Linear(d_in, d_out, bias=False)
W_k = torch.nn.Linear(d_in, d_out, bias=False)
W_v = torch.nn.Linear(d_in, d_out, bias=False)

b, num_tokens, d_in = batch.shape

keys = W_k(batch)
queries = W_q(batch)
values = W_v(batch)
attn_scores = queries @ keys.transpose(1, 2)
mask = torch.tril(torch.ones(num_tokens, num_tokens))
attn_scores = attn_scores.masked_fill(mask == 0, -torch.inf)



torch.Size([2, 6, 6])
