In [2]:
import transformers
import torch
import torch.nn as nn
import torch.optim as opt
from torch.nn import functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader

from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEVICE="cpu"

In [4]:


# Define the text dataset
text_dataset_en = [
    "The quick brown fox jumps over the lazy dog and runs through the forest.",
    "She sells sea shells by the seashore and enjoys watching the waves crash.",
    "He plays the guitar and sings beautiful songs under the shining moon.",
    "They build sandcastles on the beach and fly colorful kites in the sky.",
    "We explore ancient ruins and discover hidden treasures in mysterious caves.",
    "The sun sets behind the mountains, painting the sky with hues of pink and orange.",
    "A gentle breeze rustles the leaves on the trees as birds sing their sweet melodies.",
    "Children laugh and play in the park, enjoying the warmth of a sunny day.",
    "In the distance, a train whistle blows, signaling the arrival of a new adventure.",
    "As night falls, stars twinkle in the dark sky, guiding travelers on their journey.",
]

text_dataset_ua = [
    "Швидкий коричневий лис стрибає через лінивого собаку та біжить через ліс.",
    "Вона продає морські ракушки біля морського узбережжя та насолоджується спостереженням за хвилями.",
    "Він грає на гітарі та співає прекрасні пісні під сяючою місяцем.",
    "Вони будують піщані замки на пляжі та літають різнокольоровими паперовими зміями на небі.",
    "Ми досліджуємо стародавні руїни та відкриваємо приховані скарби в таємничих печерах.",
    "Сонце заходить за гори, розфарбовуючи небо рожевим та помаранчевим.",
    "Легкий бриз шелестить листям на деревах, поки птахи співають свої чарівні мелодії.",
    "Діти сміються та граються в парку, насолоджуючись теплотою сонячного дня.",
    "Удачний випадок, здалеку чути свисток потяга, сигналізуючий про прибуття нової пригоди.",
    "По вечорах зірки мерехтять на темному небі, направляючи подорожуючих на їхній шлях.",
]

In [5]:
# Initialize the english tokenizer
tokenizer_en = BertTokenizer.from_pretrained("bert-base-cased")

# Initialize the ukrainian tokenizer
tokenizer_ua = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

def tokenize_data(en_data, ua_data, tokenizer_en=None, tokenizer_ua=None, max_length=30):
    if not tokenizer_en and not tokenizer_ua:
        # Initialize the english tokenizer
        tokenizer_en = BertTokenizer.from_pretrained("bert-base-cased")

        # Initialize the ukrainian tokenizer
        tokenizer_ua = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

    # Tokenize the data
    en_tokenized_data  = [tokenizer_en(text, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt") for text in en_data]
    ua_tokenized_data = [tokenizer_ua(text, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt") for text in ua_data]

    return en_tokenized_data, ua_tokenized_data

In [6]:
en_tokenized_data, ua_tokenized_data = tokenize_data(text_dataset_en, text_dataset_ua, tokenizer_en, tokenizer_ua)
en_tokenized_data[0]["input_ids"][:10]

tensor([[  101,  1109,  3613,  3058, 17594, 15457,  1166,  1103, 16688,  3676,
          1105,  2326,  1194,  1103,  3304,   119,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]])

## Creating a custom dataset

In [7]:
class TranslationDataset(Dataset):
    def __init__(self, src_data, trg_data, src_tokenizer, trg_tokenizer, max_length=30):
        super().__init__()
        # Tokenize th data
        self.src_tokenized_data, self.trg_tokenized_data = tokenize_data(src_data, trg_data, src_tokenizer, trg_tokenizer, max_length)
    
    def __len__(self):
        return len(self.src_tokenized_data)
    
    def __getitem__(self, idx):
        src_sample = {
            "input_ids": self.src_tokenized_data[idx]["input_ids"],
            "attention_mask": self.src_tokenized_data[idx]["attention_mask"]
        }
        trg_sample = {
            "input_ids": self.trg_tokenized_data[idx]["input_ids"],
            "attention_mask": self.trg_tokenized_data[idx]["attention_mask"]
        }

        return src_sample, trg_sample

In [8]:
translation_dataset = TranslationDataset(text_dataset_en, text_dataset_ua, tokenizer_en, tokenizer_ua)

In [9]:
# Create the data loader
batch_size = 32
data_loader = DataLoader(translation_dataset, batch_size=batch_size, shuffle=True)

# Training

In [277]:
class Head(nn.Module):
    def __init__(self, n_embed, head_size, block_size, masked=False):
        super().__init__()
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.masked = masked
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, q, k, v):
        B, T, C = q.shape
        q = self.query(q)
        k = self.key(k)
        v = self.value(v)
                            
        # Compute scaled dot-product attention
        dot_product = torch.matmul(q, k.transpose(-2, -1))
        scaled_dot_product = dot_product * C ** -0.5 # divide by the square root of d_k
    
        # Apply masking if it is a decoder
        w = (scaled_dot_product.masked_fill(self.tril[:T, :T] == 0, float("-inf")) 
            if self.masked else scaled_dot_product)

        w = F.softmax(w, dim=1)
        w = torch.matmul(w, v)
        return w
    
class MultiHeadAttention(nn.Module):
    def __init__(self, n_embed, n_heads, head_size, block_size, masked=False):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embed, head_size, block_size, masked) for i in range(n_heads)])
        self.projection = nn.Linear(n_embed, n_embed)
        
    def forward(self, q, k, v):
        x = torch.cat([h(q, k, v) for h in self.heads], dim=-1)
        x = self.projection(x)
        return x
    
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout_p):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed))
    
    def forward(self, x):
        return self.net(x)
    
class EncoderBlock(nn.Module):
    def __init__(self, n_embed, n_heads, block_size, dropout_p):
        super().__init__()
        head_size = n_embed // n_heads
        self.sa = MultiHeadAttention(n_embed, n_heads, head_size, block_size)
        self.ffwd = FeedForward(n_embed, dropout_p) 
        self.norm1 = nn.LayerNorm(n_embed)
        self.norm2 = nn.LayerNorm(n_embed)
    
    def forward(self, x):
        # Residual connections
        x = self.norm1(x + self.sa(x, x, x))
        x = self.norm2(x + self.ffwd(x))
        return x
    
class Encoder(nn.Module):
    def __init__(self, vocab_size, n_embed, n_heads, block_size, n_layers, dropout_p):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[
            EncoderBlock(n_embed, n_heads, block_size, dropout_p) for _ in range(n_layers)
        ])       

    def forward(self, x):
        B, T = x.shape
        token_embed = self.token_embedding(x)
        positional_embed = self.pos_embedding(torch.arange(T, device=DEVICE))
        embedding = token_embed + positional_embed
        x = self.blocks(embedding)
        return x
    
class DecoderBlock(nn.Module):
    """Decoder Block"""
    def __init__(self, n_embed, n_heads, block_size, dropout_p):
        super().__init__()
        head_size = n_embed // n_heads
        self.sa = MultiHeadAttention(n_embed, n_heads, head_size, block_size, masked=True) # masked self-attention
        self.encoder_decoder_attention = MultiHeadAttention(n_embed, n_heads, head_size, block_size)
        self.ffwd = FeedForward(n_embed, dropout_p)
        self.norm1 = nn.LayerNorm(n_embed)
        self.norm2 = nn.LayerNorm(n_embed)
        self.norm3 = nn.LayerNorm(n_embed)

    def forward(self, x, encoder_output):
        x = self.norm1(x + self.sa(x, x, x))
        x = self.norm2(x + self.encoder_decoder_attention(x, encoder_output, encoder_output))
        x = self.norm3(x + self.ffwd(x))
        return x
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, n_embed, n_heads, block_size, n_layers, dropout_p):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.pos_embedding = nn.Embedding(block_size, n_embed)
        self.blocks = nn.ModuleList([
            DecoderBlock(n_embed, n_heads, block_size, dropout_p) for _ in range(n_layers)
        ])
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, x, enc_out):
        B, T= x.shape
        token_embed = self.token_embedding(x)
        pos_embed = self.pos_embedding(torch.arange(T, device=DEVICE))
        embedding = token_embed + pos_embed

        x = embedding
        for block in self.blocks:
            x = block(x, enc_out)

        x = self.lm_head(x)
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trgt_vocab_size, n_embed, n_heads, block_size, n_layers, dropout_p):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, n_embed, n_heads, block_size, n_layers, dropout_p)
        self.decoder = Decoder(trgt_vocab_size, n_embed, n_heads, block_size, n_layers, dropout_p)

    def forward(self, x, y):
        encoded = self.encoder(x)
        out = self.decoder(y, encoded)
        return out

In [270]:
vocab_size_en = tokenizer_en.vocab_size
vocab_size_ua = tokenizer_ua.vocab_size