In [None]:
# CONFIGURE TRAINING SETTINGS

In [None]:
!pip install torchmetrics

In [None]:
import re
import json
import nltk
import math
import torch
import random
import torch.nn as nn
from enum import Enum
from tqdm import tqdm
from pathlib import Path
from google.colab import drive
from tokenizers import Tokenizer
from nltk.corpus import stopwords
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader, random_split
# from torchmetrics.text import BLEUScore, WordErrorRate, CharErrorRate



In [None]:
SRC_LANG = "en"
TGT_LANG = "am"
SEQ_LEN = 52
# Connect to Google Drive
drive.mount("/content/drive", force_remount=True)
PROJECT_DIR = "/content/drive/My Drive/Ashu_NLP"
DATASET_PATH = f"{PROJECT_DIR}/data/parallel-corpus-en-am-v3.5.json"

# Download necessary NLTK data
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

In [None]:
class TextPreprocessingPipeline(ABC):
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer

    def tokenize(self, text):
        """
        Tokenize the input text into words.
        """
        words = word_tokenize(text)
        return words

    @abstractmethod
    def preprocess(self, text: str, encode=True) -> str:
        pass

In [None]:
import re
import json
import nltk
import math
import torch
import random
import torch.nn as nn
from enum import Enum
from tqdm import tqdm
from pathlib import Path
from google.colab import drive
from tokenizers import Tokenizer
from nltk.corpus import stopwords
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader, random_split
# from torchmetrics.text import BLEUScore, WordErrorRate, CharErrorRate
import pickle

# Connect to Google Drive
drive.mount("/content/drive", force_remount=True)
PROJECT_DIR = "/content/drive/My Drive/Ashu_NLP"

# Download necessary NLTK data
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

torch.manual_seed(42)
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    torch.cuda.manual_seed_all(42)
else:
    DEVICE = torch.device('cpu')
random.seed(42)

BATCH_SIZE = 64
EPOCHS = 5
LEARNING_RATE = 1e-04
MAX_SEQ_LEN = 52
MODEL_DIM = 512
NUM_LAYERS = 6
NUM_HEADS = 32
DROPOUT_RATE = 0.1
FF_DIM = 2048
SOURCE_LANG = "en"
TARGET_LANG = "am"
MODEL_DIR = f"{PROJECT_DIR}/models"
MODEL_PREFIX = "custom_model_"
PRELOAD_SUFFIX = ""
TOKENIZER_DIR = f"{PROJECT_DIR}/tokenizers"
TOKENIZER_PREFIX = "tokenizer-{0}-v3.5-12k.json"
LOG_DIR = f"{PROJECT_DIR}/logs/custom_model"
DATA_PATH = f"{PROJECT_DIR}/data/parallel-corpus-en-am-v3.5.json"

def get_model_path(suffix: str):
    return f"{MODEL_DIR}/{MODEL_PREFIX}{suffix}.pkl"

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import pickle

class AmharicTextPreprocessor(TextPreprocessingPipeline):
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__(tokenizer)

    def preprocess(self, text: str, encode=True) -> str:
        # Normalize character level discrepancies
        text = self.normalize_char_discrepancies(text)

        # Replace common abbreviations
        text = self.replace_abbreviations(text)

        # Remove punctuation and special characters
        text = self.remove_punctuation_and_special_chars(text)

        # Remove non-Amharic characters and numbers
        text = self.remove_non_amharic_chars_and_numbers(text)

        if encode:
            return self.tokenizer.encode(text).ids
        else:
            return text

    def replace_abbreviations(self, text: str) -> str:
        amharic_abbreviations = {
            "ት/ቤት": "ትምህርት ቤት",
            "ት/ርት": "ትምህርት",
            "ት/ክፍል": "ትምህርት ክፍል",
            "ሃ/አለቃ": "ሃምሳ አለቃ",
            "ሃ/ስላሴ": "ሃይለ ስላሴ",
            "ደ/ዘይት": "ደብረ ዘይት",
            "ደ/ታቦር": "ደብረ ታቦር",
            "መ/ር": "መምህር",
            "መ/ቤት": "መስሪያ ቤት",
            "መ/አለቃ": "መቶ አለቃ",
            "ክ/ከተማ": "ክፍለ ከተማ",
            "ክ/ሀገር": "ክፍለ ሀገር",
            "ወ/ር": "",
            "ወ/ሮ": "ወይዘሮ",
            "ወ/ሪት": "ወይዘሪት",
            "ወ/ስላሴ": "ወልደ ስላሴ",
            "ፍ/ስላሴ": "ፍቅረ ስላሴ",
            "ፍ/ቤት": "ፍርድ ቤት",
            "ጽ/ቤት": "ጽህፈት ቤት",
            "ሲ/ር": "",
            "ፕ/ር": "ፕሮፌሰር",
            "ጠ/ሚንስትር": "ጠቅላይ ሚኒስተር",
            "ጠ/ሚ": "ጠቅላይ ሚኒስተር",
            "ዶ/ር": "ዶክተር",
            "ገ/ገዮርጊስ": "ገብረ ገዮርጊስ",
            "ቤ/ክርስትያን": "ቤተ ክርስትያን",
            "ም/ስራ": "",
            "ም/ቤት": "ምክር ቤተ",
            "ተ/ሃይማኖት": "ተክለ ሃይማኖት",
            "ሚ/ር": "ሚኒስትር",
            "ኮ/ል": "ኮሎኔል",
            "ሜ/ጀነራል": "ሜጀር ጀነራል",
            "ብ/ጀነራል": "ብርጋደር ጀነራል",
            "ሌ/ኮለኔል": "ሌተናንት ኮለኔል",
            "ሊ/መንበር": "ሊቀ መንበር",
            "አ/አ": "ኣዲስ ኣበባ",
            "አ.አ": "ኣዲስ ኣበባ",
            "ር/መምህር": "ርዕሰ መምህር",
            "ፕ/ት": "",
            "ዓም": "ዓመተ ምህረት",
            "ዓ.ዓ": "ዓመተ ዓለም",
        }
        for key in amharic_abbreviations:
            regex = rf'\b{re.escape(key)}\b'
            text = re.sub(regex, amharic_abbreviations[key], text)
        text = re.sub(r'[.\?\"\',/#!$%^&*;:፤።{}=\-_`~()፩፪፫፬፭፮፯፰፱፲፳፴፵፶፷፸፹፺፻0-9]+', ' ', text)
        text = re.sub(r'\s{2,}', ' ', text)

        return text

    def normalize_char_discrepancies(self, text: str) -> str:
        rep1 = re.sub('[ሃኅኃሐሓኻ]', 'ሀ', text)
        rep2 = re.sub('[ሑኁዅ]', 'ሁ', rep1)
        rep3 = re.sub('[ኂሒኺ]', 'ሂ', rep2)
        rep4 = re.sub('[ኌሔዄ]', 'ሄ', rep3)
        rep5 = re.sub('[ሕኅ]', 'ህ', rep4)
        rep6 = re.sub('[ኆሖኾ]', 'ሆ', rep5)
        rep7 = re.sub('[ሠ]', 'ሰ', rep6)
        rep8 = re.sub('[ሡ]', 'ሱ', rep7)
        rep9 = re.sub('[ሢ]', 'ሲ', rep8)
        rep10 = re.sub('[ሣ]', 'ሳ', rep9)
        rep11 = re.sub('[ሤ]', 'ሴ', rep10)
        rep12 = re.sub('[ሥ]', 'ስ', rep11)
        rep13 = re.sub('[ሦ]', 'ሶ', rep12)
        rep14 = re.sub('[ዓኣዐ]', 'አ', rep13)
        rep15 = re.sub('[ዑ]', 'ኡ', rep14)
        rep16 = re.sub('[ዒ]', 'ኢ', rep15)
        rep17 = re.sub('[ዔ]', 'ኤ', rep16)
        rep18 = re.sub('[ዕ]', 'እ', rep17)
        rep19 = re.sub('[ዖ]', 'ኦ', rep18)
        rep20 = re.sub('[ጸ]', 'ፀ', rep19)
        rep21 = re.sub('[ጹ]', 'ፁ', rep20)
        rep22 = re.sub('[ጺ]', 'ፂ', rep21)
        rep23 = re.sub('[ጻ]', 'ፃ', rep22)
        rep24 = re.sub('[ጼ]', 'ፄ', rep23)
        rep25 = re.sub('[ጽ]', 'ፅ', rep24)
        rep26 = re.sub('[ጾ]', 'ፆ', rep25)
        rep27 = re.sub('(ሉ[ዋአ])', 'ሏ', rep26)
        rep28 = re.sub('(ሙ[ዋአ])', 'ሟ', rep27)
        rep29 = re.sub('(ቱ[ዋአ])', 'ቷ', rep28)
        rep30 = re.sub('(ሩ[ዋአ])', 'ሯ', rep29)
        rep31 = re.sub('(ሱ[ዋአ])', 'ሷ', rep30)
        rep32 = re.sub('(ሹ[ዋአ])', 'ሿ', rep31)
        rep33 = re.sub('(ቁ[ዋአ])', 'ቋ', rep32)
        rep34 = re.sub('(ቡ[ዋአ])', 'ቧ', rep33)
        rep35 = re.sub('(ቹ[ዋአ])', 'ቿ', rep34)
        rep36 = re.sub('(ሁ[ዋአ])', 'ኋ', rep35)
        rep37 = re.sub('(ኑ[ዋአ])', 'ኗ', rep36)
        rep38 = re.sub('(ኙ[ዋአ])', 'ኟ', rep37)
        rep39 = re.sub('(ኩ[ዋአ])', 'ኳ', rep38)
        rep40 = re.sub('(ዙ[ዋአ])', 'ዟ', rep39)
        rep41 = re.sub('(ጉ[ዋአ])', 'ጓ', rep40)
        rep42 = re.sub('(ደ[ዋአ])', 'ዷ', rep41)
        rep43 = re.sub('(ጡ[ዋአ])', 'ጧ', rep42)
        rep44 = re.sub('(ጩ[ዋአ])', 'ጯ', rep43)
        rep45 = re.sub('(ጹ[ዋአ])', 'ጿ', rep44)
        rep46 = re.sub('(ፉ[ዋአ])', 'ፏ', rep45)
        rep47 = re.sub('[ቊ]', 'ቁ', rep46)
        rep48 = re.sub('[ኵ]', 'ኩ', rep47)

        return rep48

    def remove_punctuation_and_special_chars(self, text: str) -> str:
        normalized_text = re.sub(r'[!@#$%^&*()…\[\]{};:"›’‘"’\/<>?|`´~\\=\+፡;]+', ' ', text)
        return normalized_text

    def remove_non_amharic_chars_and_numbers(self, text: str) -> str:
        rm_num_and_ascii = re.sub('[A-Za-z0-9]', '', text)
        return re.sub('[^\\u1200-\\u137F\\s]+', '', rm_num_and_ascii)

class EnglishTextPreprocessor(TextPreprocessingPipeline):
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__(tokenizer)
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()

    def preprocess(self, text: str, encode=True) -> str:
        # Convert text to lowercase
        text = text.lower()

        # Replace common English abbreviations
        text = self.replace_english_abbreviations(text)

        # Remove punctuation and special characters
        text = self.remove_punctuation_and_special_chars(text)

        # Remove non-English characters and numbers
        text = self.remove_non_english_chars_and_numbers(text)

        if encode:
            return self.tokenizer.encode(text).ids
        else:
            return text

    def remove_stopwords(self, words):
        """
        Remove common English stopwords from the list of words.
        """
        filtered_words = [word for word in words if word not in self.stop_words]
        return filtered_words

    def lemmatize(self, words):
        """
        Lemmatize words to their base form.
        """
        lemmatized_words = [self.lemmatizer.lemmatize(word) for word in words]
        return lemmatized_words

    def replace_english_abbreviations(self, text: str) -> str:
        english_abbreviations = {
            "i.e.": "that is",
            "e.g.": "for example",
            "etc.": "and so on",
            "mr.": "mister",
            "mrs.": "missus",
            "dr.": "doctor",
            "st.": "saint",
            "ave.": "avenue",
            "apt.": "apartment",
            "dept.": "department",
            "univ.": "university",
            "prof.": "professor",
            "jr.": "junior",
            "sr.": "senior",
            "co.": "company",
            "corp.": "corporation",
            "inc.": "incorporated",
            "est.": "established",
            "jan.": "january",
            "feb.": "february",
            "mar.": "march",
            "apr.": "april",
            "jun.": "june",
            "jul.": "july",
            "aug.": "august",
            "sep.": "september",
            "oct.": "october",
            "nov.": "november",
            "dec.": "december",
        }
        for key in english_abbreviations:
            regex = rf'\b{re.escape(key)}\b'
            text = re.sub(regex, english_abbreviations[key], text)

        return text

    def remove_non_english_chars_and_numbers(self, text: str) -> str:
        # Remove non-English characters
        text = re.sub(r'[^a-zA-Z\s]', ' ', text)

        # Remove numbers
        text = re.sub(r'\d', ' ', text)

        # Remove extra spaces
        text = re.sub(r'\s{2,}', ' ', text)

        return text

    def remove_punctuation_and_special_chars(self, text: str) -> str:
        normalized_text = re.sub(r'[!@#$%^&*()…\[\]{};:"›’‘"’\/<>?|`´~\\=\+፡;]+', ' ', text)
        return normalized_text

def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for batch in val_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    return avg_val_loss

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, model_save_path):
    writer = SummaryWriter()
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for batch in train_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        val_loss = evaluate_model(model, val_loader, criterion, device)

        writer.add_scalars('Loss', {'train': avg_train_loss, 'val': val_loss}, epoch)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, model_save_path)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}')

    writer.close()

def save_model(model, path):
    with open(path, 'wb') as f:
        pickle.dump(model, f)
    print(f'Model saved to {path}')


DEFINE DATA PREPROCESSING PIPELINE

In [None]:
class PreprocessingPipeline(ABC):
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer

    def tokenize(self, text):
        """
        Tokenize the input text into words.
        """
        words = word_tokenize(text)
        return words

    @abstractmethod
    def preprocess(self, text: str, encode=True) -> str:
        pass


class AmharicPreprocessor(PreprocessingPipeline):
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__(tokenizer)

    def preprocess(self, text: str, encode=True) -> str:
        # Character level mismatch
        text = self.normalize_char_level_missmatch(text)

        # Replace commonly used abbreviations
        text = self.normalize_abbreviations(text)

        # Remove punctuations and special characters
        text = self.remove_punc_and_special_chars(text)

        # Remove non-amharic chars and arabic numbers
        text = self.remove_ascii_and_numbers(text)

        if encode:
            return self.tokenizer.encode(
                text,
            ).ids
        else:
            return text

    # Remove abbreviations
    def normalize_abbreviations(self, text: str) -> str:
        common_amharic_abbreviations = {
            "ት/ቤት": "ትምህርት ቤት",
            "ት/ርት": "ትምህርት",
            "ት/ክፍል": "ትምህርት ክፍል",
            "ሃ/አለቃ": "ሃምሳ አለቃ",
            "ሃ/ስላሴ": "ሃይለ ስላሴ",
            "ደ/ዘይት": "ደብረ ዘይት",
            "ደ/ታቦር": "ደብረ ታቦር",
            "መ/ር": "መምህር",
            "መ/ቤት": "መስሪያ ቤት",
            "መ/አለቃ": "መቶ አለቃ",
            "ክ/ከተማ": "ክፍለ ከተማ",
            "ክ/ሀገር": "ክፍለ ሀገር",
            "ወ/ር": "",
            "ወ/ሮ": "ወይዘሮ",
            "ወ/ሪት": "ወይዘሪት",
            "ወ/ስላሴ": "ወልደ ስላሴ",
            "ፍ/ስላሴ": "ፍቅረ ስላሴ",
            "ፍ/ቤት": "ፍርድ ቤት",
            "ጽ/ቤት": "ጽህፈት ቤት",
            "ሲ/ር": "",
            "ፕ/ር": "ፕሮፌሰር",
            "ጠ/ሚንስትር": "ጠቅላይ ሚኒስተር",
            "ጠ/ሚ": "ጠቅላይ ሚኒስተር",
            "ዶ/ር": "ዶክተር",
            "ገ/ገዮርጊስ": "ገብረ ገዮርጊስ",
            "ቤ/ክርስትያን": "ቤተ ክርስትያን",
            "ም/ስራ": "",
            "ም/ቤት": "ምክር ቤተ",
            "ተ/ሃይማኖት": "ተክለ ሃይማኖት",
            "ሚ/ር": "ሚኒስትር",
            "ኮ/ል": "ኮሎኔል",
            "ሜ/ጀነራል": "ሜጀር ጀነራል",
            "ብ/ጀነራል": "ብርጋደር ጀነራል",
            "ሌ/ኮለኔል": "ሌተናንት ኮለኔል",
            "ሊ/መንበር": "ሊቀ መንበር",
            "አ/አ": "ኣዲስ ኣበባ",
            "አ.አ": "ኣዲስ ኣበባ",
            "ር/መምህር": "ርዕሰ መምህር",
            "ፕ/ት": "",
            "ዓም": "ዓመተ ምህረት",
            "ዓ.ዓ": "ዓመተ ዓለም",
        }
        for key in common_amharic_abbreviations:
            regex = rf'\b{re.escape(key)}\b'
            text = re.sub(regex, common_amharic_abbreviations[key], text)

        # Remove punctuation, numbers, and extra spaces
        text = re.sub(r'[.\?"\',/#!$%^&*;:፤።{}=\-_`~()፩፪፫፬፭፮፮፰፱፲፳፴፵፵፷፸፹፺፻01-9]', ' ', text)
        text = re.sub(r'\s{2,}', ' ', text)

        return text

    #method to normalize character level missmatch such as ጸሀይ and ፀሐይ
    def normalize_char_level_missmatch(self, text: str) -> str:
        rep1=re.sub('[ሃኅኃሐሓኻ]','ሀ',text)
        rep2=re.sub('[ሑኁዅ]','ሁ',rep1)
        rep3=re.sub('[ኂሒኺ]','ሂ',rep2)
        rep4=re.sub('[ኌሔዄ]','ሄ',rep3)
        rep5=re.sub('[ሕኅ]','ህ',rep4)
        rep6=re.sub('[ኆሖኾ]','ሆ',rep5)
        rep7=re.sub('[ሠ]','ሰ',rep6)
        rep8=re.sub('[ሡ]','ሱ',rep7)
        rep9=re.sub('[ሢ]','ሲ',rep8)
        rep10=re.sub('[ሣ]','ሳ',rep9)
        rep11=re.sub('[ሤ]','ሴ',rep10)
        rep12=re.sub('[ሥ]','ስ',rep11)
        rep13=re.sub('[ሦ]','ሶ',rep12)
        rep14=re.sub('[ዓኣዐ]','አ',rep13)
        rep15=re.sub('[ዑ]','ኡ',rep14)
        rep16=re.sub('[ዒ]','ኢ',rep15)
        rep17=re.sub('[ዔ]','ኤ',rep16)
        rep18=re.sub('[ዕ]','እ',rep17)
        rep19=re.sub('[ዖ]','ኦ',rep18)
        rep20=re.sub('[ጸ]','ፀ',rep19)
        rep21=re.sub('[ጹ]','ፁ',rep20)
        rep22=re.sub('[ጺ]','ፂ',rep21)
        rep23=re.sub('[ጻ]','ፃ',rep22)
        rep24=re.sub('[ጼ]','ፄ',rep23)
        rep25=re.sub('[ጽ]','ፅ',rep24)
        rep26=re.sub('[ጾ]','ፆ',rep25)
        #Normalizing words with Labialized Amharic characters such as በልቱዋል or  በልቱአል to  በልቷል
        rep27=re.sub('(ሉ[ዋአ])','ሏ',rep26)
        rep28=re.sub('(ሙ[ዋአ])','ሟ',rep27)
        rep29=re.sub('(ቱ[ዋአ])','ቷ',rep28)
        rep30=re.sub('(ሩ[ዋአ])','ሯ',rep29)
        rep31=re.sub('(ሱ[ዋአ])','ሷ',rep30)
        rep32=re.sub('(ሹ[ዋአ])','ሿ',rep31)
        rep33=re.sub('(ቁ[ዋአ])','ቋ',rep32)
        rep34=re.sub('(ቡ[ዋአ])','ቧ',rep33)
        rep35=re.sub('(ቹ[ዋአ])','ቿ',rep34)
        rep36=re.sub('(ሁ[ዋአ])','ኋ',rep35)
        rep37=re.sub('(ኑ[ዋአ])','ኗ',rep36)
        rep38=re.sub('(ኙ[ዋአ])','ኟ',rep37)
        rep39=re.sub('(ኩ[ዋአ])','ኳ',rep38)
        rep40=re.sub('(ዙ[ዋአ])','ዟ',rep39)
        rep41=re.sub('(ጉ[ዋአ])','ጓ',rep40)
        rep42=re.sub('(ደ[ዋአ])','ዷ',rep41)
        rep43=re.sub('(ጡ[ዋአ])','ጧ',rep42)
        rep44=re.sub('(ጩ[ዋአ])','ጯ',rep43)
        rep45=re.sub('(ጹ[ዋአ])','ጿ',rep44)
        rep46=re.sub('(ፉ[ዋአ])','ፏ',rep45)
        rep47=re.sub('[ቊ]','ቁ',rep46) #ቁ can be written as ቊ
        rep48=re.sub('[ኵ]','ኩ',rep47) #ኩ can be also written as ኵ

        return rep48

    #replacing any existance of special character or punctuation to null
    def remove_punc_and_special_chars(self, text: str) -> str: # puct in amh =፡።፤;፦፧፨፠፣
        normalized_text = re.sub('[\!\@\#\$\%\^\«\»\&\*\(\)\…\[\]\{\}\;\“\”\›\’\‘\"\'\:\,\.\‹\/\<\>\?\\\\|\`\´\~\-\=\+\፡\።\፤\;\፦\፥\፧\፨\፠\፣]', '', text)
        return normalized_text

    #remove all ascii characters and Arabic and Amharic numbers
    def remove_ascii_and_numbers(self, text: str) -> str:
        rm_num_and_ascii=re.sub('[A-Za-z0-9]','',text)
        return re.sub('[^\u1200-\u137F\s]+','',rm_num_and_ascii)


class EnglishPreprocessor(PreprocessingPipeline):
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__(tokenizer)
        self.stop_words = set(stopwords.words('english'))
        self.lemmatizer = WordNetLemmatizer()

    def preprocess(self, text: str, encode=True) -> str:

        # Lowercase the text
        text = text.lower()

        # Replace commonly used English abbreviations
        text = self.normalize_english_abbreviations(text)

        # Remove punctuations and special characters
        text = self.remove_punc_and_special_chars(text)

        # Remove non-English chars and numbers
        text = self.remove_non_english_and_numbers(text)

        # # Pre-tokenization
        # words = self.tokenize(text)

        # # Remove stopwords
        # words = self.remove_stopwords(words)

        # # Lemmatization
        # words = self.lemmatize(words)

        if encode:
            return self.tokenizer.encode(
                text
            ).ids
        else:
            return text

    def remove_stopwords(self, words):
        """
        Remove common English stopwords from the list of words.
        """
        filtered_words = [word for word in words if word not in self.stop_words]
        return filtered_words

    def lemmatize(self, words):
        """
        Lemmatize words to their base form.
        """
        lemmatized_words = [self.lemmatizer.lemmatize(word) for word in words]
        return lemmatized_words

    def normalize_english_abbreviations(self, text: str) -> str:
        common_english_abbreviations = {
            "i.e.": "that is",
            "e.g.": "for example",
            "etc.": "and so on",
            "mr.": "mister",
            "mrs.": "missus",
            "dr.": "doctor",
            "st.": "saint",
            "ave.": "avenue",
            "apt.": "apartment",
            "dept.": "department",
            "univ.": "university",
            "prof.": "professor",
            "jr.": "junior",
            "sr.": "senior",
            "co.": "company",
            "corp.": "corporation",
            "inc.": "incorporated",
            "est.": "established",
            "jan.": "january",
            "feb.": "february",
            "mar.": "march",
            "apr.": "april",
            "jun.": "june",
            "jul.": "july",
            "aug.": "august",
            "sep.": "september",
            "oct.": "october",
            "nov.": "november",
            "dec.": "december",
            # Add more abbreviations as needed
        }
        for key in common_english_abbreviations:
            regex = rf'\b{re.escape(key)}\b'
            text = re.sub(regex, common_english_abbreviations[key], text)

        return text

    def remove_non_english_and_numbers(self, text: str) -> str:
        # Remove non-English characters
        text = re.sub(r'[^a-zA-Z\s]', ' ', text)

        # Remove numbers
        text = re.sub(r'\d', ' ', text)

        # Remove extra spaces
        text = re.sub(r'\s{2,}', ' ', text)

        return text

    #replacing any existance of special character or punctuation to null
    def remove_punc_and_special_chars(self, text: str) -> str:
        normalized_text = re.sub('[\!\@\#\$\%\^\&\*\(\)\…\[\]\{\}\;\“\”\›\’\‘\"\'\:\,\.\‹\/\<\>\?\\\\|\`\´\~\-\=\+\፡\;]', ' ', text)
        return normalized_text

DATASET PREPARATION

In [None]:
class ParallelTextDataset(Dataset):
    def __init__(self, dataset: list[dict], src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer) -> None:
        super().__init__()
        self.dataset = dataset

        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer

        self.src_preprocessor = EnglishPreprocessor(src_tokenizer)
        self.tgt_preprocessor = AmharicPreprocessor(tgt_tokenizer)

        self.sos_token = torch.tensor([self.src_tokenizer.token_to_id("[SOS]")], dtype=torch.int64)  # (1,)
        self.eos_token = torch.tensor([self.src_tokenizer.token_to_id("[EOS]")], dtype=torch.int64)  # (1,)
        self.pad_token = torch.tensor([self.src_tokenizer.token_to_id("[PAD]")], dtype=torch.int64)  # (1,)

    def __len__(self):
        return len(self.dataset)

    def batch_iterator(self, batch_size: int) -> DataLoader:
        return DataLoader(self, batch_size, shuffle=True)

    @staticmethod
    def lookback_mask(size: int) -> torch.Tensor:
        # Lower triangular matrix
        # [[
        #   [1, 0, ... , 0],
        #   [1, 1, ... , 0],
        #   [1, 1, ... , 0],
        #   [1, 1, ... , 1]
        # ]]
        # 1 x size x size
        return torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int) == 0

    def __getitem__(self, index) -> dict:
        src_tgt_pair = self.dataset[index]
        src_text = src_tgt_pair[SRC_LANG]
        tgt_text = src_tgt_pair[TGT_LANG]

        src_token_ids = self.src_preprocessor.preprocess(src_text)
        tgt_token_ids = self.tgt_preprocessor.preprocess(tgt_text)

        src_padding = SEQ_LEN - len(src_token_ids) - 2
        tgt_padding = SEQ_LEN - len(tgt_token_ids) - 1

        # (seq_len,)
        encoder_input = torch.concat([
            self.sos_token,                                                     # (1,)
            torch.tensor(src_token_ids, dtype=torch.int64),                     # (len(src_token_ids),)
            self.eos_token,                                                     # (1,)
            torch.tensor([self.pad_token] * src_padding, dtype=torch.int64)     # (src_padding,)
        ])

        # (seq_len,)
        decoder_input = torch.concat([
            self.sos_token,                                                     # (1,)
            torch.tensor(tgt_token_ids, dtype=torch.int64),                     # (len(tgt_token_ids),)
            torch.tensor([self.pad_token] * tgt_padding, dtype=torch.int64)     # (tgt_padding,)
        ])

        # (seq_len,)
        label = torch.concat([
            torch.tensor(tgt_token_ids, dtype=torch.int64),                     # (len(tgt_token_ids),)
            self.eos_token,                                                     # (1,)
            torch.tensor([self.pad_token] * tgt_padding, dtype=torch.int64)     # (tgt_padding,)
        ])

        return {
            # (seq_len,)
            "encoder_input": encoder_input,

            # (seq_len,)
            "decoder_input": decoder_input,

            # (seq_len,) != (1,) --> (seq_len,) --> (1, 1, seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),

            # (seq_len,) != (1,) --> (seq_len,) --> (1, 1, seq_len) --> (1, seq_len) & (1, seq_len, seq_len) --> (1, seq_len, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & self.lookback_mask(SEQ_LEN),

            # (seq_len,)
            "label": label,

            "src_text": src_text,
            "tgt_text": tgt_text
        }

DEFINE TRANSFORMER MODEL

In [None]:
class WordEmbedding(nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding: nn.Embedding = nn.Embedding(vocab_size, D_MODEL)

    """
        Args:
            x (torch.Tensor): (batches, seq_len, 1)

        Returns:
            torch.Tensor: (batches, seq_len, d_model)
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.embedding.forward(x) * math.sqrt(D_MODEL)


class PositionalEncoding(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        # To randomly zero-out a given tensor based on the given probability to combat overfitting
        self.dropout = nn.Dropout(DROPOUT)

        '''
            Create the positional encoding using the following formula:
                PE(pos, 2i) = sin(pos / (10000 ^ (2i/d_model)))
                PE(pos, 2i + 1) = cos(pos / (10000 ^ (2i/d_model)))
        '''
        # Create a matrix of shape (max_seq_len, d_model)
        pe = torch.zeros(SEQ_LEN, D_MODEL)

        # Create a vector of shape (max_seq_len, 1)
        pos = torch.arange(0, SEQ_LEN, dtype=torch.float).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, D_MODEL, 2).float() * -(math.log(10000.0) / D_MODEL))

        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)

        pe = pe.unsqueeze(0) # (1, max_seq_len, d_model)

        self.register_buffer('pe', pe)

    """
        Args:
            x (torch.Tensor): (batches, seq_len, d_model) where  0 < seq_len < self.max_seq_len

        Returns:
            torch.Tensor: (batches, seq_len, d_model)
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape[1] <= SEQ_LEN, f"Input sequence length exceeds the position encoder's max sequence length  `{SEQ_LEN}`"
        return self.dropout(x + self.pe[:, :x.shape[1], :].requires_grad_(False))


class FeedForwardBlock(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(D_MODEL, DFF).to(DEVICE)
        self.dropout = nn.Dropout(DROPOUT)
        self.linear_2 = nn.Linear(DFF, D_MODEL).to(DEVICE)

    """
        Args:
            x (torch.Tensor): (batches, seq_len, d_model) where  0 < seq_len < self.max_seq_len

        Returns:
            torch.Tensor: (batches, seq_len, d_model)
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batches, seq_len, d_model) -> (batches, seq_len, dff) -> (batches, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))


class MultiHeadAttentionBlock(nn.Module):
    def __init__(self) -> None:
        assert D_MODEL % HEADS == 0, "d_model is not divisible by heads"

        super().__init__()

        self.d_k = D_MODEL // HEADS

        self.W_q = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)
        self.W_k = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)
        self.W_v = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)

        self.W_o = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)
        self.dropout = nn.Dropout(DROPOUT)

    """
        Args:
            query (torch.Tensor): (batches, heads, seq_len, d_k) where  0 < seq_len < self.max_seq_len
            key (torch.Tensor): (batches, heads, seq_len, d_k) where  0 < seq_len < self.max_seq_len
            value (torch.Tensor): (batches, heads, seq_len, d_k) where  0 < seq_len < self.max_seq_len

            dropout (nn.Dropout): -
            mask (torch.Tensor): -

        Returns:
            torch.Tensor: (batches, heads, seq_len, d_k)
    """
    @staticmethod
    def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dropout: nn.Dropout=None, mask: torch.Tensor=None) -> tuple[torch.Tensor, torch.Tensor]:
        d_k = query.shape[-1]

        # (batches, heads, seq_len, d_k) @ (batches, heads, d_k, seq_len) --> (batches, heads, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)


        # Here we apply the lookback mask so that the output at a certain position(which is a token)
        # can only depend on the tokens on the previous positions. We also apply the ignore masks
        # so that attention score for the padding special token [PAD] is zero.
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e09)

        # (batches, heads, seq_len, seq_len) which applies softmax to the last dimension
        # so that the sum of the probabilities along this dimension equals 1
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)

        # (batches, heads, seq_len, seq_len) @ (batches, heads, seq_len, d_k) --> (batches, heads, seq_len, d_k)
        return (attention_scores @ value), attention_scores

    # q must be of shape (batches, seq_len, self.d_model) where  0 < seq_len < self.max_seq_len
    # k must be of shape (batches, seq_len, self.d_model) where  0 < seq_len < self.max_seq_len
    # v must be of shape (batches, seq_len, self.d_model) where  0 < seq_len < self.max_seq_len
    """
        Args:
            query (torch.Tensor): (batches, seq_len, d_model) where  0 < seq_len < self.max_seq_len
            key (torch.Tensor): (batches, seq_len, d_model) where  0 < seq_len < self.max_seq_len
            value (torch.Tensor): (batches, seq_len, d_model) where  0 < seq_len < self.max_seq_len

            mask (torch.Tensor): -

        Returns:
            torch.Tensor: (batches, seq_len, d_model)
    """
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        query: torch.Tensor = self.W_q(q) # (batches, seq_len, d_model) @ (d_model, d_model) --> (batches, seq_len, d_model)
        key: torch.Tensor = self.W_k(k)   # (batches, seq_len, d_model) @ (d_model, d_model) --> (batches, seq_len, d_model)
        value: torch.Tensor = self.W_v(v) # (batches, seq_len, d_model) @ (d_model, d_model) --> (batches, seq_len, d_model)

        # (batches, seq_len, d_model) --> (batches, seq_len, heads, d_k) --> (batches, heads, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], HEADS, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], HEADS, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], HEADS, self.d_k).transpose(1, 2)

        # Here has shape x = (batches, heads, seq_len, d_k)
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, self.dropout, mask)

        # (batches, heads, seq_len, d_k) --> (batches, seq_len, heads, d_k)
        x = x.transpose(1, 2)

        # (batches, seq_len, heads, d_k) --> (batches, seq_len, d_model)
        x = x.contiguous().view(x.shape[0], -1, HEADS * self.d_k)

        # (batches, seq_len, d_model) --> (batches, seq_len, d_model)
        return self.W_o(x)


class ResidualConnection(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.dropout = nn.Dropout(DROPOUT)
        self.norm = nn.LayerNorm(D_MODEL, device=DEVICE)

    def forward(self, x: torch.Tensor, sublayer: nn.Module) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection() for _ in range(2)])

    def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        return self.residual_connections[1](x, self.feed_forward_block)


class Encoder(nn.Module):
    def __init__(self, encoder_blocks: nn.ModuleList) -> None:
        super().__init__()
        self.encoder_blocks = encoder_blocks
        self.norm = nn.LayerNorm(D_MODEL, device=DEVICE)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        for block in self.encoder_blocks:
            x = block(x, mask)
        return self.norm(x)


class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.dropout = nn.Dropout(DROPOUT)
        self.residual_connections = nn.ModuleList([ResidualConnection() for _ in range(3)])

    # Since this transformer model is for translation we have a src_mask(from the encoder) and tgt_mask(from the decoder) which are two different languages
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x


class Decoder(nn.Module):
    def __init__(self, decoder_blocks: nn.ModuleList) -> None:
        super().__init__()
        self.decoder_blocks = decoder_blocks
        self.norm = nn.LayerNorm(D_MODEL, device=DEVICE)

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
        for layer in self.decoder_blocks:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)


class ProjectionLayer(nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.proj = nn.Linear(D_MODEL, vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batches, seq_len, d_model) --> (batches, seq_len, vocab_size)
        return self.proj(x)

class MtTransformerModel(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: WordEmbedding, tgt_embed: WordEmbedding, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x: torch.Tensor):
        return self.projection_layer(x)

    @staticmethod
    def build(
        src_vocab_size: int,
        tgt_vocab_size: int
    ):
        # Create the embedding layers
        src_embed = WordEmbedding(src_vocab_size)
        tgt_embed = WordEmbedding(tgt_vocab_size)

        # Create the positional encoding layers
        src_pos = PositionalEncoding()
        tgt_pos = PositionalEncoding()

        # Create N_BLOCKS number of encoders
        encoder_blocks = []
        for _ in range(N_BLOCKS):
            self_attention_block = MultiHeadAttentionBlock()
            feed_forward_block = FeedForwardBlock()

            encoder_blocks.append(
                EncoderBlock(self_attention_block, feed_forward_block)
            )

        # Create N_BLOCKS number of decoders
        decoder_blocks = []
        for _ in range(N_BLOCKS):
            self_attention_block = MultiHeadAttentionBlock()
            cross_attention_block = MultiHeadAttentionBlock()
            feed_forward_block = FeedForwardBlock()

            decoder_blocks.append(
                DecoderBlock(self_attention_block, cross_attention_block, feed_forward_block)
            )

        # Create the encoder and the decoder
        encoder = Encoder(nn.ModuleList(encoder_blocks))
        decoder = Decoder(nn.ModuleList(decoder_blocks))

        # Create the projection layer
        projection_layer = ProjectionLayer(tgt_vocab_size)

        # Create the transformer
        transformer = MtTransformerModel(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

        # Initialize the parameters
        for p in transformer.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        return transformer

DEFINE THE INFERENCE ENGINE


In [None]:
class MtInferenceEngine:

    def __init__(self, model: MtTransformerModel, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, top_k: int= 5, nucleus_threshold=10) -> None:
        self.model = model
        self.top_k = top_k
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.nucleus_threshold = nucleus_threshold
        self.sos_token = torch.tensor([self.src_tokenizer.token_to_id("[SOS]")], dtype=torch.int64, device=DEVICE)  # (1,)
        self.eos_token = torch.tensor([self.src_tokenizer.token_to_id("[EOS]")], dtype=torch.int64, device=DEVICE)  # (1,)
        self.pad_token = torch.tensor([self.src_tokenizer.token_to_id("[PAD]")], dtype=torch.int64, device=DEVICE)  # (1,)
        self.model.eval()

    def translate(self, source_text: str, max_len: int) -> tuple[str, str]:
        dataset = ParallelTextDataset(
            dataset=[{"en": source_text, "am":"" }],
            src_tokenizer=self.src_tokenizer,
            tgt_tokenizer=self.tgt_tokenizer
        )
        batch_iterator = iter( dataset.batch_iterator(1))
        batch = next(batch_iterator)

        encoder_input = batch["encoder_input"].to(DEVICE)       # (1, seq_len)
        encoder_mask = batch["encoder_mask"].to(DEVICE)         # (1, 1, 1, seq_len)
        decoder_mask = batch["decoder_mask"].to(DEVICE)         # (1, 1, seq_len, seq_len)

        return self.translate_raw(encoder_input, encoder_mask, decoder_mask, max_len)

    @torch.no_grad()
    def translate_raw(self, encoder_input: torch.Tensor, encoder_mask: torch.Tensor, decoder_mask: torch.Tensor, max_len: int) -> str:
        sos_idx = self.tgt_tokenizer.token_to_id('[SOS]')
        eos_idx = self.tgt_tokenizer.token_to_id('[EOS]')

        # Precompute the encoder output and reuse it for every step
        encoder_output = model.encode(encoder_input, encoder_mask)

        # Initialize the decoder input with the sos token
        decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(encoder_input).to(DEVICE)
        while decoder_input.size(1) < max_len and next_token != eos_idx:
            # Build required masking for decoder input
            decoder_mask = ParallelTextDataset.lookback_mask(decoder_input.size(1)).type_as(encoder_mask).to(DEVICE)

            # Calculate output of decoder
            decoder_out = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)       # (1, seq_len, d_model)

            # Retrieve the embedded vector form of the last token
            last_token_vec = decoder_out[:, -1]                         # (1, d_model)

            # Get the model's raw output(logits)
            last_token_logits = model.project(last_token_vec)           # (1, d_model) --> (1, tgt_vocab_size)

            # Evaluate the probability distribution across the vocab_size
            # dimension using softmax
            last_token_prob = torch.softmax(last_token_logits, dim=1)

            # Greedily pick the one with the highest probability
            _, next_token = torch.max(last_token_prob, dim=1)

            # Append to the decoder input for the subsequent iterations
            decoder_input = torch.cat([
                decoder_input,
                torch.empty(1, 1).type_as(encoder_input).fill_(next_token.item()).to(DEVICE)
            ], dim=1)

        # Remove the batch dimension
        decoder_input = decoder_input.squeeze(0)                                    # torch.tensor([...]) with shape tensor.Size([max_len])
        return self.tgt_tokenizer.decode(decoder_input.detach().cpu().tolist())

START TRAINING THE MODEL

In [None]:
TOKENIZER_FOLDER = f"{PROJECT_DIR}/tokenizers"
TOKENIZER_BASENAME = "tokenizer-{0}-v3.5-12k.json"
import os
def get_tokenizer(lang: str, basename: str = TOKENIZER_BASENAME) -> Tokenizer:
    tokenizer_filename = f"{basename.format(lang)}"
    tokenizer_path = os.path.join(TOKENIZER_FOLDER, tokenizer_filename)  # Use os.path.join

    # Check if the tokenizer file exists
    if not os.path.exists(tokenizer_path):
        raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}")

    tokenizer: Tokenizer = Tokenizer.from_file(tokenizer_path)

    tokenizer.enable_truncation(max_length=SEQ_LEN - 2)

    return tokenizer

def get_dataset() -> tuple[ParallelTextDataset, ParallelTextDataset, ParallelTextDataset]:
    import random # Import random here if not already imported globally

    with open(DATASET_PATH, 'r', encoding='utf-8') as data:
        dataset = json.load(data)

    cropped_size = int(0.10 * len(dataset))
    cropped_dataset = random.sample(dataset, cropped_size)  # Crop the dataset here

    src_tokenizer = get_tokenizer(SRC_LANG)
    tgt_tokenizer = get_tokenizer(TGT_LANG)

    train_dataset = ParallelTextDataset(cropped_dataset, src_tokenizer, tgt_tokenizer)
    val_dataset = ParallelTextDataset(cropped_dataset, src_tokenizer, tgt_tokenizer)
    test_dataset = ParallelTextDataset(cropped_dataset, src_tokenizer, tgt_tokenizer)

    return train_dataset, val_dataset, test_dataset


def get_model(src_vocab_size: int, tgt_vocab_size):
    return MtTransformerModel.build(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size
    )

@torch.no_grad()
def validate(model: MtTransformerModel, val_batch_iterator: DataLoader, loss_func: nn.CrossEntropyLoss):
    """
        Set the transformer module(the model) to evaluation mode
    """
    model.eval()

    val_losses = []
    # Evaluate model with `num_examples` number of random examples
    for batch in val_batch_iterator:
        # Retrieve the data points from the current batch
        encoder_input = batch["encoder_input"].to(DEVICE)       # (batches, seq_len)
        decoder_input = batch["decoder_input"].to(DEVICE)       # (batches, seq_len)
        encoder_mask = batch["encoder_mask"].to(DEVICE)         # (bathes, 1, 1, seq_len)
        decoder_mask = batch["decoder_mask"].to(DEVICE)         # (bathes, 1, seq_len, seq_len)
        label: torch.Tensor = batch['label'].to(DEVICE)         # (batches, seq_len)

        # Perform the forward pass according to the operations defined in
        # the transformer model in order to build the computation graph of the model
        encoder_output = model.encode(encoder_input, encoder_mask)                                  # (batches, seq_len, d_model)
        decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)    # (batches, seq_len, d_model)
        proj_output: torch.Tensor = model.project(decoder_output)                                   # (batches, seq_len, tgt_vocab_size)

        # Compute the cross entropy loss
        loss: torch.Tensor = loss_func.forward(
            proj_output.view(-1, val_dataset.tgt_tokenizer.get_vocab_size()),     # (batches, seq_len, tgt_vocab_size) --> (batches*seq_len, tgt_vocab_size)
            label.view(-1)                                                          # (batches, seq_len) --> (batches * seq_len, )
        )

        val_losses.append(loss.item())

        if len(val_losses) > 1:
            break

    return sum(val_losses) / len(val_losses)


def train(model: MtTransformerModel, train_dataset: ParallelTextDataset, val_dataset: ParallelTextDataset) -> None:
    # Configure Tensorboard
    writer = SummaryWriter(TB_LOG_DIR)

    # Create the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=INIT_LR, eps=1e-09)

    initial_epoch = 0
    global_step = 0
    if PRELOAD_MODEL_SUFFIX:
        model_filename = get_weights_file_path(PRELOAD_MODEL_SUFFIX)
        print(f"Preloading model {model_filename}")

        state = torch.load(model_filename)
        initial_epoch = state["epoch"] + 1
        global_step = state["global_step"]

        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])

    loss_func = nn.CrossEntropyLoss(ignore_index=train_dataset.src_tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(DEVICE)

    batch_iterator = train_dataset.batch_iterator(BATCH_SIZE)
    val_batch_iterator = val_dataset.batch_iterator(BATCH_SIZE)

    prev_loss = float('inf')
    val_loss = 0
    for epoch in range(initial_epoch, EPOCHS):
        # Wrap train_dataloader with tqdm to show a progress bar to show
        # how much of the batches have been processed on the current epoch
        batch_iterator = tqdm(batch_iterator, desc=f"Processing epoch {epoch: 02d}", colour="BLUE")

        train_losses = []
        val_losses = []
        # Iterate through the batches
        for batch in batch_iterator:
            """
                Set the transformer module(the model) to back to training mode
            """
            model.train()

            # Retrieve the data points from the current batch
            encoder_input = batch["encoder_input"].to(DEVICE)       # (batches, seq_len)
            decoder_input = batch["decoder_input"].to(DEVICE)       # (batches, seq_len)
            encoder_mask = batch["encoder_mask"].to(DEVICE)         # (bathes, 1, 1, seq_len)
            decoder_mask = batch["decoder_mask"].to(DEVICE)         # (bathes, 1, seq_len, seq_len)
            label: torch.Tensor = batch['label'].to(DEVICE)         # (batches, seq_len)

            # Perform the forward pass according to the operations defined in
            # the transformer model in order to build the computation graph of the model
            encoder_output = model.encode(encoder_input, encoder_mask)                                  # (batches, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)    # (batches, seq_len, d_model)
            proj_output: torch.Tensor = model.project(decoder_output)                                   # (batches, seq_len, tgt_vocab_size)

            # Compute the training loss
            train_loss: torch.Tensor = loss_func.forward(
                proj_output.view(-1, train_dataset.tgt_tokenizer.get_vocab_size()),     # (batches, seq_len, tgt_vocab_size) --> (batches*seq_len, tgt_vocab_size)
                label.view(-1)                                                          # (batches, seq_len) --> (batches * seq_len, )
            )

            if global_step % 200 == 0:
                # Evaluate the model on the validation dataset(aka unseen data)
                val_loss = validate(model, val_batch_iterator, loss_func)

                # Log the training and validation loss on tensorboard
                writer.add_scalars("Cross-Entropy-Loss", { "Training": train_loss.item(), "Validation": val_loss }, global_step)
            else:
                writer.add_scalars("Cross-Entropy-Loss", { "Training": train_loss.item() }, global_step)

            writer.flush()

            # Add the calculated training loss and validation loss as a postfix to the progress bar shown by tqdm
            batch_iterator.set_postfix({"train_loss": f"{train_loss.item():6.3f}"})

            # Perform the backward pass on the computation graph built during the forward pass,
            # in order to calculate the grad for each of the intermediate and leaf tensors on the computation graph
            train_loss.backward()

            # Update the model parameters
            optimizer.step()

            # Zero the gradients of the model parameters to prevent gradient accumulation
            optimizer.zero_grad()

            train_losses.append(train_loss.item())
            val_losses.append(val_loss)

            global_step += 1

        current_avg_train_loss = sum(train_losses) / len(train_losses)
        current_avg_val_loss = sum(val_losses) / len(val_losses)

        if current_avg_train_loss < prev_loss:
            prev_loss = current_avg_train_loss

            # Save the model at the end of every epoch
            model_filename = get_weights_file_path(f"epoch-{epoch:02d}_avgTrainLoss-{current_avg_train_loss:6.3f}_avgValLoss-{current_avg_val_loss:6.3f}_batch-{BATCH_SIZE}_init_lr-{INIT_LR:.0e}")
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "global_step": global_step
            }, model_filename)

In [None]:
import os

def get_weights_file_path(suffix: str):
    """Generates the file path for saving model weights."""
    return os.path.join(MODEL_DIR, f"{MODEL_PREFIX}{suffix}.pkl")

In [None]:
print(f"Training started on `{DEVICE}` device")
DATASET_PATH = f"{PROJECT_DIR}/data/parallel-corpus-en-am-v3.5.json"
SRC_LANG = "en"
TGT_LANG = "am"
SEQ_LEN = 52

train_dataset, val_dataset, test_dataset = get_dataset()

D_MODEL = 512
N_BLOCKS = 6
HEADS = 8  # Reduced from 32 to 8 for better performance
DROPOUT = 0.1
DFF = 2048
TB_LOG_DIR = f"{PROJECT_DIR}/logs/custom_model"
INIT_LR = 1e-4
PRELOAD_MODEL_SUFFIX = ""


model = get_model(train_dataset.src_tokenizer.get_vocab_size(), train_dataset.tgt_tokenizer.get_vocab_size()).to(DEVICE)

train(model, train_dataset, val_dataset)

Processing epoch  0:   0%|[34m          [0m| 1/509 [00:41<5:54:45, 41.90s/it, train_loss=9.927]

In [None]:
print(f"Testing started on `{DEVICE}` device")
model.eval()

loss_func = nn.CrossEntropyLoss(ignore_index=train_dataset.src_tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(DEVICE)

batch_iterator = tqdm(val_dataset.batch_iterator(BATCH_SIZE), desc=f"Evaluating model on test dataset", colour="GREEN")
losses = []
# Iterate through the batches
for batch in batch_iterator:
    # Retrieve the data points from the current batch
    encoder_input = batch["encoder_input"].to(DEVICE)       # (batches, seq_len)
    decoder_input = batch["decoder_input"].to(DEVICE)       # (batches, seq_len)
    encoder_mask = batch["encoder_mask"].to(DEVICE)         # (bathes, 1, 1, seq_len)
    decoder_mask = batch["decoder_mask"].to(DEVICE)         # (bathes, 1, seq_len, seq_len)
    label: torch.Tensor = batch['label'].to(DEVICE)         # (batches, seq_len)

    # Perform the forward pass according to the operations defined in
    # the transformer model in order to build the computation graph of the model
    encoder_output = model.encode(encoder_input, encoder_mask)                                  # (batches, seq_len, d_model)
    decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)    # (batches, seq_len, d_model)
    proj_output: torch.Tensor = model.project(decoder_output)                                   # (batches, seq_len, tgt_vocab_size)

    # Compute the training loss
    test_loss: torch.Tensor = loss_func(
        proj_output.view(-1, train_dataset.tgt_tokenizer.get_vocab_size()),     # (batches, seq_len, tgt_vocab_size)  -->  (batches*seq_len, tgt_vocab_size)
        label.view(-1)                                                          # (batches, seq_len)                  -->  (batches * seq_len, )
    )

    # Add the calculated test loss as a postfix to the progress bar shown by tqdm
    batch_iterator.set_postfix({"test_loss": f"{test_loss.item():6.3f}"})

    losses.append(test_loss.item())

avg_loss = sum(losses) / len(losses)
print(f"\nTesting finished with an average cross entropy of {avg_loss}")

In [None]:
inference_engine = MtInferenceEngine(model, train_dataset.src_tokenizer, train_dataset.tgt_tokenizer)
user_input = input("Enter a short english sentence: ")
prediction = inference_engine.translate(user_input, 10)
print(f"\n Predicted: {prediction}")

In [None]:
import torch
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Define the file path
model_filename = "/content/drive/My Drive/Ashu_NLP/models/tmodel_epoch-100.pt"

# Save the model's state dictionary
torch.save(model.state_dict(), model_filename)