# Martltsera: Data and Training

Character-level Seq2Seq Georgian Spellchecker

Martltsera ("მართლწერა", meaning "Spelling" in Georgian) is uses recurrent neural networks (GRU/LSTM) to correct misspelled Georgian words. It operates on a character level, learning the intrinsic orthography of the language to fix typos, missing characters, and keyboard slips.

This project was first implemented using Python scripts in the `src/` folder. This notebook explains the reasoning behind each script.

Libraries and some config:

In [17]:
!pip install numpy torch
import json
import logging
import os
import random
import time
import torch
import torch.nn as nn


random.seed(95)  # ⚡

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger("MartltseraLogger (Train)")



Constants and hyperparameters:

In [18]:
DICTIONARY_SIZE = 50000
HIDDEN_SIZE = 512
NUM_LAYERS = 1
DROPOUT_P = 0.2
LEARNING_RATE = 0.0001
TEACHER_FORCING_RATIO = 0.5
MODEL_SAVE_PATH = "../models/Martltsera_5.pth"
SOS_token = 0
EOS_token = 1

# `get_data.py`

I needed a substantial vocabulary of correctly spelled Georgian words. I collected three JSON chunks (`wordsChunk_0.json`, `wordsChunk_1.json`, `wordsChunk_2.json`) from https://github.com/AleksandreSukh/GeorgianWordsDataBase and stored them in the `data_src/` directory, containing around 270K unique words in total after deduplication and cleaning. I restricted the dataset to random 50K samples, to afford training.

### Data Generation: Creating Training Pairs

To create the training pairs required for the seq2seq model, I focused on generating synthetic misspelled versions of the collected correct Georgian words. This method was chosen because collecting real-world misspelled data paired with corrections is challenging and time-consuming, especially for a low-resource language like Georgian. Instead, synthetic generation allows for scalable, controlled creation of diverse error patterns that simulate common typographical mistakes. By programmatically introducing errors, the model learns to map noisy inputs back to clean targets, effectively internalizing Georgian orthography rules without relying on external labeled error corpora.

The key to effective synthetic errors is realism: most real typos are small perturbations (e.g., edit distance 1-2), often stemming from keyboard layouts where fingers slip to adjacent keys. To capture this, I incorporated a Georgian QWERTY keyboard neighbor map (`KEYBOARD_NEIGHBORS`), which defines nearby characters for each letter. For instance, 'ა' has neighbors ['ს', 'ქ', 'ზ'], reflecting physical proximity on the keyboard. This ensures errors feel natural, like substituting 'ს' for 'ა' due to a slight mispress, rather than arbitrary random changes that could confuse the model.

The core augmentation logic is implemented in the `adapt_or_die` function, which processes each correct word character by character:

1. **Error Probability Adjustment**: Shorter words (<7 characters) have a higher per-character error probability (15%) compared to longer ones (10%). This prevents short words from often remaining unchanged, ensuring they contribute meaningfully to training while avoiding excessive distortion in longer, more complex words.

2. **Character Filtering**: Non-Georgian characters (outside Unicode range 4304–4336, i.e., 'ა' to 'ჰ') are preserved unchanged, as the model focuses solely on Mkhedruli script corrections.

3. **Probabilistic Error Introduction**: For each Georgian character, if an error is triggered (based on the probability), a random action is selected:
   - **Substitution (60% chance)**: Replace the character with a typo from `get_typo_char`. This function favors keyboard neighbors (95% chance) for realism, falling back to a fully random Georgian character only 5% of the time to add slight variety without overcomplicating patterns.
   - **Deletion (20% chance)**: Simply skip the character, shortening the word.
   - **Insertion (20% chance)**: Keep the original character and append a typo character after it, lengthening the word.

   This distribution prioritizes substitutions, as they are the most common real-world typos (e.g., hitting the wrong key), while deletions and insertions handle cases like missed or extra keystrokes.

4. **Guaranteed Perturbation**: To ensure every generated input is actually misspelled (crucial for training, as identical input-target pairs are added separately), if no error occurred during the pass, a forced substitution is applied at a random position (again using `get_typo_char` for consistency).

This algorithm produces varied, realistic errors that mimic human typing behaviors, such as finger slips or hurried input, while keeping most changes minimal to reflect the assignment's note on edit distances. For example, running `adapt_or_die` on test words yields pairs like:
- "გამარჯობა" $\mapsto$ "გამარჯობს" (substitution of 'ა' to 'ს' at the end, a neighbor error)
- "გაგიმარჯოს" $\mapsto$ "აგიმარჯოს" (deletion of the first 'გ')
- "ლადო" $\mapsto$ "ოადო" (substitution of 'ლ' to 'ო', a neighbor via 'ო' being near 'ლ')
- "ბრიუს ვეინი" $\mapsto$ "ბრჯუს ვეინი" (insertion or substitution involving 'ჯ')
- And so on for others, demonstrating small, plausible changes.

Importantly, to promote model stability, I augmented the dataset with cases where the input is already correct (approximately 20% of pairs, sampled randomly in `train.py`). This teaches the model to output the input unchanged when no correction is needed, reducing the risk of over-correction during inference (e.g., altering valid rare words). Without this, the model might assume every input requires changes, leading to hallucinations. This directly addresses the assignment's question: "Your dataset should include cases where the input is already correct. Why might this matter?" - it ensures robustness and prevents unnecessary modifications, enhancing practical usability. Overall, this creative, keyboard-aware approach rewards data quality over quantity, as emphasized in the hints, resulting in a balanced dataset of ~50,000 pairs that effectively trains the model on intrinsic Georgian patterns.

The `get_dataset_pairs` function loads the words, shuffles them, and applies `adapt_or_die` to create (misspelled, correct) pairs, skipping words shorter than 3 characters to focus on meaningful examples.

In [19]:
ALL_GEORGIAN_CHARS = [chr(i) for i in range(4304, 4337)]

# standard Georgian QWERTY keyboard neighbors to simulate realistic finger slips
KEYBOARD_NEIGHBORS: dict[str, list[str]] = {
    'ა': ['ს', 'ქ', 'ზ'],
    'ბ': ['ვ', 'გ', 'ჰ', 'ნ'],
    'გ': ['ფ', 'ტ', 'ყ', 'ჰ', 'ბ', 'ვ'],
    'დ': ['ს', 'ე', 'რ', 'ფ', 'ც', 'ხ'],
    'ე': ['წ', 'რ', 'დ', 'ს'],
    'ვ': ['ც', 'ფ', 'გ', 'ბ'],
    'ზ': ['ა', 'ს', 'ხ'],
    'თ': ['ღ', 'ყ', 'გ', 'ფ'],  # whenever we use "Shift +", all the neighbors are also with "Shift +" (if there exists one)
    'ი': ['უ', 'ო', 'ჯ', 'ჰ'],
    'კ': ['ჯ', 'ი', 'ო', 'ლ', 'მ'],
    'ლ': ['კ', 'მ', 'ო', 'პ'],
    'მ': ['ნ', 'ჯ', 'კ', 'ლ'],
    'ნ': ['ბ', 'ჰ', 'ჯ', 'მ'],
    'ო': ['ი', 'პ', 'ლ', 'კ', 'ჯ'],
    'პ': ['ო', 'ლ'],
    'ჟ': ['ჰ', 'უ', 'ი', 'კ', 'მ', 'ნ'],
    'რ': ['ე', 'ტ', 'ფ', 'დ'],
    'ს': ['ა', 'წ', 'ე', 'დ', 'ხ', 'ზ'],
    'ტ': ['რ', 'ყ', 'გ', 'ფ'],
    'უ': ['ყ', 'ი', 'ჯ', 'ჰ'],
    'ფ': ['დ', 'რ', 'ტ', 'გ', 'ვ', 'ც'],
    'ქ': ['წ', 'ა'],
    'ღ': ['ე', 'თ', 'ფ', 'დ'],
    'ყ': ['თ', 'უ', 'ჰ', 'გ'],
    'შ': ['ა', 'წ', 'ე', 'დ', 'ხ', 'ზ'],
    'ჩ': ['ხ', 'დ', 'ფ', 'ვ'],
    'ც': ['ხ', 'დ', 'ფ', 'ვ'],
    'ძ': ['ა', 'ს', 'ხ'],
    'წ': ['ქ', 'ე', 'ს', 'ა'],
    'ჭ': ['ქ', 'ე', 'ს', 'ა'],
    'ხ': ['ა', 'ს', 'დ', 'ც'],
    'ჯ': ['ჰ', 'უ', 'ი', 'კ', 'მ', 'ნ'],
    'ჰ': ['გ', 'ყ', 'უ', 'ჯ', 'ნ', 'ბ']
}


def read_sources(path: str) -> list[str]:
    # for script version: data_src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path))
    data_src_path = os.path.abspath(os.path.join(os.getcwd(), path))
    res: list[str] = []

    for i in range(3):
        file_path = os.path.join(data_src_path, f"wordsChunk_{i}.json")
        with open(file_path, "r", encoding="utf-8") as f:
            content = json.load(f)
            res.extend(content)

    return res


def get_typo_char(c: str) -> str:
    # 95% chance to pick a neighbor, 5% random char
    if c in KEYBOARD_NEIGHBORS and random.random() > 0.05:
        return random.choice(KEYBOARD_NEIGHBORS[c])
    return random.choice(ALL_GEORGIAN_CHARS)


def adapt_or_die(word: str) -> str:
    res = ""
    error_made = False

    # higher chance of error per char for shorter words otherwise short words often remain untouched
    prob = 0.15 if len(word) < 7 else 0.10

    for c in word:
        # ord('ა') == 4304 and ord('ჰ') == 4336
        if not (4304 <= ord(c) <= 4336):
            res += c
            continue

        if random.random() < prob:
            # 0 -> substitution p = 0.6; 1 -> deletion p = 0.2; 2 -> insertion p = 0.2
            action_roll = random.random()
            error_made = True

            if action_roll < 0.6:
                res += get_typo_char(c)  # substitution jutsu!
            elif action_roll < 0.8:
                continue  # deletion
            else:
                res += c + get_typo_char(c)  # insertion
        else:
            res += c

    # ensure the word is actually corrupted for training purposes
    if not error_made and len(res) > 0:
        idx = random.randint(0, len(res) - 1)
        temp_list = list(res)
        if 4304 <= ord(temp_list[idx]) <= 4336:
            temp_list[idx] = get_typo_char(temp_list[idx])
        res = "".join(temp_list)

    return res


def get_dataset_pairs(path: str = "../data_src", dictionary_size: int = 50000) -> list[tuple[str, str]]:
    pure_data = read_sources(path)
    random.shuffle(pure_data)

    dataset = []
    count = 0

    for word in pure_data:
        if count >= dictionary_size:
            break

        if len(word) < 3:  # skip extremely short words
            continue

        dataset.append((adapt_or_die(word), word))
        count += 1

    return dataset


# if __name__ == "__main__":
print("<== Testing adapt_or_die ==>")
test_words = ["გამარჯობა", "გაგიმარჯოს", "ლადო", "ბრიუს ვეინი", "კომპიუტერი", "სიყვარული", "მაყუთი", "წიგნიერება", "გოკუ", "ტელეფონი", "ბეტმენი"]
for w in test_words:
    print(f"{w} -> {adapt_or_die(w)}")

print("\n<== Testing dataset generation ==>")
data = get_dataset_pairs(dictionary_size=15)
for corrupted, original in data:
    print(f"Corrupted: {corrupted} | Original: {original}")

<== Testing adapt_or_die ==>
გამარჯობა -> გამარჯობს
გაგიმარჯოს -> აგიმარჯოს
ლადო -> ოადო
ბრიუს ვეინი -> ბრჯუს ვეინი
კომპიუტერი -> კჯმპიუტერი
სიყვარული -> სითვარული
მაყუთი -> მაგუთი
წიგნიერება -> წიგნიერებ
გოკუ -> გოკჰ
ტელეფონი -> გწლეფნი
ბეტმენი -> ბეტჯენი

<== Testing dataset generation ==>
Corrupted: დაცურათ | Original: დავცურავთ
Corrupted: მინდორიდ | Original: მინდორიც
Corrupted: დავუმტკიცე-მეთქაი | Original: დავუმტკიცე-მეთქი
Corrupted: ხეცურებულან | Original: შეცურებულან
Corrupted: მსმალმა | Original: მამალმა
Corrupted: ძდვენთა | Original: ძღვენთა
Corrupted: სექჭიები | Original: სექციები
Corrupted: გაევფერა | Original: გაშვერა
Corrupted: გამოსროლილმს | Original: გამოსროლილმა
Corrupted: საუთოწბლებზე | Original: საუთოებლებზე
Corrupted: აუყეხიათ | Original: აუტეხიათ
Corrupted: ცრემლიც | Original: ცრემლიც
Corrupted: შეეყალა | Original: შეეწყალა
Corrupted: გაერთიანდებისნ | Original: გაერთიანდებიან
Corrupted: ლეონოდე | Original: ლეონიდე


# `Gamarjoba.py`

This file defines the core model architecture, named "Gamarjoba" (Georgian for "hello"), which is an encoder-decoder LSTM network.

## LSTM (Long Short-Term Memory) Justification:
I chose the LSTM architecture because Georgian words have a complex structure. A single word often consists of a root with many prefixes and suffixes attached to it.
To fix a spelling error, the model needs to remember the beginning of the word (the root) to figure out the correct ending.
Standard RNNs often struggle to remember this information over long sequences because of the "vanishing gradient" problem.
LSTMs solve this using a Cell State - a type of internal memory that acts like a highway, allowing information to travel through the network without getting lost.
This makes the LSTM perfect for learning the long and complex character patterns found in the Georgian language.

## Encoder-Decoder Architecture Justification:
I implemented an Encoder-Decoder structure because spelling errors often change the length of a word.
For example, if a user misses a key or types an extra one, the input length will differ from the correct output.
This architecture solves that problem by separating the task into two parts.
The Encoder reads the entire misspelled word first to capture its full context.
The Decoder then uses that information to generate the correct word character by character.
This allows the model to handle complex errors like insertions and deletions, not just simple letter replacements.

In [20]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, dropout_p: float = 0.2):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(input_size, hidden_size)  # convert character indices to dense vectors
        self.dropout = nn.Dropout(self.dropout_p)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)  # this saved me tons of time


    def forward(self, input_seq: torch.Tensor, hidden: torch.Tensor, cell: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        embedded = self.embedding(input_seq).view(1, 1, -1)  # Shape (N) -> Shape (seq_len=1, batch_size=1, input_size=N); as 1 * 1  = 1, -1 will automatically infer the embedding dimension
        output = self.dropout(embedded)
        output, (hidden, cell) = self.lstm(output, (hidden, cell))  # updates hidden/cell states based on current input

        return output, hidden, cell

    # LSTM processes data sequentially. Since there is no previous memory at the start, we must create one: (h_0, c_0).
    def init_hidden(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.zeros(self.num_layers, 1, self.hidden_size, device=device), torch.zeros(self.num_layers, 1, self.hidden_size, device=device)  # (num_layers,


class DecoderLSTM(nn.Module):
    def __init__(self, output_size: int, hidden_size: int, num_layers: int = 1, dropout_p: float = 0.2):
        super(DecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(output_size, hidden_size)  # embeds the previous character (or SOS)
        self.dropout = nn.Dropout(self.dropout_p)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers)  # processes the sequence step-by-step
        self.out = nn.Linear(hidden_size, output_size)  # projects hidden state to vocabulary size (logits)
        self.softmax = nn.LogSoftmax(dim=1)  # converts logits to log-probabilities for prediction


    def forward(self, input_step: torch.Tensor, hidden: torch.Tensor, cell: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        output = self.embedding(input_step).view(1, 1, -1)
        output = self.dropout(output)
        output = torch.nn.functional.relu(output)
        output, (hidden, cell) = self.lstm(output, (hidden, cell))  # updates state based on input and previous state
        prediction = self.softmax(self.out(output[0]))  # computes probability distribution over vocabulary

        return prediction, hidden, cell


class Gamarjoba(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int, num_layers: int = 1, dropout_p: float = 0.2):
        super(Gamarjoba, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.encoder = EncoderLSTM(vocab_size, hidden_size, num_layers, dropout_p).to(self.device)
        self.decoder = DecoderLSTM(vocab_size, hidden_size, num_layers, dropout_p).to(self.device)
        self.vocab_size = vocab_size


    # Teacher Forcing: A training method where the model is sometimes fed the actual correct previous character instead of its own predicted guess to speed up convergence.
    def forward(self, input_tensor: torch.Tensor, target_tensor: torch.Tensor, teacher_forcing_ratio: float = 0.5) -> torch.Tensor:
        input_length = input_tensor.size(0)
        target_length = target_tensor.size(0)

        encoder_hidden, encoder_cell = self.encoder.init_hidden(self.device)  # initializes hidden state with zeros

        for i in range(input_length):
            _, encoder_hidden, encoder_cell = self.encoder(input_tensor[i], encoder_hidden, encoder_cell)  # builds context vector

        decoder_input = torch.tensor([[SOS_token]], device=self.device)  # Start-Of-Sequence token assumed 0; starts decoding
        decoder_hidden = encoder_hidden  # passes the Encoder's final memory to Decoder
        decoder_cell = encoder_cell

        outputs = torch.zeros(target_length, self.vocab_size, device=self.device)

        for i in range(target_length):
            decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)  # predict next char
            outputs[i] = decoder_output.squeeze(0)  # stores the prediction

            use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False  # decides strategy randomly

            if use_teacher_forcing:
                decoder_input = target_tensor[i]  # feeds the correct character as next input
            else:
                _, top_i = decoder_output.topk(1)  # gets the character index with the highest probability
                decoder_input = top_i.squeeze().detach()  # detach() prevents backprop through this step (treats input as constant)
                if decoder_input.item() == EOS_token:  # End-Of-Sequence token assumed 1
                    break

        return outputs

# `train.py`

This script manages the training process, including dataset splitting, batching, and optimization.

I split the pairs into 80% train / 20% validation for monitoring overfitting. Batching is done with a size of 64 (adjustable), processing batches sequentially but computing loss per word to handle variable lengths without padding (though for efficiency, true batching with padding could be added).

The training loop:
- Uses Adam optimizer with `LR=0.0001` (conservative to avoid instability).
- `NLLLoss` for multi-class character prediction.
- Tracks average loss per token.
- Validates after each epoch, saving checkpoints if val loss improves.
- Early stopping with `patience=5` to prevent unnecessary computation.

Hyperparameters like `hidden_size=512` and `num_layers=1` balance capacity and speed. I added 20% correct pairs to the dataset here for stability.

This demonstrates the full workflow: splitting, batching, loss tracking, and checkpointing based on validation.

## Design Considerations for Variable Output Length
- I adopted a classic encoder-decoder (seq2seq) architecture with no forced alignment.
The encoder (LSTM) processes the entire misspelled input and compresses it into a fixed-size context (final hidden and cell states).
- The decoder (also LSTM) then autoregressively generates the corrected word one character at a time, starting from a `<SOS>` token and stopping when it predicts `<EOS>`.
- This design allows the output length to be completely independent of input length, naturally handling insertions, deletions, and substitutions without any padding or length constraints.

## Character Vocabulary and Handling the Georgian Alphabet
- The Georgian Mkhedruli script consists of $33$ modern letters (U+10D0 to U+10F0, but contiguous from U+10D0='ა' to U+10FC='ჰ').
- In the provided code, `ALL_GEORGIAN_CHARS = [chr(i) for i in range(4304, 4337)]`, which covers exactly these 33 characters ($4304$ is 10D0 in hex; $4336$ is 10FC in hex).
- The vocabulary size is therefore $33 + 2$ special tokens $= 35$.
- Character-to-index mapping assigns indices $2 \dots 34$ to the $33$ letters (in Unicode order), reserving $0$ for `<SOS>` and $1$ for `<EOS>`.
- During tokenization, any character not in this set is skipped (`tensor_from_word` filters with `if char in char_to_index`), which is safe because the dataset contains only Georgian words and the model focuses exclusively on Mkhedruli script.

## Special Tokens
- `<SOS>` (index $0$): Used only to initialise the decoder at inference time (and implicitly during teacher forcing in training). It signals the start of generation and provides a neutral starting point for the decoder LSTM.
- `<EOS>` (index $1$): Appended to every target sequence during training so the model learns to predict it after the last real character. At inference, greedy decoding stops when `<EOS>` is predicted, preventing over-generation.
- No `<PAD>` token: The model processes sequences one character at a time (`batch_size=1` per word effectively, even in batches) and lengths vary naturally, so padding is unnecessary.

In [21]:
char_to_index = {char: i + 2 for i, char in enumerate(ALL_GEORGIAN_CHARS)}
index_to_char = {i + 2: char for i, char in enumerate(ALL_GEORGIAN_CHARS)}
VOCAB_SIZE = len(ALL_GEORGIAN_CHARS) + 2

def tensor_from_word(word: str, device: torch.device) -> torch.Tensor:
    indexes = [char_to_index[char] for char in word if char in char_to_index]
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)


def train_batch(model: Gamarjoba, optimizer: torch.optim.Optimizer, criterion: nn.NLLLoss, batch_pairs: list[tuple[str, str]], device: torch.device) -> float:
    optimizer.zero_grad()
    total_loss = torch.tensor(0.0, device=device)
    total_tokens = 0

    for input_word, target_word in batch_pairs:
        input_tensor = tensor_from_word(input_word, device)
        target_tensor = tensor_from_word(target_word, device)
        target_length = target_tensor.size(0)
        outputs = model(input_tensor, target_tensor, teacher_forcing_ratio=TEACHER_FORCING_RATIO)
        example_loss = torch.tensor(0.0, device=device)

        for i in range(target_length):
            example_loss += criterion(outputs[i].unsqueeze(0), target_tensor[i])

        total_loss += example_loss
        total_tokens += target_length

    if total_tokens > 0:
        avg_loss = total_loss / total_tokens
        avg_loss.backward()
        optimizer.step()
        return avg_loss.item()

    optimizer.step()
    return 0.0


def validate(model: Gamarjoba, criterion: nn.NLLLoss, val_pairs: list[tuple[str, str]], device: torch.device) -> float:
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for input_word, target_word in val_pairs:
            input_tensor = tensor_from_word(input_word, device)
            target_tensor = tensor_from_word(target_word, device)
            target_length = target_tensor.size(0)
            outputs = model(input_tensor, target_tensor, teacher_forcing_ratio=TEACHER_FORCING_RATIO)
            example_loss = 0.0

            for i in range(target_length):
                example_loss += criterion(outputs[i].unsqueeze(0), target_tensor[i]).item()

            total_loss += example_loss
            total_tokens += target_length

    model.train()
    return total_loss / total_tokens if total_tokens > 0 else 0.0


def get_batches(pairs: list[tuple[str, str]], batch_size: int):
    for i in range(0, len(pairs), batch_size):
        yield pairs[i:i + batch_size]


def train_model(epochs: int, batch_size: int):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    model = Gamarjoba(VOCAB_SIZE, HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout_p=DROPOUT_P)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.NLLLoss()

    logger.info("Generating dataset...")
    pairs = get_dataset_pairs(dictionary_size=DICTIONARY_SIZE)
    correct_pairs = [(w, w) for _, w in random.sample(pairs, int(0.2 * len(pairs)))]  # add correct pairs
    pairs += correct_pairs
    random.shuffle(pairs)
    n = len(pairs)
    train_size = int(0.8 * n)
    train_pairs = pairs[:train_size]
    val_pairs = pairs[train_size:]
    logger.info(f"Dataset ready: {len(train_pairs)} train pairs, {len(val_pairs)} val pairs.")

    start_time = time.time()
    log_interval = 200
    loss_interval = 0.0
    total_steps = epochs * len(train_pairs)
    iter_count = 0
    best_val_loss = float('inf')
    patience = 5
    counter = 0

    logger.info(f"Starting training for {epochs} epochs...")

    for epoch in range(1, epochs + 1):
        random.shuffle(train_pairs)
        for batch_pairs in get_batches(train_pairs, batch_size):
            batch_len = len(batch_pairs)
            iter_count += batch_len
            loss = train_batch(model, optimizer, criterion, batch_pairs, device)
            loss_interval += loss * batch_len
            if iter_count % log_interval == 0:
                loss_avg = loss_interval / log_interval
                logger.info(f"{iter_count} steps ({iter_count / total_steps * 100:.0f}% complete) | Loss: {loss_avg:.3f}")
                loss_interval = 0.0

        val_loss = validate(model, criterion, val_pairs, device)
        logger.info(f"Epoch {epoch} Validation Loss: {val_loss:.3f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs(os.path.dirname(MODEL_SAVE_PATH[:-4] + f"_{epoch}.pth"), exist_ok=True)
            torch.save(model.state_dict(), MODEL_SAVE_PATH[:-4] + f"_{epoch}.pth")
            logger.info(f"New best model saved to {MODEL_SAVE_PATH[:-4] + f"_{epoch}.pth"}")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                logger.info(f"Early stopping at epoch {epoch}")
                break

Training time (took $\approx 100$ minutes on Colab's T4 GPU)!

***Note***: I stopped training the execution of the next cell simply because I trained the model on Colab and downloaded the models from there, as it would take over $300$ minutes on my machine to train the model for $5$ epochs, whereas it only took $\approx 100$ minutes on Colab.

These were the last 2 lines of logs from Colab:
```bash
13/12/2025 10:54:25 - MartltseraLogger (Train) - INFO - 240000 steps (100% complete) | Loss: 7.263
13/12/2025 10:55:42 - MartltseraLogger (Train) - INFO - Epoch 5 Validation Loss: 0.877
```

In [22]:
train_model(5, 64)

13/12/2025 15:04:14 - MartltseraLogger (Train) - INFO - Using device: cpu
13/12/2025 15:04:14 - MartltseraLogger (Train) - INFO - Generating dataset...
13/12/2025 15:04:14 - MartltseraLogger (Train) - INFO - Dataset ready: 48000 train pairs, 12000 val pairs.
13/12/2025 15:04:14 - MartltseraLogger (Train) - INFO - Starting training for 5 epochs...


KeyboardInterrupt: 