## This notebook is an implementation of transformers from scratch for learning purposes - Work in progress 

In [3]:
import torch
from math import sqrt
from pathlib import Path
import torch.nn as nn
# Decided to use the tokenizers library for BPE tokenization to understand how tokenization works under the hood
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder

In [9]:
class Embedding(torch.nn.Module):
    def __init__(self):
        self.bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
        self.bpe_tokenizer.pre_tokenizer = ByteLevel()
        self.bpe_tokenizer.decoder = ByteLevelDecoder()
        self.trainer = BpeTrainer(vocab_size = 50_000, min_frequency=2, special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"], 
                                  initial_alphabet=ByteLevel.alphabet())
    def train_tokenizer(self, sequences_files: list[str]):
        self.bpe_tokenizer.train(sequences_files, self.trainer)
        self.bpe_tokenizer.save("./tokenizer.json")
     
    def encode_word(self, word: str):
        output = self.bpe_tokenizer.encode(word)
        return output
    def embed(self, tokenized_sequence, d_model: int):
        ids = tokenized_sequence.ids
        return torch.nn.Embedding(self.bpe_tokenizer.get_vocab_size(), d_model)(torch.tensor(ids))
        
embedding = Embedding()
embedding.train_tokenizer(["./input_text.txt"])
embedded = embedding.encode_word("hello world!")
embedded = embedding.embed(embedded, 512)


In [41]:
embedded

tensor([[ 1.2551, -0.1232, -0.4561,  ...,  0.6302, -2.2152, -0.1945],
        [-0.5921,  0.5571, -0.4458,  ..., -3.0437,  1.0927,  0.3457],
        [ 0.9480,  0.2918,  1.0361,  ..., -0.5419, -1.7925, -0.2153],
        ...,
        [ 0.2529,  0.1792,  0.6107,  ...,  0.0608,  0.0175,  0.7894],
        [ 2.2248,  0.0859,  0.3144,  ...,  1.0642,  0.9147,  0.5261],
        [-1.7264, -0.5665,  0.1764,  ..., -1.3539,  1.6781,  0.8746]],
       grad_fn=<EmbeddingBackward0>)

In [78]:
# Here we make sure to specify the dtype when creating the ten  sors to avoid casting to float64 and having everything go slower
# Will add max_len_seq later to compare speed
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

    def forward(self, embedded):
        d_model = embedded.shape[1]
        seq_length = embedded.shape[0]
        pos = torch.arange(0, seq_length, device=embedded.device, dtype=embedded.dtype).unsqueeze(1) # (seq_length, 1)
        i = torch.arange(0, d_model, 2, device=embedded.device, dtype=embedded.dtype) # (d_model/2,)
        positional_encoding = torch.zeros_like(embedded)
        div_term = torch.exp(-(torch.log(torch.tensor(10_000.0, device=embedded.device, dtype=embedded.dtype)) * (i/d_model)))    
        positional_encoding[:, 0::2] = torch.sin(pos/div_term)
        positional_encoding[:, 1::2] = torch.cos(pos/div_term)
        return embedded + positional_encoding

In [11]:
encoded_input = positional_encoding(embedded)

In [12]:
encoded_input.shape

torch.Size([11, 512])

In [None]:
class AttentionLayer(torch.nn.Module):
    def __init__(self, dim_k, dim_q, dim_v, dim_d):
        super().__init__()
        self.W_Q = torch.nn.Parameter(torch.randn(dim_d, dim_q))
        self.W_K = torch.nn.Parameter(torch.randn(dim_d, dim_k))
        self.W_V = torch.nn.Parameter(torch.randn(dim_d, dim_v))
        self.dim_d = dim_d
    def forward(self, x):
        Q = x @self.W_Q
        K = x @self.W_K
        V = x @self.W_V
        attention_scores = torch.softmax(Q@K.T/sqrt(self.dim_d))
        attention = attention_scores @ V
        return attention

In [None]:
class MultiHeadAttentionLayer(torch.nn.Module):
    def __init__(self, dim_k, dim_q, dim_v, dim_d, num_heads):
        super().__init__()
        self.W_Q = nn.Parameter(torch.randn(num_heads, dim_d, dim_q))
        self.W_K = nn.Parameter(torch.randn(num_heads, dim_d, dim_k))
        self.W_V = nn.Parameter(torch.randn(num_heads, dim_d, dim_v))
        self.dim_d = dim_d
        self.scale = sqrt(dim_d)
    def forward(self, x):
        Q = x @self.W_Q
        K = x @self.W_K
        V = x @self.W_V
        attention_scores = torch.softmax(Q@K.transpose(-2, -1)/self.scale, dim=-1)
        attention = attention_scores @ V # (num_heads, seq_length, dim_v)
        attention = attention.transpose(0, 1) # (seq_length, num_heads, dim_v)
        flattened_attention = attention.reshape(attention.shape[0], -1) 
        projected_attention = flattened_attention @ nn.Parameter(torch.randn(flattened_attention.shape[1], 512))
        return projected_attention, K, V

In [None]:

multi_head_attention_layer = MultiHeadAttentionLayer(dim_k=512, dim_q=512, dim_v=512, dim_d=512, num_heads=8)
output_layer = multi_head_attention_layer.forward(encoded_input)

In [29]:
print(output_layer.shape)

torch.Size([11, 512])


In [None]:
class ResidualLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, sublayer_output, previous_layer_output):
        return sublayer_output + previous_layer_output



# Post Normalization is used here, like in the original Transformer paper
class LayerAddAndNormLayer(nn.Module):
    def __init__(self, dim_d, e = 1e-6):
        super().__init__()
        self.epsilon = e
        self.dim_d = dim_d
        self.gamma = torch.nn.Parameter(torch.ones(self.dim_d))
        self.beta = torch.nn.Parameter(torch.zeros(self.dim_d))
        self.residual = ResidualLayer()
    def normalization(self, x):
        mean = torch.mean(x, dim = -1, keepdim=True)
        variance = torch.var(x, dim = -1, unbiased=False, keepdim=True)
        normalized_x = (x-mean)/(torch.sqrt(variance+self.epsilon))
        normalized_x = self.gamma * normalized_x + self.beta
        return normalized_x 
    def forward(self, sublayer_output, previous_layer_output):
        return self.normalization(self.residual(previous_layer_output, sublayer_output))


In [None]:
class FullyConnectedLayer(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.hidden_dims = hidden_dims
        self.out_dim = out_dim
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.layers = nn.Sequential(*layers)
    def forward(self, x):
        return self.layers(x)

In [None]:
class MaskedMultiHeadAttentionLayer(torch.nn.Module):
    def __init__(self, dim_d, num_heads):
        super().__init__()
        self.dim_k = self.dim_v = self.dim_q = dim_d // num_heads
        self.W_Q = nn.Parameter(torch.randn(num_heads, dim_d, self.dim_q))
        self.W_K = nn.Parameter(torch.randn(num_heads, dim_d, self.dim_k))
        self.W_V = nn.Parameter(torch.randn(num_heads, dim_d, self.dim_v))
        self.dim_d = dim_d
        self.scale = sqrt(dim_d)
        self.W = nn.Parameter(torch.randn(num_heads * self.dim_v, 512))
    def forward(self, x):
        Q = x @self.W_Q
        K = x @self.W_K
        V = x @self.W_V
        temp = Q@K.transpose(-2, -1)/self.scale
        j = torch.arange(0, x.shape[0], device=x.device).unsqueeze(0)
        i = torch.arange(0, x.shape[0], device=x.device).unsqueeze(1)
        temp = temp.masked_fill(i<j, float('-inf'))
        attention_scores = torch.softmax(temp, dim=-1)
        attention = attention_scores @ V # (num_heads, seq_length, dim_v)
        attention = attention.transpose(0, 1) # (seq_length, num_heads, dim_v)
        flattened_attention = attention.reshape(attention.shape[0], -1) 
        projected_attention = flattened_attention @ self.W
        return projected_attention
    

class MultiHeadCrossAttentionLayer(nn.Module):
    def __init__(self, dim_d=512, num_heads=8):
        super().__init__()
        self.dim_k = self.dim_v = self.dim_q = dim_d // num_heads
        self.W_Q = nn.Parameter(torch.randn(num_heads, dim_d, self.dim_q))
        self.W_K = nn.Parameter(torch.randn(num_heads, dim_d, self.dim_k))
        self.W_V = nn.Parameter(torch.randn(num_heads, dim_d, self.dim_v))
        self.dim_d = dim_d
        self.scale = sqrt(dim_d)
        self.W = nn.Parameter(torch.randn(num_heads * self.dim_v, 512))
    def forward(self, x, encoder_output):
        Q = x @self.W_Q
        K = encoder_output @self.W_K
        V = encoder_output @self.W_V
        attention_scores = torch.softmax(Q@K.transpose(-2,-1)/self.scale, dim=-1)
        attention = attention_scores @ V # (num_heads, seq_length, dim_v)
        attention = attention.transpose(0, 1) # (seq_length, num_heads, dim_v)
        flattened_attention = attention.reshape(attention.shape[0], -1)
        projected_attention = flattened_attention @ self.W
        return projected_attention

In [None]:
# self.embedding = Embedding()
# self.positional_encoding = PositionalEncoding(d_model=dim_d)
class EncoderBlock(nn.Module):
    def __init__(self, dim_d = 512, num_heads=8):
        super().__init__()
        self.multi_head_attention = MultiHeadAttentionLayer(dim_d=dim_d, num_heads=num_heads)
        self.fully_connected_layer = FullyConnectedLayer(in_dim = dim_d, hidden_dims = [2048], out_dim = dim_d)
        self.residual = ResidualLayer()
        self.layer_add_and_norm_1 = LayerAddAndNormLayer(dim_d=dim_d)
        self.layer_add_and_norm_2 = LayerAddAndNormLayer(dim_d=dim_d)
    def forward(self, x):
        attention_output, _, _ = self.multi_head_attention(x)
        x = self.layer_add_and_norm_1(attention_output, x)
        fc_output = self.fully_connected_layer(x)
        x = self.layer_add_and_norm_2(fc_output, x)
        return x
    
class DecoderBlock(nn.Module):
    def __init__(self, dim_d = 512, num_heads=8):
        super().__init__()
        self.masked_multi_head_attention = MaskedMultiHeadAttentionLayer(dim_d=dim_d, num_heads=num_heads)
        self.layer_add_and_norm_1 = LayerAddAndNormLayer(dim_d=dim_d)
        self.layer_add_and_norm_2 = LayerAddAndNormLayer(dim_d=dim_d)
        self.layer_add_and_norm_3 = LayerAddAndNormLayer(dim_d=dim_d)
        self.multi_head_cross_attention = MultiHeadCrossAttentionLayer( dim_d=dim_d, num_heads=num_heads)
        self.fully_connected_layer = FullyConnectedLayer(in_dim = dim_d, hidden_dims = [2048], out_dim = dim_d)
    def forward(self, x, encoder_output):
        masked_attention_output = self.masked_multi_head_attention(x)
        x = self.layer_add_and_norm_1(masked_attention_output, x)
        cross_attention_output = self.multi_head_cross_attention(x, encoder_output)
        x = self.layer_add_and_norm_2(cross_attention_output, x)
        fc_output = self.fully_connected_layer(x)
        x = self.layer_add_and_norm_3(fc_output, x)
        return x
        
 

In [None]:
class Transformer(nn.Module):
    def __init__(self, num_encoder_layers, num_heads, dim_d, num_decoder_layers, vocab_size):
        super().__init__()
        self.dim_d = dim_d
        self.embedding = Embedding()
        self.position_encoding = PositionalEncoding(d_model=dim_d)
        self.encoder_layers = nn.Sequential(*[EncoderBlock(dim_d=dim_d, num_heads=num_heads) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList([DecoderBlock(dim_d=dim_d, num_heads=num_heads) for _ in range(num_decoder_layers)])
        self.linear_layer = nn.Linear(in_features=dim_d, out_features=vocab_size)
        self.softmax_layer = nn.Softmax(dim = -1)
    def forward(self, input_sequence, target_sequence):
        embedded_input = self.embedding.embed(input_sequence, self.dim_d)
        encoded_input = self.position_encoding(embedded_input)
        
        embedded_output = self.embedding.embed(target_sequence, self.dim_d)
        encoded_target = self.position_encoding(embedded_output)

        encoder_output = self.encoder_layers(encoded_input)

        decoder_output  = encoded_target
        for layer in self.decoder_layers:
            decoder_output = layer(decoder_output , encoder_output)

        output_linear = self.linear_layer(decoder_output)
        final_output = self.softmax_layer(output_linear)
        return final_output


  