<a href="https://colab.research.google.com/github/NoCodeProgram/deepLearning/blob/main/transformer/multiheadAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/NoCodeProgram/deepLearning.git

Cloning into 'deepLearning'...
remote: Enumerating objects: 254, done.[K
remote: Counting objects: 100% (115/115), done.[K
remote: Compressing objects: 100% (113/113), done.[K
remote: Total 254 (delta 42), reused 0 (delta 0), pack-reused 139[K
Receiving objects: 100% (254/254), 12.36 MiB | 20.38 MiB/s, done.
Resolving deltas: 100% (78/78), done.


In [2]:
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# Read the text file
with open('deepLearning/transformer/shakespeare.txt', 'r') as file:
    text = file.read()

# Tokenize the text  (this is very simple tokenizer, in reality you would use a more advanced one)
tokenizer = get_tokenizer('basic_english')
tokens = tokenizer(text)
unique_tokens = set(tokens)


In [3]:
stoi = { s:i for i,s in enumerate(unique_tokens)}
itos = { i:s for i,s in enumerate(unique_tokens)}
print(stoi)
print(itos)

vocab_size = len(unique_tokens)
print(vocab_size)

3129


In [4]:
sentence = "i love you all"
indices = [stoi[word] for word in sentence.split()]
print(indices)

import torch.nn as nn

embedding_dim = 20
embedding = nn.Embedding(vocab_size, embedding_dim)

embedded_sentence = embedding(torch.tensor(indices))
print(embedded_sentence)


[1400, 658, 321, 2661]
tensor([[ 0.0925,  0.0993, -0.2679, -0.2897,  0.6791, -1.3325, -1.4117, -0.9931,
         -0.2453, -0.3506,  1.8642,  0.6862, -0.7704, -0.1396, -1.7060,  1.4248,
          1.4760, -0.6616, -0.3940,  1.2568],
        [ 0.4818, -1.4555,  0.0211, -1.7342,  0.1788, -0.0343, -0.7334,  0.3300,
         -0.7932, -1.3668,  0.3013,  0.5729, -0.9761,  0.6527,  0.1741, -0.1306,
         -0.4450,  2.0865, -0.5614, -1.0796],
        [-0.2783, -0.8294, -0.0323, -0.1246,  0.6529,  0.6033, -0.1477, -0.2540,
          1.6059,  1.4985, -0.4159,  0.5053, -0.4023, -0.3188, -0.3874, -1.4405,
         -1.8073, -0.6821,  0.0633, -0.1453],
        [-0.0342, -0.1321,  0.2908,  1.5125,  1.0318, -0.5913,  0.1666, -0.5430,
         -0.5951, -0.6156,  0.9003, -0.0048,  1.9140, -0.8457, -2.3570, -0.3435,
          2.4975, -1.0575,  0.2997, -0.4541]], grad_fn=<EmbeddingBackward0>)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, atten_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, atten_dim, bias=False)
        self.key = nn.Linear(embed_dim, atten_dim, bias=False)
        self.value = nn.Linear(embed_dim, atten_dim, bias=False)

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        scores = torch.matmul(query, key.transpose(-2, -1))
        scores = scores / key.size(-1)**0.5

        attention_weights = F.softmax(scores, dim=-1)
        weighted_values = torch.matmul(attention_weights, value)

        return weighted_values

In [6]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        attention_dim = embed_dim // num_heads
        self.attentions = nn.ModuleList([SelfAttention(embed_dim, attention_dim) for _ in range(num_heads)])
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        head_outputs = []
        for attention in self.attentions:
            head_output = attention(x)
            head_outputs.append(head_output)

        concatenated_heads = torch.cat(head_outputs, dim=-1)
        print("concatenated_heads", concatenated_heads.shape)
        output = self.fc(concatenated_heads)
        print("output", output.shape)
        return output


In [7]:
num_heads = 4

output = MultiheadAttention(embedding_dim, num_heads)(embedded_sentence)
print("output shape", output.shape)


concatenated_heads torch.Size([4, 20])
output torch.Size([4, 20])
output shape torch.Size([4, 20])
