<a href="https://colab.research.google.com/github/SandroMuradashvili/The-Georgian-Spellcheck/blob/main/interface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Georgian Spellcheck - Inference Notebook
Load trained model and test corrections
"""

# ============================================================================
# SETUP & DEPENDENCIES
# ============================================================================

import numpy as np
import torch
import torch.nn as nn
import pickle
from typing import List

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# LOAD VOCABULARY
# ============================================================================

print("Loading vocabulary...")
with open('vocab.pkl', 'rb') as f:
    vocab_data = pickle.load(f)

char_to_idx = vocab_data['char_to_idx']
idx_to_char = vocab_data['idx_to_char']
vocab_size = vocab_data['vocab_size']

print(f"Vocabulary size: {vocab_size}")

# ============================================================================
# MODEL ARCHITECTURE (Same as training)
# ============================================================================

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=char_to_idx['<PAD>'])
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True,
                           dropout=dropout if num_layers > 1 else 0, bidirectional=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        embedded = self.dropout(self.embedding(x))
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, (hidden, cell)

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim * 3, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)
        return torch.softmax(attention, dim=1)

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=char_to_idx['<PAD>'])
        self.attention = Attention(hidden_dim)
        self.lstm = nn.LSTM(embed_dim + hidden_dim * 2, hidden_dim, num_layers,
                           batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, hidden, cell, encoder_outputs):
        embedded = self.dropout(self.embedding(x))
        attn_weights = self.attention(hidden[-1], encoder_outputs)
        attn_weights = attn_weights.unsqueeze(1)
        context = torch.bmm(attn_weights, encoder_outputs)
        lstm_input = torch.cat((embedded, context), dim=2)
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell

class Seq2SeqSpellchecker(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.3):
        super().__init__()
        self.encoder = Encoder(vocab_size, embed_dim, hidden_dim, num_layers, dropout)
        self.decoder = Decoder(vocab_size, embed_dim, hidden_dim, num_layers, dropout)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]

        encoder_outputs, (hidden, cell) = self.encoder(src)

        hidden = hidden.view(self.num_layers, 2, batch_size, self.hidden_dim)
        hidden = hidden.permute(0, 2, 1, 3).contiguous().view(self.num_layers, batch_size, -1)
        hidden = hidden[:, :, :self.hidden_dim].contiguous()

        cell = cell.view(self.num_layers, 2, batch_size, self.hidden_dim)
        cell = cell.permute(0, 2, 1, 3).contiguous().view(self.num_layers, batch_size, -1)
        cell = cell[:, :, :self.hidden_dim].contiguous()

        outputs = torch.zeros(batch_size, trg_len, self.decoder.vocab_size).to(src.device)
        input_token = trg[:, 0].unsqueeze(1)

        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell, encoder_outputs)
            outputs[:, t] = output
            teacher_force = False  # No teacher forcing in inference
            top1 = output.argmax(1).unsqueeze(1)
            input_token = trg[:, t].unsqueeze(1) if teacher_force else top1

        return outputs

# ============================================================================
# LOAD TRAINED MODEL
# ============================================================================

print("\nLoading trained model...")
model = Seq2SeqSpellchecker(
    vocab_size=vocab_size,
    embed_dim=128,
    hidden_dim=256,
    num_layers=2,
    dropout=0.3
).to(device)

checkpoint = torch.load('best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Model loaded successfully!")
print(f"Validation loss: {checkpoint['val_loss']:.4f}")

# ============================================================================
# INFERENCE FUNCTION
# ============================================================================

def correct_word(word: str, model_path: str = 'best_model.pt', max_length: int = 50) -> str:
    """
    Takes a potentially misspelled Georgian word and returns the corrected version.

    Args:
        word: Input word (potentially misspelled)
        model_path: Path to trained model (not used here since model already loaded)
        max_length: Maximum output length

    Returns:
        Corrected word
    """
    model.eval()

    with torch.no_grad():
        # Convert word to indices
        input_indices = [char_to_idx.get(c, char_to_idx['<UNK>']) for c in word]
        input_tensor = torch.tensor([input_indices]).to(device)

        # Encode
        encoder_outputs, (hidden, cell) = model.encoder(input_tensor)

        # Prepare decoder hidden states
        batch_size = 1
        hidden = hidden.view(model.num_layers, 2, batch_size, model.hidden_dim)
        hidden = hidden.permute(0, 2, 1, 3).contiguous().view(model.num_layers, batch_size, -1)
        hidden = hidden[:, :, :model.hidden_dim].contiguous()

        cell = cell.view(model.num_layers, 2, batch_size, model.hidden_dim)
        cell = cell.permute(0, 2, 1, 3).contiguous().view(model.num_layers, batch_size, -1)
        cell = cell[:, :, :model.hidden_dim].contiguous()

        # Start decoding with <SOS> token
        decoder_input = torch.tensor([[char_to_idx['<SOS>']]]).to(device)

        output_chars = []
        for _ in range(max_length):
            output, hidden, cell = model.decoder(decoder_input, hidden, cell, encoder_outputs)
            predicted_idx = output.argmax(1).item()

            # Check for <EOS> or <PAD>
            if predicted_idx == char_to_idx['<EOS>'] or predicted_idx == char_to_idx['<PAD>']:
                break

            # Skip special tokens
            if predicted_idx not in [char_to_idx['<SOS>'], char_to_idx['<UNK>']]:
                output_chars.append(idx_to_char[predicted_idx])

            decoder_input = torch.tensor([[predicted_idx]]).to(device)

        return ''.join(output_chars)

# ============================================================================
# TEST EXAMPLES
# ============================================================================

print("\n" + "="*70)
print("SPELLCHECK DEMONSTRATION")
print("="*70 + "\n")

# Test cases: mix of errors and correct words
test_words = [
    # Misspelled words (various error types)
    "გამარჰობა",     # გამარჯობა (hello) - substitution
    "თბილსი",        # თბილისი (Tbilisi) - deletion
    "საქარტველო",    # საქართველო (Georgia) - substitution
    "პროგამა",       # პროგრამა (program) - deletion
    "კომპიუტერი",    # Already correct
    "გაიარჯობა",     # გამარჯობა - substitution/insertion
    "ქართუული",      # ქართული (Georgian) - insertion
    "უნივრსიტეტი",  # უნივერსიტეტი (university) - substitution
    "დილამშვიდობისა", # Already correct (good morning)
    "მადლბა",        # მადლობა (thank you) - deletion
    "თბილისი",       # Already correct (Tbilisi)
    "საქართველო",    # Already correct (Georgia)
    "ბატონი",        # Already correct (mister)
    "ქალბატონი",     # Already correct (madam)
    "დღეს",          # Already correct (today)
    "ხვალ",          # Already correct (tomorrow)
    "წიგნი",         # Already correct (book)
    "სკოლა",         # Already correct (school)
    "სახლი",         # Already correct (house)
    "მანქანა",       # Already correct (car)
]

print(f"{'Input Word':<25} {'Corrected Word':<25} {'Status'}")
print("-" * 70)

for word in test_words:
    corrected = correct_word(word)
    status = "✓ No change" if word == corrected else "✎ Corrected"
    print(f"{word:<25} {corrected:<25} {status}")

# ============================================================================
# INTERACTIVE TESTING
# ============================================================================

print("\n" + "="*70)
print("INTERACTIVE MODE")
print("="*70)
print("\nEnter Georgian words to correct (or 'quit' to exit):\n")

while True:
    user_input = input("Enter word: ").strip()

    if user_input.lower() in ['quit', 'exit', 'q']:
        print("Exiting...")
        break

    if not user_input:
        continue

    try:
        corrected = correct_word(user_input)
        if user_input == corrected:
            print(f"Result: {corrected} (no correction needed)")
        else:
            print(f"Result: {user_input} → {corrected}")
    except Exception as e:
        print(f"Error: {e}")

    print()

print("\nDone!")

Using device: cpu
Loading vocabulary...
Vocabulary size: 37

Loading trained model...
Model loaded successfully!
Validation loss: 0.2409

SPELLCHECK DEMONSTRATION

Input Word                Corrected Word            Status
----------------------------------------------------------------------
გამარჰობა                 გამარჰობაა                ✎ Corrected
თბილსი                    თბილისი                   ✎ Corrected
საქარტველო                საქარტველო                ✓ No change
პროგამა                   როგამაა                   ✎ Corrected
კომპიუტერი                კომპიუტერი                ✓ No change
გაიარჯობა                 გაიარჯობაა                ✎ Corrected
ქართუული                  ქართული                   ✎ Corrected
უნივრსიტეტი               უნივრსიტეტი               ✓ No change
დილამშვიდობისა            დილამშვიდობისა            ✓ No change
მადლბა                    მადლბაა                   ✎ Corrected
თბილისი                   თბილისი                   ✓ No change
სა