# `predict.py`

This file implements inference, loading the trained model to correct words.

The `SpellChecker` class loads the model, converts words to tensors (appending EOS), encodes the input, then decodes greedily (picking the highest-probability char each step) until EOS or a safety limit (50 chars).

No teacher forcing here - it's pure autoregressive generation. The `correct_word` function wraps this for easy use.

In testing, it shows corrections on examples like "გამრჯობა" (typo) $\mapsto$ "გამარჯობა", demonstrating the model's ability to fix errors while preserving correct inputs.

In [15]:
import logging
import random
import torch
from src.Gamarjoba import Gamarjoba
from src.get_data import ALL_GEORGIAN_CHARS


random.seed(95)  # ⚡
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%d/%m/%Y %H:%M:%S')
logger = logging.getLogger("MartltseraLogger (Inference)")

# constants and hyperparameters
HIDDEN_SIZE = 512
NUM_LAYERS = 1
DROPOUT_P = 0.2
SOS_token = 0
EOS_token = 1

char_to_index = {char: i + 2 for i, char in enumerate(ALL_GEORGIAN_CHARS)}
index_to_char = {i + 2: char for i, char in enumerate(ALL_GEORGIAN_CHARS)}
VOCAB_SIZE = len(ALL_GEORGIAN_CHARS) + 2


class SpellChecker:
    def __init__(self, model_path: str = "../models/Martltsera_5.pth"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = Gamarjoba(VOCAB_SIZE, HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout_p=DROPOUT_P)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()  # important for dropout and inference behaviour


    def tensor_from_word(self, word: str) -> torch.Tensor:
        idxs = [char_to_index.get(char, EOS_token) for char in word if char in char_to_index]
        idxs.append(EOS_token)
        return torch.tensor(idxs, dtype=torch.long, device=self.device).view(-1, 1)


    @staticmethod  # why not
    def idx_to_char(idx: int) -> str:
        return index_to_char.get(idx, "")


    def fix(self, word: str) -> str:
        if not word.strip():  # edge case
            return word

        with torch.no_grad():
            input_tensor = self.tensor_from_word(word)

            encoder_hidden, encoder_cell = self.model.encoder.init_hidden(self.device)
            for i in range(input_tensor.size(0)):
                _, encoder_hidden, encoder_cell = self.model.encoder(input_tensor[i], encoder_hidden, encoder_cell)

            decoder_input = torch.tensor([[SOS_token]], device=self.device)
            decoder_hidden = encoder_hidden
            decoder_cell = encoder_cell

            decoded_chars = []
            while True:
                prediction, decoder_hidden, decoder_cell = self.model.decoder(decoder_input, decoder_hidden, decoder_cell)
                _, top_i = prediction.topk(1)
                idx = top_i.squeeze().item()

                if idx == EOS_token or len(decoded_chars) >= 50:  # safety limit
                    break

                decoded_chars.append(self.idx_to_char(idx))
                decoder_input = top_i.squeeze().detach()  # feed prediction back as next input (no teacher forcing at inference)

            return "".join(decoded_chars)


def correct_word(word: str, model_path: str = "../models/Martltsera_5.pth") -> str:
    model = SpellChecker(model_path=model_path)
    return model.fix(word)


# if __name__ == "__main__":
logger.info("Inference examples:")
for i, w in enumerate(["თბილისი", "საქონელი", "გამარჯობა", "გაგიმარჯოს", "ტელევიზია", "ბაყაყი", "მხედარი", "ნადირობა", "ნაპოლეონი", "ნარუტო", "შოთა", "აზარტი", "კომპიუტერი", "დაცვა", "ძლიერი", "ქართული", "ქართველი", "გამარჯვებულია", "მეცნიერი", "პროგრამისტი"]):
    logger.info(f"{i + 1}) {w} -> {correct_word(w, model_path="../models/Martltsera_5.pth")}")

13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - Inference examples:
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 1) თბილისი -> თბილისი
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 2) საქონელი -> საქონელი
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 3) გამარჯობა -> გამარჯობა
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 4) გაგიმარჯოს -> გაგიმაროოს
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 5) ტელევიზია -> ტელევიიაია
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 6) ბაყაყი -> ბაყაყი
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 7) მხედარი -> მხედარი
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 8) ნადირობა -> ნადირობა
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 9) ნაპოლეონი -> ნაპოლეონი
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 10) ნარუტო -> ნარტულო
13/12/2025 15:04:36 - MartltseraLogger (Inference) - INFO - 11) შოთა -> შოთა
13