In [1]:
# Essentially multiple instances of causal attention running independently with their own weights for queries keys and values
# we will therefore get multiple resulting context vectors that we will concatenate in the end, so rows stay the same, n input tokens, and n output context vectors but the columns increase

# computationally expensive but performance gain is significant as for every different matrix another multiplication needs to be done which increase linearly, so we can instead generate with 3 big weight matrices in the start, each for query key and values, and then we just do 1 multiplication and split the results into parts

from nltk import word_tokenize
import gensim.downloader as api
import torch

glove_embeddings = api.load("glove-twitter-25")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device('cuda')
device

device(type='cuda')

In [2]:
sentence = "Your journey starts with one step Your journey starts with one step Your journey starts with one step starts with"
tokens = word_tokenize(sentence.lower())
encoded = torch.tensor(glove_embeddings[tokens])
# let num tokens be n
encoded.shape

torch.Size([20, 25])

In [3]:

class MultiHeadAttention(torch.nn.Module):
    # num of embeddings are n, n at max can be equal to context length
    def __init__(self, embed_dim, context_length, output_dim, num_heads, batch_size):
        super().__init__()

        self.output_dim = output_dim
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.batch_size = batch_size
        self.context_length = context_length

        # trainable layers for initializing queries, keys and values
        self.w_queries = torch.nn.Linear(embed_dim, output_dim*num_heads, bias=False) # embed_dim x output_dim*num_heads
        self.w_keys = torch.nn.Linear(embed_dim, output_dim*num_heads, bias=False)
        self.w_values = torch.nn.Linear(embed_dim, output_dim*num_heads, bias=False)

    def forward(self, embeddings):
        embeddings = embeddings[: self.context_length * (len(embeddings)//self.context_length)] # n x embed_dim
        embeddings = torch.reshape(embeddings, (-1, self.context_length, self.embed_dim))  # batches x context_len x embed_dim
        batches = len(embeddings)

        all_queries = self.w_queries(embeddings) # batches x context_len x output_dim*num_heads
        all_keys = self.w_keys(embeddings)
        all_values = self.w_values(embeddings)

        # split them in columns and then line them up in a tensor
        queries =  all_queries.view(batches, self.num_heads, self.context_length, self.output_dim)
        keys = all_keys.view(batches, self.num_heads, self.context_length, self.output_dim)
        values = all_values.view(batches, self.num_heads, self.context_length, self.output_dim)


        attention_scores = queries @ keys.transpose(2,3) # numheads x n x n

        # as this is causal attention now we will mask the upper right diagonal
        causal_mask_bool =  torch.triu(torch.ones_like(attention_scores), diagonal=1).bool() #triu stands for triangle up

        attention_scores.masked_fill_(causal_mask_bool, -torch.inf) # now a word only depend on the words before it, as the future dependencies are -ve infinity so after softmax the probabilities will be zero

        attention_weights = torch.softmax(attention_scores / self.output_dim**0.5, dim=2) # batches x numheads x context_len x context_len

        context_vectors = (attention_weights @ values).transpose(1, 2)
        # batches x context x numheads x output_dim <from> batches x numheads x context x output_dim
        context_vector = context_vectors.contiguous().view(batches*context_len, self.output_dim*self.num_heads)
        # batches x context x numheads*output_dim
        return context_vector

context_len = 4
out_dim = 3
heads = 2

attention_head = MultiHeadAttention(encoded.shape[-1], context_len, out_dim, heads, batch_size=6)
context = attention_head.forward(encoded)
encoded.shape , context.shape

(torch.Size([20, 25]), torch.Size([20, 6]))