<img src="imgs/image.png" width=500>

In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import seaborn as sns
import torchtext
import matplotlib.pyplot as plt

In [2]:
print(torch.cuda.is_available())

False


## Embeddings

```INPUT (0, 2, 11, 24, 123, 1)``` - token idx

```OUTPUT ([0.533, ... , 0.123], [0.627, ... , 0.156], ... , [0.724, ... , 0.976])``` - token embedding 


In [3]:
class Tranformer_config():
    def __init__(self) -> None:
        self.emb_dim = 512
        self.ffn_dim = 2048
        self.num_heads = 8
        self.num_layers = 6
        self.dropout = 0.1
        self.max_len = 256
        self.batch_size = 1
        self.lr = 1e-4
        self.epochs = 10
        self.vocab_size = 50257
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
class Embedding_with_pe(nn.Module):
    def __init__(self, vocab_size, sequence_length, embedding_dim):
        """
        input: tensor of tokens (batch_size, sequence_length, embedding_dim)
        """
        super().__init__()
        self.sequence_range = torch.arange(sequence_length).unsqueeze(0)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_encoding = nn.Embedding(sequence_length, embedding_dim)
    
    def forward(self, x):
        pe = self.position_encoding(self.sequence_range)
        embedding = self.embeddings(x)
        return embedding + pe
        

# embed = Embedding_with_pe(vocab_size=10, sequence_lenght=5, embedding_dim=3)
# embed(torch.tensor([[1, 2, 0, 0, 0]]))
# embed.sequence_range
# embed(torch.tensor([[1,2,3,4]])).shape

In [5]:
# tembed = Embedding_with_pe(vocab_size=20, embedding_dim=10, sequence_lenght=4)
# embed(torch.randint(0, 20, [2, 4])) 

## MHA + Add & Norm

In [6]:
class MultiHeadAttention_AddNorm(nn.Module):
    def __init__(self, embedding_dim, n_heads, dropout=0.1):
        super().__init__()
        self.MHA = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, Q, K, V, mask=None):
        attn_output, _ = self.MHA(Q, K, V, attn_mask=mask)
        attn_output = self.dropout(attn_output)
        output = self.layer_norm(Q + attn_output)
        return output

# seq_len = 100
# emb_dim = 160
# n_heads = 8 
# batch = 32

# mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# # print(mask)

# mha = MultiHeadAttention_AddNorm(emb_dim, n_heads, dropout=0)
# output, attention = mha(torch.randn(batch, seq_len, emb_dim), mask=None)
# print(output.shape)

## Feed Forward + Add & Norm

In [7]:
class FF_AddNorm(nn.Module):
    def __init__(self, embedding_dim, ffn_dim, dropout=0.1):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(embedding_dim, ffn_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_dim, embedding_dim),
            nn.Dropout(dropout)
        )
        self.layer_norm = nn.LayerNorm(embedding_dim)
        
    def forward(self, x):
        ff_output = self.ff(x)
        output = self.layer_norm(x + ff_output)
        return output


# seq_len = 100
# emb_dim = 160
# n_heads = 8 
# batch = 32

# ff = FF_AddNorm(emb_dim, 512, 0.1)
# output = ff(torch.randn(batch, seq_len, emb_dim))

# print(output.shape)

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, embedding_dim, n_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention_AddNorm(embedding_dim, n_heads, dropout)
        self.ffn = FF_AddNorm(embedding_dim, ffn_dim, dropout)
        
    def forward(self, x, mask=None):
        x = self.attention(x, x, x, mask)
        x = self.ffn(x)
        return x
    
class Encoder(nn.Module):
    def __init__(self, n_blocks, embedding_dim, n_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(embedding_dim, n_heads, ffn_dim, dropout) for _ in range(n_blocks)])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x
    
# n_blocks = 6
# embedding_dim = 512 
# n_heads = 8
# ffn_dim = 2048

# encoder = Encoder(n_blocks, embedding_dim, n_heads, ffn_dim)
# output = encoder(torch.randn(32, 100, 512))

# print(output.shape)

In [9]:
class DecoderLayer(nn.Module):
    def __init__(self, embedding_dim, n_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention_AddNorm(embedding_dim, n_heads, dropout)
        self.cross_attention = MultiHeadAttention_AddNorm(embedding_dim, n_heads, dropout)
        self.ffn = FF_AddNorm(embedding_dim, ffn_dim, dropout)
        
    def forward(self, x, context, target_mask=None, padding_mask=None):
        x = self.self_attention(x, x, x, target_mask)
        x = self.cross_attention(x, context, context, padding_mask)
        x = self.ffn(x)
        return x
class Decoder(nn.Module):
    def __init__(self, n_blocks, embedding_dim, n_heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([DecoderLayer(embedding_dim, n_heads, ffn_dim, dropout) for _ in range(n_blocks)])

    def forward(self, x, context, target_mask=None, padding_mask=None):
        for layer in self.layers:
            x = layer(x, context, target_mask, padding_mask)
        return x
    
# decoder = Decoder(n_blocks, embedding_dim, n_heads, ffn_dim)
# output = decoder(torch.randn(32, 100, 512), torch.randn(32, 100, 512))

In [11]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = Embedding_with_pe(config.vocab_size, config.max_len, config.emb_dim)
        self.encoder = Encoder(config.num_encoder_layers, config.emb_dim, config.num_heads, config.ffn_dim, config.dropout)
        self.decoder = Decoder(config.num_decoder_layers, config.emb_dim, config.num_heads, config.ffn_dim, config.dropout)
        self.linear = nn.Linear(config.emb_dim, config.vocab_size)
    
    def forward(self, original, target, target_mask=None, padding_mask=None):
        original = self.embedding(original)
        target = self.embedding(target)
        encoder_out = self.encoder(original)
        decoder_out = self.decoder(target, encoder_out, target_mask, padding_mask)
        output = self.linear(decoder_out)
        
        # print("embdedding original", "-" * 100, "\n", original)
        # print("embdedding target", "-" * 100, "\n", target)
        # print("encoder_out", "-" * 100, "\n", encoder_out)
        # print("decoder_out", "-" * 100, "\n", decoder_out) 
        # print("output", "-" * 100, "\n", output)
        
        return output
    
    
class Tranformer_config():
    def __init__(self) -> None:
        self.emb_dim = 150
        self.max_len = 200
        self.vocab_size = 50257
        
        self.num_encoder_layers = 6
        self.num_decoder_layers = 6
        self.num_heads = 2
        
        self.ffn_dim = 2048
        self.dropout = 0.1
        
        self.batch_size = 1
        self.lr = 1e-4
        self.epochs = 10
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        
config = Tranformer_config()
transformer = Transformer(config)
output = transformer(torch.randint(0, 50257, [1, 200]), torch.randint(0, 50257, [1, 200]))
# print(output.shape)

embdedding original ---------------------------------------------------------------------------------------------------- 
 tensor([[[ 0.1353, -0.5520, -0.2578,  ..., -3.0753,  0.1462, -2.6390],
         [-1.6448, -2.4551, -2.1762,  ...,  0.5636, -1.2615, -2.9863],
         [ 0.6010,  0.5996, -0.0621,  ..., -0.0647,  0.5029, -0.0619],
         ...,
         [ 0.4387, -2.4390,  0.3713,  ...,  0.3974, -1.2088,  1.1083],
         [-0.0412, -0.4828, -1.4971,  ..., -2.5934,  1.4022, -0.4571],
         [ 2.1110,  0.1720,  0.6893,  ..., -1.2799,  0.1511, -0.3131]]],
       grad_fn=<AddBackward0>)
embdedding target ---------------------------------------------------------------------------------------------------- 
 tensor([[[ 1.1583, -1.4360, -1.0492,  ..., -2.1828,  1.5707, -0.3493],
         [-1.1209, -2.8678, -2.5269,  ...,  0.6984, -1.8902, -2.4611],
         [ 1.2762,  0.5925,  0.2262,  ...,  2.8743, -0.5552,  2.1794],
         ...,
         [ 0.1383, -1.0510,  0.2185,  ..., -0.7502,  0.4