In [2]:
import math
import torch
import torch.nn as nn

In [3]:
from torchtext.data import get_tokenizer

tokenizer = get_tokenizer("basic_english")
tokens = tokenizer("The animal did not cross the street because it was tired.")
tokens

['the',
 'animal',
 'did',
 'not',
 'cross',
 'the',
 'street',
 'because',
 'it',
 'was',
 'tired',
 '.']

In [4]:
from torchtext.vocab import build_vocab_from_iterator

vocab = build_vocab_from_iterator([tokens],)

1lines [00:00, 3960.63lines/s]


In [5]:
for token in tokens:
    print(f"{token} : {vocab[token]}")

the : 2
animal : 4
did : 7
not : 9
cross : 6
the : 2
street : 10
because : 5
it : 8
was : 12
tired : 11
. : 3


In [6]:
token_ids = [vocab[token] for token in tokens]
token_ids = torch.tensor(token_ids, dtype=torch.long)
token_ids

tensor([ 2,  4,  7,  9,  6,  2, 10,  5,  8, 12, 11,  3])

In [7]:
vocab_size = len(vocab)
embed_dim = 64

embedding = nn.Embedding(vocab_size, embed_dim)
X = embedding(token_ids)

In [8]:
X.shape

torch.Size([12, 64])

In [None]:
def positional_encoding(seq_len, d_model):
    PE = torch.zeros(seq_len, d_model)
    
    position = torch.arange(0, seq_len).unsqueeze(1)

    div_term = torch.exp(
        torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model)
    )

    PE[:, 0::2] = torch.sin(position * div_term)
    PE[:, 1::2] = torch.cos(position * div_term)
    
    return PE

In [10]:
X = X + positional_encoding(len(tokens), 64)
# X = X + positional_encoding(len(tokens), 64).unsqueeze(0)  # (batch_size, seq_len, d_model)

In [11]:
X.shape  

torch.Size([12, 64])

In [12]:
d_model = torch.tensor(64, dtype=torch.long)
num_heads = 8
d_k = d_v = d_model // num_heads
seq_len = X.shape[0]

In [13]:
W_Q = torch.randn(d_model, d_model) / torch.sqrt(d_model)
W_K = torch.randn(d_model, d_model) / torch.sqrt(d_model)
W_V = torch.randn(d_model, d_model) / torch.sqrt(d_model)

W_O = torch.randn(d_model, d_model) / torch.sqrt(d_model)

In [14]:
W_Q.shape

torch.Size([64, 64])

In [15]:
Q = torch.matmul(X, W_Q)
K = torch.matmul(X, W_K)
V = torch.matmul(X, W_V)

In [16]:
Q.shape

torch.Size([12, 64])

In [17]:
Q = Q.reshape(seq_len, num_heads, d_k).transpose(1,0)
K = K.reshape(seq_len, num_heads, d_k).transpose(1,0)
V = V.reshape(seq_len, num_heads, d_v).transpose(1,0)

In [18]:
print(Q.shape)

torch.Size([8, 12, 8])


In [19]:
softmax = nn.Softmax(dim=-1)

In [20]:
heads = []
attention_weights = []

for h in range(num_heads):
    scores = torch.matmul(Q[h], K[h].T)
    scores = scores / torch.sqrt(d_k)
    weights = softmax(scores)
    attention_weights.append(weights)
    head = torch.matmul(weights, V[h])
    heads.append(head)


In [45]:
print(type(attention_weights))
print(len(attention_weights))
print(type(attention_weights[0]))
print(attention_weights[0].shape)


<class 'list'>
8
<class 'torch.Tensor'>
torch.Size([12, 12])


In [25]:
heads[0].shape

torch.Size([12, 8])

In [26]:
concat = torch.concatenate(heads, axis=1)

In [27]:
concat.shape

torch.Size([12, 64])

In [28]:
output = torch.matmul(concat, W_O)

In [29]:
output.shape

torch.Size([12, 64])