In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F

<h2> Word Embeddings </h2>

<h2> Tokenization </h2>

In [73]:
class Tokenizer:

    def __init__(self):

        self.vocab = None
        self.max_len = None
        self.vocab_size = None

    def preprocess(self, text):

        # remove all non alphabetic characters
        text = ''.join(e for e in text if e.isalnum() or e.isspace())

        # convert to lowercase
        text = text.lower()

        return text

    def generate_vocab(self, texts):

        text_preprocessed = [self.preprocess(text) for text in texts]

        self.max_len = max([len(seq.split()) for seq in text_preprocessed])

        words = " ".join(text_preprocessed).split()

        vocab = list(set(words))

        vocab.sort()

        self.vocab = vocab

        self.vocab_size = len(self.vocab) + 1
    
    def tokenize(self, texts):

        total_tokens = []

        for text in texts:
    
            text_preprocessed = self.preprocess(text)

            words = text_preprocessed.split()

            tokens = []

            for word in words:
                tokens.append(self.vocab.index(word))
            
            tokens += [self.vocab_size-1 for _ in range(self.max_len - len(tokens))]

            total_tokens.append(tokens)

        return total_tokens

texts = [
    "I am a student", 
    "I am a teacher", 
    "I am a doctor", 
    "I am a programmer", 
    "The quick brown fox jumps over the lazy dog"
]

tokenizer = Tokenizer()

tokenizer.generate_vocab(texts)

print(tokenizer.vocab)

tokens = tokenizer.tokenize(texts)

print(tokens)

print(tokenizer.max_len, [len(t) for t in tokens])

['a', 'am', 'brown', 'doctor', 'dog', 'fox', 'i', 'jumps', 'lazy', 'over', 'programmer', 'quick', 'student', 'teacher', 'the']
[[6, 1, 0, 12, 15, 15, 15, 15, 15], [6, 1, 0, 13, 15, 15, 15, 15, 15], [6, 1, 0, 3, 15, 15, 15, 15, 15], [6, 1, 0, 10, 15, 15, 15, 15, 15], [14, 11, 2, 5, 7, 9, 14, 8, 4]]
9 [9, 9, 9, 9, 9]


<h2> Embedding Layer </h2>

In [74]:
class EmbeddingLayer(nn.Module):

    def __init__(self, vocab_size, embedding_dim):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):

        return self.embedding(x)
    

embedding_dim = 10

print(tokenizer.vocab_size)
print(tokens)

embedding_layer = EmbeddingLayer(tokenizer.vocab_size, embedding_dim)

tokens = torch.Tensor(tokens).long()

embeddings = embedding_layer(tokens)

print(embeddings)

16
[[6, 1, 0, 12, 15, 15, 15, 15, 15], [6, 1, 0, 13, 15, 15, 15, 15, 15], [6, 1, 0, 3, 15, 15, 15, 15, 15], [6, 1, 0, 10, 15, 15, 15, 15, 15], [14, 11, 2, 5, 7, 9, 14, 8, 4]]
tensor([[[ 0.3276,  0.2250,  0.4126,  0.3643,  0.1932,  1.5503,  0.0839,
           0.4204,  0.6827, -0.4537],
         [ 0.0082, -0.6683,  1.0538,  1.8405,  0.2528, -1.9614,  0.0260,
          -1.1189, -0.6230, -0.2543],
         [-0.5169,  0.2792, -1.3118,  0.3313, -0.7539, -0.0291,  1.9711,
          -0.2230,  0.0837, -1.6425],
         [ 0.6643, -1.3734, -0.9785, -1.9924, -0.5671,  1.0187,  0.4730,
           0.8144, -1.4467,  0.5170],
         [ 1.0626, -0.1594, -0.7191,  0.0479,  2.0683,  2.2863,  0.9569,
          -0.9163, -0.3195, -1.4508],
         [ 1.0626, -0.1594, -0.7191,  0.0479,  2.0683,  2.2863,  0.9569,
          -0.9163, -0.3195, -1.4508],
         [ 1.0626, -0.1594, -0.7191,  0.0479,  2.0683,  2.2863,  0.9569,
          -0.9163, -0.3195, -1.4508],
         [ 1.0626, -0.1594, -0.7191,  0.0479,  2

<h2> The Attention Layer </h2>

In [75]:
class Attention(nn.Module):

    def __init__(self, d_model):
        super().__init__()

        self.linear = nn.Linear(d_model, 3*d_model)

    def forward(self, x):

        q, k, v = self.linear(x).chunk(3, dim=-1)

        # attn = softmax(Q @ K.T) V
        attn = torch.einsum("bnd,bkd->bnk", q, k)
        attn = F.softmax(attn, dim=-1)
        attn = attn @ v

        return attn
    

attention_layer = Attention(embedding_dim)
attn_logits = attention_layer(embeddings)
embeddings.shape, attn_logits.shape

(torch.Size([5, 9, 10]), torch.Size([5, 9, 10]))