In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
# setting fix seed
torch.random.manual_seed(seed=1234)

# data
text = "Hi! My name is Matheus."
tokens = [13347, 0, 3092, 836, 374, 7011, 383, 355, 13] # ["Hi", "!", " My", " name", " is", " Matheus", "."]

# parameters
vocab_size = max(tokens) + 1 # number of classes to predict
emb_dim = 5 # size of vector representation of each token
context = len(tokens) # context size of model

# layers
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim)
query = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=False)
key = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=False)
value = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=False)

# mask filter
ones = torch.ones(size=[context, context], dtype=torch.float)
mask = torch.tril(input=ones)

# forward pass
t_tokens = torch.tensor(data=tokens).unsqueeze(dim=0) # [9] -> [1,9]
x = embedding(t_tokens) # [1,9] -> [1,9,50] embedding vectors

B, T, C = x.size()
Q = query(x) # [1,9,50] -> [1,9,50]
K = key(x) # [1,9,50] -> [1,9,50]
V = value(x) # [1,9,50] -> [1,9,50]

QK = Q @ K.transpose(-2, -1) * C**-0.5 # [1,9,50] @ [1,50,9] -> [1,9,9] attention matrix
attention = QK.masked_fill(mask[:T,:T] == 0, float("-inf")) # applying mask
attention = F.softmax(input=attention, dim=-1) # [1,9,9] normalizing to 0 and 1 in embedding dimension

out = attention @ V # [1,9,9] @ [1,9,50] -> [1,9,50]

print(out.size()) # new data representation

torch.Size([1, 9, 5])
