In [1]:
import json
import torch
import torch.nn as nn
from pathlib import Path

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

PyTorch version: 2.9.1+cpu
CUDA available: False
Using device: cpu


## RUN  ALL  THE CELLs, at the end is the interactive testing

In [2]:
class CharVocab:
    """Character-level vocabulary for Georgian text."""

    def __init__(self):
        self.PAD_TOKEN = '<PAD>'
        self.SOS_TOKEN = '<SOS>'  # Start of sequence
        self.EOS_TOKEN = '<EOS>'  # End of sequence
        self.UNK_TOKEN = '<UNK>'  # Unknown character

        self.char2idx = {
            self.PAD_TOKEN: 0,
            self.SOS_TOKEN: 1,
            self.EOS_TOKEN: 2,
            self.UNK_TOKEN: 3,
        }
        self.idx2char = {v: k for k, v in self.char2idx.items()}
        self.next_idx = 4

    def build_vocab(self, words):
        """Build vocabulary from list of words."""
        for word in words:
            for char in word:
                if char not in self.char2idx:
                    self.char2idx[char] = self.next_idx
                    self.idx2char[self.next_idx] = char
                    self.next_idx += 1
        print(f"Vocabulary size: {len(self.char2idx)}")
        return self

    def encode(self, text, add_sos=False, add_eos=True):
        """Convert text to list of indices."""
        indices = []
        if add_sos:
            indices.append(self.char2idx[self.SOS_TOKEN])
        indices.extend([self.char2idx.get(char, self.char2idx[self.UNK_TOKEN])
                        for char in text])
        if add_eos:
            indices.append(self.char2idx[self.EOS_TOKEN])
        return indices

    def decode(self, indices):
        """Convert list of indices to text."""
        chars = []
        for idx in indices:
            if idx == self.char2idx[self.EOS_TOKEN]:
                break
            if idx == self.char2idx[self.PAD_TOKEN]:
                continue
            if idx == self.char2idx[self.SOS_TOKEN]:
                continue
            chars.append(self.idx2char.get(idx, self.UNK_TOKEN))
        return ''.join(chars)

    def __len__(self):
        return len(self.char2idx)

print("CharVocab class defined")

CharVocab class defined


In [3]:
class LSTMEncoder(nn.Module):
    """LSTM Encoder for sequence encoding."""

    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_lengths=None):
        embedded = self.dropout(self.embedding(src))
        if src_lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            outputs, (hidden, cell) = self.lstm(packed)
            outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        else:
            outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, hidden, cell


class BahdanauAttention(nn.Module):
    """Bahdanau (additive) attention mechanism."""

    def __init__(self, hidden_dim, encoder_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim + encoder_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask=None):
        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))
        attention = self.v(energy).squeeze(2)
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e4)
        attention_weights = torch.softmax(attention, dim=1)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)
        return context, attention_weights


class LSTMDecoder(nn.Module):
    """LSTM Decoder with attention mechanism."""

    def __init__(self, vocab_size, embedding_dim, hidden_dim, encoder_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.attention = BahdanauAttention(hidden_dim, encoder_dim)
        self.lstm = nn.LSTM(
            embedding_dim + encoder_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.fc_out = nn.Linear(hidden_dim + encoder_dim + embedding_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, hidden, cell, encoder_outputs, mask=None):
        embedded = self.dropout(self.embedding(tgt))
        context, attn_weights = self.attention(hidden[-1], encoder_outputs, mask)
        context = context.unsqueeze(1)
        lstm_input = torch.cat([embedded, context], dim=2)
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        output = output.squeeze(1)
        context = context.squeeze(1)
        embedded = embedded.squeeze(1)
        prediction = self.fc_out(torch.cat([output, context, embedded], dim=1))
        return prediction, hidden, cell, attn_weights


class SpellingLSTM(nn.Module):
    """LSTM Encoder-Decoder with Attention for spelling correction."""

    def __init__(self, vocab_size, embedding_dim=256, encoder_hidden_dim=256,
                 decoder_hidden_dim=512, num_layers=2, dropout=0.3):
        super().__init__()
        self.encoder = LSTMEncoder(vocab_size, embedding_dim, encoder_hidden_dim, num_layers, dropout)
        encoder_output_dim = encoder_hidden_dim * 2
        self.decoder = LSTMDecoder(vocab_size, embedding_dim, decoder_hidden_dim,
                                  encoder_output_dim, num_layers, dropout)
        self.bridge_h = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)
        self.bridge_c = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, tgt, src_padding_mask=None, tgt_padding_mask=None, **kwargs):
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        if src_padding_mask is not None:
            src_lengths = (~src_padding_mask).sum(dim=1)
        else:
            src_lengths = torch.full((batch_size,), src.size(1), dtype=torch.long, device=src.device)
        encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
        hidden_combined = []
        cell_combined = []
        for i in range(self.decoder.num_layers):
            h_forward = hidden[i*2]
            h_backward = hidden[i*2 + 1]
            h_combined = torch.cat([h_forward, h_backward], dim=1)
            hidden_combined.append(self.bridge_h(h_combined))
            c_forward = cell[i*2]
            c_backward = cell[i*2 + 1]
            c_combined = torch.cat([c_forward, c_backward], dim=1)
            cell_combined.append(self.bridge_c(c_combined))
        hidden = torch.stack(hidden_combined)
        cell = torch.stack(cell_combined)
        attn_mask = None
        if src_padding_mask is not None:
            attn_mask = ~src_padding_mask
        outputs = []
        for t in range(tgt_len):
            tgt_t = tgt[:, t].unsqueeze(1)
            prediction, hidden, cell, _ = self.decoder(tgt_t, hidden, cell, encoder_outputs, attn_mask)
            outputs.append(prediction)
        outputs = torch.stack(outputs, dim=1)
        return outputs

print("Model architecture defined")

Model architecture defined


In [4]:
# Load vocabulary from JSON file
vocab_json_path = 'char_vocab.json'

with open(vocab_json_path, 'r', encoding='utf-8') as f:
    vocab_data = json.load(f)

# Reconstruct vocabulary object
vocab = CharVocab()
vocab.char2idx = vocab_data['char2idx']
# Convert string keys back to integers for idx2char
vocab.idx2char = {int(k): v for k, v in vocab_data['idx2char'].items()}
vocab.next_idx = vocab_data.get('next_idx', len(vocab.char2idx))

print(f"Loaded vocabulary from {vocab_json_path}")
print(f"Vocabulary size: {len(vocab)}")

Loaded vocabulary from char_vocab.json
Vocabulary size: 39


In [5]:
# Initialize model with same parameters as training
model = SpellingLSTM(
    vocab_size=len(vocab),
    embedding_dim=256,
    encoder_hidden_dim=256,
    decoder_hidden_dim=512,
    num_layers=2,
    dropout=0.3
)

# Load trained weights
checkpoint = torch.load('best_model1.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print(f"Model loaded successfully!")
print(f"Trained for {checkpoint.get('epoch', 'unknown')} epochs")
print(f"Best validation loss: {checkpoint.get('val_loss', 'unknown')}")

Model loaded successfully!
Trained for 8 epochs
Best validation loss: 1.001819218268417


## Step 6: Define Inference Function

In [6]:
def correct_word(model, word, vocab, device='cuda', max_len=100):
    """
    Correct a single misspelled word using LSTM decoder.
    
    Uses greedy decoding with early stopping to prevent loops.
    """
    model.eval()
    
    # Encode input (no SOS for source)
    src = torch.LongTensor([vocab.encode(word, add_sos=False, add_eos=True)]).to(device)
    src_lengths = torch.LongTensor([src.size(1)])
    
    with torch.no_grad():
        # Encode source
        encoder_outputs, hidden, cell = model.encoder(src, src_lengths)
        
        # Bridge encoder hidden to decoder
        hidden_combined = []
        cell_combined = []
        for i in range(model.decoder.num_layers):
            h_forward = hidden[i*2]
            h_backward = hidden[i*2 + 1]
            h_combined = torch.cat([h_forward, h_backward], dim=1)
            hidden_combined.append(model.bridge_h(h_combined))
            
            c_forward = cell[i*2]
            c_backward = cell[i*2 + 1]
            c_combined = torch.cat([c_forward, c_backward], dim=1)
            cell_combined.append(model.bridge_c(c_combined))
        
        hidden = torch.stack(hidden_combined)
        cell = torch.stack(cell_combined)
        
        # Start with SOS token
        tgt_token = torch.LongTensor([[vocab.char2idx['<SOS>']]]).to(device)
        decoded_tokens = []
        
        # Greedy decoding with loop prevention
        for step in range(max_len):
            prediction, hidden, cell, _ = model.decoder(
                tgt_token, hidden, cell, encoder_outputs, mask=None
            )
            
            # Get most likely next token
            next_token_id = prediction.argmax(dim=-1).item()
            
            # Stop conditions
            if next_token_id == vocab.char2idx['<EOS>']:
                break
            if next_token_id == vocab.char2idx['<PAD>']:
                break
            if next_token_id == vocab.char2idx['<SOS>']:
                continue  # Skip if model predicts SOS again
            
            # Prevent infinite loops: stop if output is too long relative to input
            if len(decoded_tokens) > len(word) * 3:
                break
            
            # Detect repetition: if last 3 characters are same, stop
            if len(decoded_tokens) >= 3:
                last_three = decoded_tokens[-3:]
                if len(set(last_three)) == 1:  # All same character
                    break
            
            decoded_tokens.append(next_token_id)
            tgt_token = torch.LongTensor([[next_token_id]]).to(device)
    
    # Decode tokens to string
    return vocab.decode(decoded_tokens)

print("Inference function defined")

Inference function defined


In [7]:
# Define corruption function for testing
import random

GEORGIAN_KEYBOARD = {
   "ა": ['ქ','ს','ზ'], 'ბ': ['ვ','ნ','გ','ჰ'], 'გ': ['ვ','ბ','ფ','ტ','ყ','ჰ'],
   'დ': ['ხ','ც','ს','ფ','რ','ე'], 'ე': ['წ','რ','დ','ს'], 'ვ': ['ც','ბ','ფ','გ'],
   'ზ': ['ა','ს','ხ'], "თ": ['ღ','ყ','ფ','გ','ტ','რ'], 'ი': ['უ','ო','ჯ','კ'],
   'კ': ['მ','ჯ','ლ','ი','ო'], 'ლ': ['კ','ო','პ'], 'მ': ['ნ','ჯ','კ','ლ'],
   'ნ': ['ბ','ჰ','ჯ'], 'ო': ['ი','პ','კ','ლ'], 'პ': ['ო','ლ'],
   'ჟ': ['ჯ','ჰ','უ','ნ','მ'], 'რ': ['ღ','ე','ტ','თ','დ','ფ'],
   'ს': ['შ','ა','ზ','ხ','წ','ე'], 'ტ': ['რ','ყ','ფ','გ'], 'უ': ['ყ','ჰ','ჯ','ი'],
   'ფ': ['ც','ვ','დ','გ','რ','ტ'], 'ქ': ['ა','წ'], 'ღ': ['თ','რ','ტ','ე','დ','ფ'],
   'ყ': ['ტ','გ','ჰ','უ'], 'შ': ['ს','ა','დ','წ','ე','ხ'], 'ჩ': ['ც','ხ','ვ','დ','ფ'],
   'ც': ['ხ','ვ','დ','ფ'], 'ძ': ['ა','ს','ხ'], 'წ': ['ქ','ე','ს','ა'],
   'ჭ': ['ქ','ე','ს','ა'], 'ხ': ['ა','ს','დ','ც','ზ'],
   'ჯ': ['ჰ','უ','ი','კ','მ','ნ'], 'ჰ': ['გ','ყ','უ','ჯ','ნ','ბ']
}

# Get all Georgian characters from vocabulary (excluding special tokens)
ALL_GEORGIAN_CHARS = set()
for char in vocab.char2idx.keys():
    if char not in [vocab.PAD_TOKEN, vocab.SOS_TOKEN, vocab.EOS_TOKEN, vocab.UNK_TOKEN]:
        ALL_GEORGIAN_CHARS.add(char)

print(f"Loaded {len(ALL_GEORGIAN_CHARS)} Georgian characters from vocabulary")

def corrupt_word(word, corruption_prob=1.0):
    """Apply realistic corruptions to simulate typing errors."""
    if len(word) < 2:
        return word
    
    original_word = word
    max_attempts = 10
    
    for attempt in range(max_attempts):
        word_list = list(original_word)
        
        # Number of errors based on word length
        if len(word_list) <= 4:
            num_errors = 1
        elif len(word_list) <= 8:
            num_errors = random.randint(1, 2)
        else:
            num_errors = random.randint(1, 3)
        
        for _ in range(num_errors):
            if len(word_list) < 2:
                break
            
            error_type = random.choices(
                ['substitute', 'delete', 'insert', 'transpose', 'repeat'],
                weights=[0.35, 0.25, 0.20, 0.15, 0.05]
            )[0]
            
            pos = random.randint(0, len(word_list) - 1)
            
            if error_type == 'substitute':
                char = word_list[pos]
                if char in GEORGIAN_KEYBOARD and GEORGIAN_KEYBOARD[char]:
                    word_list[pos] = random.choice(GEORGIAN_KEYBOARD[char])
                elif ALL_GEORGIAN_CHARS:
                    candidates = [c for c in ALL_GEORGIAN_CHARS if c != char]
                    if candidates:
                        word_list[pos] = random.choice(candidates)
            
            elif error_type == 'delete':
                if len(word_list) > 2:
                    word_list.pop(pos)
            
            elif error_type == 'insert':
                if ALL_GEORGIAN_CHARS:
                    if pos > 0 and random.random() < 0.3:
                        word_list.insert(pos, word_list[pos-1])
                    else:
                        word_list.insert(pos, random.choice(list(ALL_GEORGIAN_CHARS)))
            
            elif error_type == 'transpose':
                if pos < len(word_list) - 1:
                    word_list[pos], word_list[pos + 1] = word_list[pos + 1], word_list[pos]
            
            elif error_type == 'repeat':
                word_list.insert(pos, word_list[pos])
        
        corrupted = ''.join(word_list)
        
        if corrupted != original_word and len(corrupted) > 0:
            return corrupted
    
    # Last resort: force a substitution
    if len(original_word) >= 2 and ALL_GEORGIAN_CHARS:
        word_list = list(original_word)
        pos = random.randint(0, len(word_list) - 1)
        candidates = [c for c in ALL_GEORGIAN_CHARS if c != word_list[pos]]
        if candidates:
            word_list[pos] = random.choice(candidates)
            return ''.join(word_list)
    
    return original_word

print("Corruption function defined")


Loaded 35 Georgian characters from vocabulary
Corruption function defined


In [8]:
# Create test pairs: corrupt some sample Georgian words
random.seed(42)

# Sample Georgian words for testing (you can replace these with your own)
sample_words = [
    'გამარჯობა', 'მადლობა', 'დედა', 'მამა', 'სახლი', 
    'წიგნი', 'სკოლა', 'მეგობარი', 'ქუჩა', 'საქართველო',
    'თბილისი', 'ბავშვი', 'ქალი', 'კაცი', 'ძაღლი',
    'კატა', 'საჭმელი', 'წყალი', 'დრო', 'ფული'
]

test_pairs = []
for word in sample_words:
    corrupted = corrupt_word(word)
    if corrupted != word:  # Only add if actually corrupted
        test_pairs.append((corrupted, word))

print(f"Created {len(test_pairs)} test pairs")
print("\nAll test examples:")
for i, (corrupted, original) in enumerate(test_pairs, 1):
    print(f"{i:2}. Corrupted: {corrupted:20} → Original: {original}")

Created 20 test pairs

All test examples:
 1. Corrupted: გზმაეჯობბა           → Original: გამარჯობა
 2. Corrupted: ჯადლიბა              → Original: მადლობა
 3. Corrupted: დედ                  → Original: დედა
 4. Corrupted: მანა                 → Original: მამა
 5. Corrupted: საახლი               → Original: სახლი
 6. Corrupted: წიგგნი               → Original: წიგნი
 7. Corrupted: ზკოლა                → Original: სკოლა
 8. Corrupted: ყმეგობრი             → Original: მეგობარი
 9. Corrupted: ქუფა                 → Original: ქუჩა
10. Corrupted: სქატთველო            → Original: საქართველო
11. Corrupted: თბიისი               → Original: თბილისი
12. Corrupted: ბაბშვ-ი              → Original: ბავშვი
13. Corrupted: ქპალი                → Original: ქალი
14. Corrupted: კაცჯ                 → Original: კაცი
15. Corrupted: მძაღლი               → Original: ძაღლი
16. Corrupted: კაატ                 → Original: კატა
17. Corrupted: სზმჭელი              → Original: საჭმელი
18. Corrupted: წყაყლი      

In [9]:
def calculate_f1_score(predicted, target):
    """
    Calculate character-level F1 score between predicted and target strings.
    
    Args:
        predicted: The predicted/corrected word
        target: The true/original word
    
    Returns:
        f1_score: F1 score value (0 to 1)
    """
    # Convert to sets of (character, position) pairs for character-level matching
    pred_chars = set((char, i) for i, char in enumerate(predicted))
    target_chars = set((char, i) for i, char in enumerate(target))
    
    # Calculate true positives, false positives, false negatives
    true_positives = len(pred_chars & target_chars)
    false_positives = len(pred_chars - target_chars)
    false_negatives = len(target_chars - pred_chars)
    
    # Calculate precision and recall
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    
    # Calculate F1 score
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return f1

print("F1 score calculation function defined")

F1 score calculation function defined


In [10]:
# Test the model on corrupted pairs
print("\n" + "="*70)
print("Testing Model on Corrupted Word Pairs")
print("="*70)

correct_count = 0
total_count = len(test_pairs)
total_f1 = 0.0
f1_scores = []

for corrupted_word, original_word in test_pairs:
    try:
        corrected = correct_word(model, corrupted_word, vocab, device)
        is_correct = (corrected == original_word)
        correct_count += int(is_correct)
        
        # Calculate F1 score
        f1 = calculate_f1_score(corrected, original_word)
        f1_scores.append(f1)
        total_f1 += f1
        
        status = "✓" if is_correct else "✗"
        print(f"{status} Corrupted: {corrupted_word:20} → Corrected: {corrected:20} (True: {original_word}) | F1: {f1:.3f}")
    except Exception as e:
        print(f"✗ Error with '{corrupted_word}': {e}")
        f1_scores.append(0.0)

# Calculate average F1 score
avg_f1 = total_f1 / total_count if total_count > 0 else 0

print("\n" + "="*70)
print(f"Accuracy: {correct_count}/{total_count} ({100*correct_count/total_count:.1f}%)")
print(f"Average F1 Score: {avg_f1:.4f}")
print(f"Min F1: {min(f1_scores):.4f}, Max F1: {max(f1_scores):.4f}")
print("="*70)


Testing Model on Corrupted Word Pairs


✓ Corrupted: გზმაეჯობბა           → Corrected: გამარჯობა            (True: გამარჯობა) | F1: 1.000
✗ Corrupted: ჯადლიბა              → Corrected: ჯადობია              (True: მადლობა) | F1: 0.429
✓ Corrupted: დედ                  → Corrected: დედა                 (True: დედა) | F1: 1.000
✗ Corrupted: მანა                 → Corrected: მანას                (True: მამა) | F1: 0.667
✗ Corrupted: საახლი               → Corrected: სახალი               (True: სახლი) | F1: 0.545
✓ Corrupted: წიგგნი               → Corrected: წიგნი                (True: წიგნი) | F1: 1.000
✗ Corrupted: ზკოლა                → Corrected: აკოლა                (True: სკოლა) | F1: 0.800
✗ Corrupted: ყმეგობრი             → Corrected: მეგობრი              (True: მეგობარი) | F1: 0.667
✗ Corrupted: ქუფა                 → Corrected: ქუდა                 (True: ქუჩა) | F1: 0.750
✓ Corrupted: სქატთველო            → Corrected: საქართველო           (True: საქართველო) | F1: 1.000
✗ Corrupted: თბიისი               → Corrected: თბ

techincally 50% because ბავშვნი ერთ-ერთ ბრუნვაშია, i have over-complicated the data generation imho


In [11]:
print("Enter a Georgian word to correct (type 'exit' to quit):")
while True:
    user_input = input("Your word: ").strip()
    if user_input.lower() == 'exit':
        break
    if not user_input:
        print("Please enter a word.")
        continue
    
    try:
        corrected_word = correct_word(model, user_input, vocab, device)
        print(f"  Original:  {user_input}")
        print(f"  Corrected: {corrected_word}\n")
    except Exception as e:
        print(f"An error occurred: {e}")
        print("Please try a different word.\n")

print("Exiting interactive correction.")

Enter a Georgian word to correct (type 'exit' to quit):
  Original:  სალადინი
  Corrected: სალადინის

  Original:  დანია
  Corrected: დანისა

  Original:  ალმასი
  Corrected: ალმასი

  Original:  სტალინი
  Corrected: სტალინის

  Original:  გიორგა
  Corrected: გიორგა

  Original:  სააკაძე
  Corrected: სააკაძე

Exiting interactive correction.
