In [None]:
import json
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import math
from pathlib import Path
from tqdm import tqdm
import numpy as np

# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# ============================================================================
# GEORGIAN KEYBOARD LAYOUT - For realistic misclick simulation
# ============================================================================
#ა ბ გ დ ე ვ ზ თ ი კ ლ მ ნ ო პ ჟ რ ს ტ უ ფ ქ ღ ყ შ ჩ ც ძ წ ჭ ხ ჯ ჰ
# Standard Georgian QWERTY keyboard layout (approximation)
GEORGIAN_KEYBOARD = {
   "ა" : ['ქ','ს,','ზ'],
   'ბ' : ['ვ','ნ','გ','ჰ'], # we can also add space but not sure that model is going to work for this one
   'გ' : ['ვ','ბ','ფ','ტ','ყ','ჰ'],
   'დ' : ['ხ','ც','ს','ფ','რ','ე'],
   'ე' : ['წ','რ','დ','ს'],
   'ვ': ['ც','ბ','ფ','გ'],
   'ზ': ['ა','ს','ხ'],
   "თ": ['ღ','ყ','ფ','გ','ტ','რ'], # pressed/not pressed shift
   'ი': ['უ','ო','ჯ','კ'],
   'კ': ['მ','ჯ','ლ','ი','ო'],
   'ლ': ['კ','ო','პ'],
   'მ': ['ნ','ჯ','კ','ლ'],
   'ნ': ['ბ','ჰ','ჯ'], #no მ
   'ო': ['ი','პ','კ','ლ'],
   'პ': ['ო','ლ'],
   'ჟ': ['ჯ','ჰ','უ','ნ','მ'],
   'რ': ['ღ','ე','ტ','თ','დ','ფ'],
   'ს': ['შ','ა','ზ','ხ','წ','ე'],
   'ტ': ['რ','ყ','ფ','გ'],
   'უ': ['ყ','ჰ','ჯ','ი'],
   'ფ': ['ც','ვ','დ','გ','რ','ტ'],
   'ქ': ['ა','წ'],
   'ღ': ['თ','რ','ტ','ე','დ','ფ'],
   'ყ': ['ტ','გ','ჰ','უ'],
   'შ': ['ს','ა','დ','წ','ე','ხ'],
   'ჩ': ['ც','ხ','ვ','დ','ფ'],
   'ც': ['ხ','ვ','დ','ფ'],
   'ძ': ['ა', 'ს', 'ხ'],
   'წ': ['ქ', 'ე', 'ს', 'ა'],
   'ჭ': ['ქ', 'ე', 'ს', 'ა'],
   'ხ': ['ა', 'ს', 'დ', 'ც','ზ'],
   'ჯ': ['ჰ', 'უ', 'ი', 'კ', 'მ', 'ნ'],
   'ჰ': ['გ', 'ყ', 'უ', 'ჯ', 'ნ', 'ბ']
}

# Around line 107 - Update this section:
ALL_GEORGIAN_CHARS = set()

# Define your data directory once
DATA_DIR = 'drive/MyDrive/data' # Update this path as needed, i have trained on colab current one would be data

for word_file in ['wordsChunk_0.json', 'wordsChunk_1.json', 'wordsChunk_2.json']:
    file_path = Path(DATA_DIR) / word_file
    if file_path.exists():
        with open(file_path, 'r', encoding='utf-8') as f:
            words = json.load(f)
            for word in words[:1000]:
                ALL_GEORGIAN_CHARS.update(word)














# ============================================================================
# INFERENCE - Generate corrected spellings
# ============================================================================

def correct_word(model, word, vocab, device='cuda', max_len=100):
    """
    Correct a single misspelled word using LSTM decoder.

    Uses greedy decoding with early stopping to prevent loops.
    """
    model.eval()

    # Encode input (no SOS for source)
    src = torch.LongTensor([vocab.encode(word, add_sos=False, add_eos=True)]).to(device)
    src_lengths = torch.LongTensor([src.size(1)])

    with torch.no_grad():
        # Encode source
        encoder_outputs, hidden, cell = model.encoder(src, src_lengths)

        # Bridge encoder hidden to decoder
        hidden_combined = []
        cell_combined = []
        for i in range(model.decoder.num_layers):
            h_forward = hidden[i*2]
            h_backward = hidden[i*2 + 1]
            h_combined = torch.cat([h_forward, h_backward], dim=1)
            hidden_combined.append(model.bridge_h(h_combined))

            c_forward = cell[i*2]
            c_backward = cell[i*2 + 1]
            c_combined = torch.cat([c_forward, c_backward], dim=1)
            cell_combined.append(model.bridge_c(c_combined))

        hidden = torch.stack(hidden_combined)
        cell = torch.stack(cell_combined)

        # Start with SOS token
        tgt_token = torch.LongTensor([[vocab.char2idx['<SOS>']]]).to(device)
        decoded_tokens = []

        # Greedy decoding with loop prevention
        for step in range(max_len):
            prediction, hidden, cell, _ = model.decoder(
                tgt_token, hidden, cell, encoder_outputs, mask=None
            )

            # Get most likely next token
            next_token_id = prediction.argmax(dim=-1).item()

            # Stop conditions
            if next_token_id == vocab.char2idx['<EOS>']:
                break
            if next_token_id == vocab.char2idx['<PAD>']:
                break
            if next_token_id == vocab.char2idx['<SOS>']:
                continue  # Skip if model predicts SOS again

            # Prevent infinite loops: stop if output is too long relative to input
            if len(decoded_tokens) > len(word) * 3:
                break

            # Detect repetition: if last 3 characters are same, stop
            if len(decoded_tokens) >= 3:
                last_three = decoded_tokens[-3:]
                if len(set(last_three)) == 1:  # All same character
                    break

            decoded_tokens.append(next_token_id)
            tgt_token = torch.LongTensor([[next_token_id]]).to(device)

    # Decode tokens to string
    return vocab.decode(decoded_tokens)



In [15]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [16]:
# ============================================================================
# DATA CORRUPTION - Simulating realistic typing errors
# ============================================================================

def corrupt_word(word, corruption_prob=1.0):
    """
    Apply realistic corruptions to simulate typing errors.

    ALWAYS corrupts words (100% rate) so model learns to actually correct,
    not just copy input to output.

    Error Types:
    1. Character-level errors:
       - Substitution: Adjacent key typo (35%)
       - Deletion: Missing character (25%)
       - Insertion: Extra character (20%)
       - Transposition: Swapped characters (15%)
       - Repetition: Doubled character (5%)
    """
    # Skip very short words
    if len(word) < 2:
        return word

    original_word = word
    max_attempts = 10  # Try harder to corrupt

    for attempt in range(max_attempts):
        word_list = list(original_word)

        # Number of errors: 1-3 based on word length
        if len(word_list) <= 4:
            num_errors = 1
        elif len(word_list) <= 8:
            num_errors = random.randint(1, 2)
        else:
            num_errors = random.randint(1, 3)

        for _ in range(num_errors):
            if len(word_list) < 2:
                break

            error_type = random.choices(
                ['substitute', 'delete', 'insert', 'transpose', 'repeat'],
                weights=[0.35, 0.25, 0.20, 0.15, 0.05]
            )[0]

            pos = random.randint(0, len(word_list) - 1)

            if error_type == 'substitute':
                char = word_list[pos]
                if char in GEORGIAN_KEYBOARD and GEORGIAN_KEYBOARD[char]:
                    word_list[pos] = random.choice(GEORGIAN_KEYBOARD[char])
                elif ALL_GEORGIAN_CHARS:
                    # Pick a random different character
                    candidates = [c for c in ALL_GEORGIAN_CHARS if c != char]
                    if candidates:
                        word_list[pos] = random.choice(candidates)

            elif error_type == 'delete':
                if len(word_list) > 2:
                    word_list.pop(pos)

            elif error_type == 'insert':
                if ALL_GEORGIAN_CHARS:
                    if pos > 0 and random.random() < 0.3:
                        # Duplicate adjacent char (common typo)
                        word_list.insert(pos, word_list[pos-1])
                    else:
                        word_list.insert(pos, random.choice(list(ALL_GEORGIAN_CHARS)))

            elif error_type == 'transpose':
                if pos < len(word_list) - 1:
                    word_list[pos], word_list[pos + 1] = word_list[pos + 1], word_list[pos]

            elif error_type == 'repeat':
                # Double a character
                word_list.insert(pos, word_list[pos])

        corrupted = ''.join(word_list)

        # Success if we actually changed the word
        if corrupted != original_word and len(corrupted) > 0:
            return corrupted

    # Last resort: force a substitution
    if len(original_word) >= 2 and ALL_GEORGIAN_CHARS:
        word_list = list(original_word)
        pos = random.randint(0, len(word_list) - 1)
        candidates = [c for c in ALL_GEORGIAN_CHARS if c != word_list[pos]]
        if candidates:
            word_list[pos] = random.choice(candidates)
            return ''.join(word_list)

    return original_word

In [17]:
def load_and_corrupt_data(data_dir='data', corruption_rate=1.0, max_words=None):
    """
    Load words from JSON files and create corrupted versions.

    NOW: 100% corruption rate - every word gets corrupted so model
    learns to actually CORRECT, not just COPY.

    Args:
        data_dir: Directory containing word JSON files
        corruption_rate: Percentage of words to corrupt (1.0 = 100%)
        max_words: Maximum number of words to load (None = all)

    Returns:
        List of (corrupted, correct) tuples
    """
    all_words = []
    data_path = Path(data_dir)

    # Load all word chunks
    for word_file in ['wordsChunk_0.json', 'wordsChunk_1.json', 'wordsChunk_2.json']:
        file_path = data_path / word_file
        if file_path.exists():
            print(f"Loading {word_file}...")
            with open(file_path, 'r', encoding='utf-8') as f:
                words = json.load(f)
                all_words.extend(words)

    # Limit dataset size if specified
    if max_words:
        all_words = all_words[:max_words]

    # Remove very short words (hard to corrupt meaningfully)
    all_words = [w for w in all_words if len(w) >= 2]

    print(f"Loaded {len(all_words)} words total (after filtering short words)")

    # Create training pairs: (corrupted, correct)
    training_pairs = []
    failed_corruptions = 0

    for word in tqdm(all_words, desc="Corrupting words"):
        corrupted = corrupt_word(word, corruption_prob=corruption_rate)

        # Only add if we successfully corrupted (corrupted != original)
        if corrupted != word:
            training_pairs.append((corrupted, word))
        else:
            failed_corruptions += 1
            # Still add as identity pair (some clean examples help)
            if random.random() < 0.1:  # Keep 10% of failed corruptions
                training_pairs.append((word, word))

    # Count actual corruptions
    num_corrupted = sum(1 for c, o in training_pairs if c != o)
    print(f"Successfully corrupted: {num_corrupted}/{len(training_pairs)} "
          f"({100*num_corrupted/len(training_pairs):.1f}%)")
    print(f"Failed to corrupt: {failed_corruptions} words")

    # Shuffle to mix corrupted and clean examples
    random.shuffle(training_pairs)

    return training_pairs

In [18]:
# ============================================================================
# CHARACTER VOCABULARY - Building character-to-index mappings
# ============================================================================

class CharVocab:
    """Character-level vocabulary for Georgian text."""

    def __init__(self):
        self.PAD_TOKEN = '<PAD>'
        self.SOS_TOKEN = '<SOS>'  # Start of sequence
        self.EOS_TOKEN = '<EOS>'  # End of sequence
        self.UNK_TOKEN = '<UNK>'  # Unknown character

        self.char2idx = {
            self.PAD_TOKEN: 0,
            self.SOS_TOKEN: 1,
            self.EOS_TOKEN: 2,
            self.UNK_TOKEN: 3,
        }
        self.idx2char = {v: k for k, v in self.char2idx.items()}
        self.next_idx = 4

    def build_vocab(self, words):
        """Build vocabulary from list of words."""
        for word in words:
            for char in word:
                if char not in self.char2idx:
                    self.char2idx[char] = self.next_idx
                    self.idx2char[self.next_idx] = char
                    self.next_idx += 1
        print(f"Vocabulary size: {len(self.char2idx)}")
        return self

    def encode(self, text, add_sos=False, add_eos=True):
        """Convert text to list of indices."""
        indices = []
        if add_sos:
            indices.append(self.char2idx[self.SOS_TOKEN])
        indices.extend([self.char2idx.get(char, self.char2idx[self.UNK_TOKEN])
                        for char in text])
        if add_eos:
            indices.append(self.char2idx[self.EOS_TOKEN])
        return indices

    def decode(self, indices):
        """Convert list of indices to text."""
        chars = []
        for idx in indices:
            if idx == self.char2idx[self.EOS_TOKEN]:
                break
            if idx == self.char2idx[self.PAD_TOKEN]:
                continue
            if idx == self.char2idx[self.SOS_TOKEN]:
                continue
            chars.append(self.idx2char.get(idx, self.UNK_TOKEN))
        return ''.join(chars)

    def __len__(self):
        return len(self.char2idx)

In [19]:
# ============================================================================
# DATASET - PyTorch Dataset for character sequences
# ============================================================================

class SpellingDataset(Dataset):
    """Dataset for spelling correction pairs."""

    def __init__(self, pairs, vocab, max_len=50):
        self.pairs = pairs
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        corrupted, correct = self.pairs[idx]

        # Encode sequences
        # Source: no SOS, just EOS
        src = self.vocab.encode(corrupted, add_sos=False, add_eos=True)
        # Target: add SOS at start, EOS at end for teacher forcing
        tgt = self.vocab.encode(correct, add_sos=True, add_eos=True)

        # Truncate if too long
        src = src[:self.max_len]
        tgt = tgt[:self.max_len]

        return torch.LongTensor(src), torch.LongTensor(tgt)


def collate_fn(batch):
    """Custom collate function to pad sequences in batch."""
    src_batch, tgt_batch = zip(*batch)

    # Pad sequences
    src_padded = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0)
    tgt_padded = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=0)

    return src_padded, tgt_padded

In [20]:
# ============================================================================
# LSTM ENCODER-DECODER MODEL - With attention mechanism
# ============================================================================

class LSTMEncoder(nn.Module):
    """LSTM Encoder for sequence encoding."""

    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_lengths=None):
        # src: [batch, seq_len]
        embedded = self.dropout(self.embedding(src))  # [batch, seq_len, emb_dim]

        if src_lengths is not None:
            # Pack padded sequences for efficiency
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            outputs, (hidden, cell) = self.lstm(packed)
            outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        else:
            outputs, (hidden, cell) = self.lstm(embedded)

        # outputs: [batch, seq_len, hidden_dim*2] (bidirectional)
        # hidden: [num_layers*2, batch, hidden_dim]
        # cell: [num_layers*2, batch, hidden_dim]
        return outputs, hidden, cell

In [21]:
class BahdanauAttention(nn.Module):
    """Bahdanau (additive) attention mechanism."""

    def __init__(self, hidden_dim, encoder_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim + encoder_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask=None):
        # hidden: [batch, hidden_dim]
        # encoder_outputs: [batch, src_len, encoder_dim]

        batch_size = encoder_outputs.size(0)
        src_len = encoder_outputs.size(1)

        # Repeat hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # [batch, src_len, hidden_dim]

        # Calculate attention scores
        energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2)))
        attention = self.v(energy).squeeze(2)  # [batch, src_len]

        # Apply mask if provided
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e4)

        # Apply softmax
        attention_weights = torch.softmax(attention, dim=1)  # [batch, src_len]

        # Apply attention to encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)  # [batch, encoder_dim]

        return context, attention_weights


class LSTMDecoder(nn.Module):
    """LSTM Decoder with attention mechanism."""

    def __init__(self, vocab_size, embedding_dim, hidden_dim, encoder_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.attention = BahdanauAttention(hidden_dim, encoder_dim)

        self.lstm = nn.LSTM(
            embedding_dim + encoder_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )

        self.fc_out = nn.Linear(hidden_dim + encoder_dim + embedding_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tgt, hidden, cell, encoder_outputs, mask=None):
        # tgt: [batch, 1] (one token at a time for autoregressive generation)
        # hidden: [num_layers, batch, hidden_dim]
        # cell: [num_layers, batch, hidden_dim]
        # encoder_outputs: [batch, src_len, encoder_dim]

        embedded = self.dropout(self.embedding(tgt))  # [batch, 1, emb_dim]

        # Calculate attention using the top layer's hidden state
        context, attn_weights = self.attention(
            hidden[-1], encoder_outputs, mask
        )  # context: [batch, encoder_dim]

        # Combine embedded input and context
        context = context.unsqueeze(1)  # [batch, 1, encoder_dim]
        lstm_input = torch.cat([embedded, context], dim=2)  # [batch, 1, emb_dim + encoder_dim]

        # Pass through LSTM
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
        # output: [batch, 1, hidden_dim]

        # Prepare for output projection
        output = output.squeeze(1)  # [batch, hidden_dim]
        context = context.squeeze(1)  # [batch, encoder_dim]
        embedded = embedded.squeeze(1)  # [batch, emb_dim]

        # Concatenate and project to vocabulary
        prediction = self.fc_out(
            torch.cat([output, context, embedded], dim=1)
        )  # [batch, vocab_size]

        return prediction, hidden, cell, attn_weights

In [22]:
class SpellingLSTM(nn.Module):
    """
    LSTM Encoder-Decoder with Attention for spelling correction.

    Architecture:
    - Encoder: 2-layer bidirectional LSTM (256 hidden units per direction)
    - Decoder: 2-layer unidirectional LSTM (512 hidden units)
    - Attention: Bahdanau (additive) attention mechanism
    - Embedding: 256 dimensions

    Why LSTM for this task:
    1. Sequential nature captures left-to-right character dependencies
    2. Attention helps align corrupted → correct characters
    3. Bidirectional encoder sees full context of misspelled word
    4. Proven effective for sequence-to-sequence tasks
    5. Less memory than Transformer for longer sequences
    """

    def __init__(self, vocab_size, embedding_dim=256, encoder_hidden_dim=256,
                 decoder_hidden_dim=512, num_layers=2, dropout=0.3):
        super().__init__()

        self.encoder = LSTMEncoder(
            vocab_size, embedding_dim, encoder_hidden_dim, num_layers, dropout
        )

        # Encoder is bidirectional, so encoder output dim is encoder_hidden_dim * 2
        encoder_output_dim = encoder_hidden_dim * 2

        self.decoder = LSTMDecoder(
            vocab_size, embedding_dim, decoder_hidden_dim,
            encoder_output_dim, num_layers, dropout
        )

        # Bridge from encoder hidden to decoder hidden (bidirectional to unidirectional)
        self.bridge_h = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)
        self.bridge_c = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)

        self.dropout = nn.Dropout(dropout)

        self._init_weights()

    def _init_weights(self):
        """Initialize weights."""
        for name, param in self.named_parameters():
            if 'weight' in name:
                if 'lstm' in name:
                    nn.init.orthogonal_(param)
                else:
                    nn.init.xavier_uniform_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.0)

    def forward(self, src, tgt, src_padding_mask=None, tgt_padding_mask=None, **kwargs):
        """
        Forward pass for training (teacher forcing).

        Args:
            src: Source sequence (corrupted) [batch, src_len]
            tgt: Target sequence (correct) [batch, tgt_len]
            src_padding_mask: Mask for source padding [batch, src_len]
            tgt_padding_mask: Mask for target padding (not used in LSTM)

        Returns:
            outputs: Predictions [batch, tgt_len, vocab_size]
        """
        batch_size = src.size(0)
        tgt_len = tgt.size(1)

        # Calculate source lengths for packing
        if src_padding_mask is not None:
            src_lengths = (~src_padding_mask).sum(dim=1)
        else:
            src_lengths = torch.full((batch_size,), src.size(1), dtype=torch.long, device=src.device)

        # Encode
        encoder_outputs, hidden, cell = self.encoder(src, src_lengths)
        # encoder_outputs: [batch, src_len, encoder_hidden*2]
        # hidden: [num_layers*2, batch, encoder_hidden]
        # cell: [num_layers*2, batch, encoder_hidden]

        # Bridge encoder hidden states to decoder (combine forward and backward)
        # Take pairs of layers (forward, backward) and combine them
        hidden_combined = []
        cell_combined = []
        for i in range(self.decoder.num_layers):
            # Combine forward and backward hidden states
            h_forward = hidden[i*2]      # Forward LSTM layer i
            h_backward = hidden[i*2 + 1]  # Backward LSTM layer i
            h_combined = torch.cat([h_forward, h_backward], dim=1)
            hidden_combined.append(self.bridge_h(h_combined))

            c_forward = cell[i*2]
            c_backward = cell[i*2 + 1]
            c_combined = torch.cat([c_forward, c_backward], dim=1)
            cell_combined.append(self.bridge_c(c_combined))

        hidden = torch.stack(hidden_combined)  # [num_layers, batch, decoder_hidden]
        cell = torch.stack(cell_combined)      # [num_layers, batch, decoder_hidden]

        # Create mask for attention (inverse of padding mask)
        attn_mask = None
        if src_padding_mask is not None:
            attn_mask = ~src_padding_mask  # [batch, src_len]

        # Decode with teacher forcing
        outputs = []
        for t in range(tgt_len):
            tgt_t = tgt[:, t].unsqueeze(1)  # [batch, 1]
            prediction, hidden, cell, _ = self.decoder(
                tgt_t, hidden, cell, encoder_outputs, attn_mask
            )
            outputs.append(prediction)

        outputs = torch.stack(outputs, dim=1)  # [batch, tgt_len, vocab_size]
        return outputs

    def generate_square_subsequent_mask(self, sz):
        """Dummy method for compatibility (not needed for LSTM)."""
        return None


In [None]:
# ============================================================================
# TRAINING LOOP - With mixed precision (FP16) and proper monitoring
# ============================================================================

def train_model(model, train_loader, val_loader, vocab,
                num_epochs=10, learning_rate=0.0001, device='cuda'):
    """
    Train the spelling correction model with FP16 mixed precision. otherwise training would take a lifetime, and its okay 16fp for this task.

    Uses:
    - Adam optimizer with learning rate 1e-4
    - CrossEntropyLoss (standard for classification)
    - GradScaler for FP16 training
    - Learning rate scheduling
    - Early stopping based on validation loss
    """
    model = model.to(device)

    # Optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )

    # Loss function with label smoothing for better generalization
    criterion = nn.CrossEntropyLoss(
        ignore_index=vocab.char2idx['<PAD>'],
        label_smoothing=0.1
    )

    # Mixed precision training
    scaler = GradScaler()

    best_val_loss = float('inf')
    patience_counter = 0
    max_patience = 5

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for src, tgt in pbar:
            src, tgt = src.to(device), tgt.to(device)

            # Teacher forcing: use ground truth as decoder input
            tgt_input = tgt[:, :-1]  
            tgt_output = tgt[:, 1:]  

            # Create masks
            src_padding_mask = (src == vocab.char2idx['<PAD>'])
            tgt_padding_mask = (tgt_input == vocab.char2idx['<PAD>'])

            optimizer.zero_grad()

            # Mixed precision forward pass
            with autocast():
                output = model(src, tgt_input,
                             src_padding_mask=src_padding_mask,
                             tgt_padding_mask=tgt_padding_mask)

                # Reshape for loss calculation
                loss = criterion(
                    output.reshape(-1, len(vocab)),
                    tgt_output.reshape(-1)
                )

            # Backward pass with scaling
            scaler.scale(loss).backward()

            # Gradient clipping to prevent exploding gradients, capping the magnitude of gradients
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()

            # Metrics
            train_loss += loss.item()
            pred = output.argmax(dim=-1)
            mask = (tgt_output != vocab.char2idx['<PAD>'])
            train_correct += ((pred == tgt_output) & mask).sum().item()
            train_total += mask.sum().item()

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100*train_correct/train_total:.2f}%'
            })

        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = 100 * train_correct / train_total

        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for src, tgt in val_loader:
                src, tgt = src.to(device), tgt.to(device)

                tgt_input = tgt[:, :-1]
                tgt_output = tgt[:, 1:]

                src_padding_mask = (src == vocab.char2idx['<PAD>'])
                tgt_padding_mask = (tgt_input == vocab.char2idx['<PAD>'])

                with autocast():
                    output = model(src, tgt_input,
                                 src_padding_mask=src_padding_mask,
                                 tgt_padding_mask=tgt_padding_mask)

                    loss = criterion(
                        output.reshape(-1, len(vocab)),
                        tgt_output.reshape(-1)
                    )

                val_loss += loss.item()
                pred = output.argmax(dim=-1)
                mask = (tgt_output != vocab.char2idx['<PAD>'])
                val_correct += ((pred == tgt_output) & mask).sum().item()
                val_total += mask.sum().item()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = 100 * val_correct / val_total

        # Learning rate scheduling
        scheduler.step(avg_val_loss)

        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
        print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss,
            }, 'drive/MyDrive/best_model1.pt')
            print(f"  ✓ Saved best model (val_loss: {best_val_loss:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= max_patience:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break

    return model

In [None]:
# ============================================================================
# INFERENCE - Generate corrected spellings
# ============================================================================

def correct_word(model, word, vocab, device='cuda', max_len=100):
    """
    Correct a single misspelled word using LSTM decoder.

    Uses greedy decoding with early stopping to prevent loops.
    """
    model.eval()

    # Encode input (no SOS for source)
    src = torch.LongTensor([vocab.encode(word, add_sos=False, add_eos=True)]).to(device)
    src_lengths = torch.LongTensor([src.size(1)])

    with torch.no_grad():
        # Encode source
        encoder_outputs, hidden, cell = model.encoder(src, src_lengths)

        # Bridge encoder hidden to decoder
        hidden_combined = []
        cell_combined = []
        for i in range(model.decoder.num_layers):
            h_forward = hidden[i*2]
            h_backward = hidden[i*2 + 1]
            h_combined = torch.cat([h_forward, h_backward], dim=1)
            hidden_combined.append(model.bridge_h(h_combined))

            c_forward = cell[i*2]
            c_backward = cell[i*2 + 1]
            c_combined = torch.cat([c_forward, c_backward], dim=1)
            cell_combined.append(model.bridge_c(c_combined))

        hidden = torch.stack(hidden_combined)
        cell = torch.stack(cell_combined)

        # Start with SOS token
        tgt_token = torch.LongTensor([[vocab.char2idx['<SOS>']]]).to(device)
        decoded_tokens = []

        # Greedy decoding with loop prevention
        for step in range(max_len):
            prediction, hidden, cell, _ = model.decoder(
                tgt_token, hidden, cell, encoder_outputs, mask=None
            )

            # Get most likely next token
            next_token_id = prediction.argmax(dim=-1).item()

            # Stop conditions
            if next_token_id == vocab.char2idx['<EOS>']:
                break
            if next_token_id == vocab.char2idx['<PAD>']:
                break
            if next_token_id == vocab.char2idx['<SOS>']:
                continue  

            # Prevent infinite loops: stop if output is too long relative to input
            if len(decoded_tokens) > len(word) * 3:
                break

            # Detect repetition: if last 3 characters are same, stop
            if len(decoded_tokens) >= 3:
                last_three = decoded_tokens[-3:]
                if len(set(last_three)) == 1:  # All same character
                    break

            decoded_tokens.append(next_token_id)
            tgt_token = torch.LongTensor([[next_token_id]]).to(device)

    # Decode tokens to string
    return vocab.decode(decoded_tokens)

In [None]:
CORRUPTION_RATE = 1.0  
MAX_WORDS = None  
BATCH_SIZE = 128
NUM_EPOCHS = 15  
LEARNING_RATE = 0.001  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Step 1: Load and corrupt data
print("\n" + "="*70)
print("STEP 1: Loading and corrupting data (100% corruption)")
print("="*70)
pairs = load_and_corrupt_data(
    data_dir=DATA_DIR,
    corruption_rate=CORRUPTION_RATE,
    max_words=MAX_WORDS
)

# Step 2: Build vocabulary
print("\n" + "="*70)
print("STEP 2: Building character vocabulary")
print("="*70)
all_words = [word for pair in pairs for word in pair]
vocab = CharVocab().build_vocab(all_words)

# Step 3: Split into train/val
print("\n" + "="*70)
print("STEP 3: Creating train/validation split")
print("="*70)
split_idx = int(0.9 * len(pairs))
train_pairs = pairs[:split_idx]
val_pairs = pairs[split_idx:]
print(f"Train: {len(train_pairs)} pairs")
print(f"Val: {len(val_pairs)} pairs")

# Step 4: Create datasets and loaders
train_dataset = SpellingDataset(train_pairs, vocab)
val_dataset = SpellingDataset(val_pairs, vocab)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE,
    shuffle=True, collate_fn=collate_fn, num_workers=0
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE,
    shuffle=False, collate_fn=collate_fn, num_workers=0
)

# Step 5: Initialize model
print("\n" + "="*70)
print("STEP 4: Initializing LSTM Encoder-Decoder model")
print("="*70)
model = SpellingLSTM(
    vocab_size=len(vocab),
    embedding_dim=256,
    encoder_hidden_dim=256,
    decoder_hidden_dim=512,
    num_layers=2,
    dropout=0.3
)
# great exploration, i didnt knew these functionionality existed
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# Step 6: Train model
print("\n" + "="*70)
print("STEP 5: Training model with FP16 mixed precision")
print("="*70)
model = train_model(
    model, train_loader, val_loader, vocab,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    device=device
)

Using device: cuda

STEP 1: Loading and corrupting data (100% corruption)
Loading wordsChunk_0.json...
Loading wordsChunk_1.json...
Loading wordsChunk_2.json...
Loaded 271787 words total (after filtering short words)


Corrupting words: 100%|██████████| 271787/271787 [00:01<00:00, 179429.79it/s]


Successfully corrupted: 271787/271787 (100.0%)
Failed to corrupt: 0 words

STEP 2: Building character vocabulary
Vocabulary size: 39

STEP 3: Creating train/validation split
Train: 244608 pairs
Val: 27179 pairs

STEP 4: Initializing LSTM Encoder-Decoder model


  scaler = GradScaler()


Model parameters: 8,476,967

STEP 5: Training model with FP16 mixed precision


  with autocast():
Epoch 1/15: 100%|██████████| 1911/1911 [03:09<00:00, 10.07it/s, loss=1.1637, acc=80.47%]
  with autocast():



Epoch 1 Summary:
  Train Loss: 1.2986 | Train Acc: 80.47%
  Val Loss: 1.1313 | Val Acc: 85.56%
  ✓ Saved best model (val_loss: 1.1313)


Epoch 2/15: 100%|██████████| 1911/1911 [03:10<00:00, 10.04it/s, loss=1.0783, acc=86.81%]



Epoch 2 Summary:
  Train Loss: 1.0947 | Train Acc: 86.81%
  Val Loss: 1.0610 | Val Acc: 87.73%
  ✓ Saved best model (val_loss: 1.0610)


Epoch 3/15: 100%|██████████| 1911/1911 [03:12<00:00,  9.91it/s, loss=1.0769, acc=88.18%]



Epoch 3 Summary:
  Train Loss: 1.0486 | Train Acc: 88.18%
  Val Loss: 1.0364 | Val Acc: 88.45%
  ✓ Saved best model (val_loss: 1.0364)


Epoch 4/15: 100%|██████████| 1911/1911 [03:11<00:00,  9.95it/s, loss=1.0395, acc=89.16%]



Epoch 4 Summary:
  Train Loss: 1.0166 | Train Acc: 89.16%
  Val Loss: 1.0208 | Val Acc: 88.96%
  ✓ Saved best model (val_loss: 1.0208)


Epoch 5/15: 100%|██████████| 1911/1911 [03:12<00:00,  9.95it/s, loss=1.0309, acc=89.90%]



Epoch 5 Summary:
  Train Loss: 0.9928 | Train Acc: 89.90%
  Val Loss: 1.0103 | Val Acc: 89.29%
  ✓ Saved best model (val_loss: 1.0103)


Epoch 6/15: 100%|██████████| 1911/1911 [03:12<00:00,  9.91it/s, loss=0.9725, acc=90.56%]



Epoch 6 Summary:
  Train Loss: 0.9720 | Train Acc: 90.56%
  Val Loss: 1.0041 | Val Acc: 89.53%
  ✓ Saved best model (val_loss: 1.0041)


Epoch 7/15: 100%|██████████| 1911/1911 [03:12<00:00,  9.93it/s, loss=0.9629, acc=91.11%]



Epoch 7 Summary:
  Train Loss: 0.9559 | Train Acc: 91.11%
  Val Loss: 1.0027 | Val Acc: 89.59%
  ✓ Saved best model (val_loss: 1.0027)


Epoch 8/15: 100%|██████████| 1911/1911 [03:12<00:00,  9.93it/s, loss=0.9682, acc=91.63%]



Epoch 8 Summary:
  Train Loss: 0.9400 | Train Acc: 91.63%
  Val Loss: 1.0028 | Val Acc: 89.64%


Epoch 9/15: 100%|██████████| 1911/1911 [03:09<00:00, 10.10it/s, loss=0.9560, acc=92.03%]



Epoch 9 Summary:
  Train Loss: 0.9282 | Train Acc: 92.03%
  Val Loss: 1.0018 | Val Acc: 89.69%
  ✓ Saved best model (val_loss: 1.0018)


Epoch 10/15: 100%|██████████| 1911/1911 [03:10<00:00, 10.02it/s, loss=0.9335, acc=92.40%]



Epoch 10 Summary:
  Train Loss: 0.9168 | Train Acc: 92.40%
  Val Loss: 1.0028 | Val Acc: 89.77%


Epoch 11/15: 100%|██████████| 1911/1911 [03:09<00:00, 10.09it/s, loss=0.8944, acc=92.71%]



Epoch 11 Summary:
  Train Loss: 0.9081 | Train Acc: 92.71%
  Val Loss: 1.0034 | Val Acc: 89.73%


Epoch 12/15: 100%|██████████| 1911/1911 [03:09<00:00, 10.06it/s, loss=0.9121, acc=93.05%]



Epoch 12 Summary:
  Train Loss: 0.8984 | Train Acc: 93.05%
  Val Loss: 1.0060 | Val Acc: 89.65%


Epoch 13/15: 100%|██████████| 1911/1911 [03:09<00:00, 10.10it/s, loss=0.8857, acc=94.10%]



Epoch 13 Summary:
  Train Loss: 0.8690 | Train Acc: 94.10%
  Val Loss: 1.0038 | Val Acc: 89.87%


Epoch 14/15: 100%|██████████| 1911/1911 [03:08<00:00, 10.12it/s, loss=0.8399, acc=94.59%]



Epoch 14 Summary:
  Train Loss: 0.8558 | Train Acc: 94.59%
  Val Loss: 1.0077 | Val Acc: 89.81%

Early stopping triggered after 14 epochs


In [27]:
# Step 7: Test on examples
print("\n" + "="*70)
print("STEP 6: Testing on example corrections")
print("="*70)

# Load best model
checkpoint = torch.load('drive/MyDrive/best_model1.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

# Test on validation pairs (these are all corrupted now!)
test_examples = val_pairs[:15]  # Take 15 examples

correct_count = 0
for corrupted_word, original_word in test_examples:
    corrected = correct_word(model, corrupted_word, vocab, device)
    is_correct = corrected == original_word
    correct_count += int(is_correct)
    status = "✓" if is_correct else "✗"
    print(f"{status} Corrupted: {corrupted_word:20} → Corrected: {corrected:20} (True: {original_word})")

print(f"\nAccuracy: {correct_count}/{len(test_examples)} ({100*correct_count/len(test_examples):.1f}%)")


STEP 6: Testing on example corrections
✓ Corrupted: ოქროოიპროიბთ         → Corrected: ოქროპირობით          (True: ოქროპირობით)
✓ Corrupted: პ-ეში                → Corrected: პეში                 (True: პეში)
✓ Corrupted: გაუჭირვწბლადდ        → Corrected: გაუჭირვებლად         (True: გაუჭირვებლად)
✗ Corrupted: ჰლია                 → Corrected: ჰოლია                (True: ჰულია)
✗ Corrupted: ხდების               → Corrected: ხედების              (True: ხეების)
✗ Corrupted: უკფრორე              → Corrected: უკროდე               (True: უფრორე)
✗ Corrupted: სასნელგ              → Corrected: სასნელით             (True: სასჯელთ)
✗ Corrupted: გს,სარკლი            → Corrected: გასარკლი             (True: გასარკული)
✗ Corrupted: კომოთი               → Corrected: კომოთი               (True: კომოდით)
✓ Corrupted: უყვარრადა            → Corrected: უყვარადა             (True: უყვარადა)
✗ Corrupted: ჯოლტ                 → Corrected: ჯოლტ                 (True: ნოლტი)
✗ Corrupted: კრიტჩტონის          

there could be some corrected words to be valid, despite the fact they are not for example,
✗ Corrupted: ხდების               → Corrected: ხედების              (True: ხეების)
also weird words showed up e.g ჯოლტი,აგუტები, top 5 betrayals moments 

weird words, top 5 betrayal moments in history

In [28]:
# Make sure the model and vocab are loaded from previous steps
# If you restarted the kernel, you might need to run the training and setup cells again.

# Put the model in evaluation mode
model.eval()

print("Enter a Georgian word to correct (type 'exit' to quit):")
while True:
    user_input = input("Your word: ").strip()
    if user_input.lower() == 'exit':
        break
    if not user_input:
        print("Please enter a word.")
        continue

    try:
        corrected_word = correct_word(model, user_input, vocab, device)
        print(f"  Original: {user_input}\n  Corrected: {corrected_word}")
    except Exception as e:
        print(f"An error occurred: {e}")
        print("Please ensure the model and vocabulary are properly loaded.")

print("Exiting interactive correction.")

Enter a Georgian word to correct (type 'exit' to quit):
Your word: გამარჯონა
  Original: გამარჯონა
  Corrected: გამარჯონა
Your word: არსი
  Original: არსი
  Corrected: არსის
Your word: არის
  Original: არის
  Corrected: არისა
Your word: კარგად
  Original: კარგად
  Corrected: კარგად
Your word: ოთახი
  Original: ოთახი
  Corrected: ოთახის
Your word: გამარჰობა
  Original: გამარჰობა
  Corrected: გამარჯობა
Your word: გაგიმარჰოს
  Original: გაგიმარჰოს
  Corrected: გაგიმართოს
Your word: ტერენტ
  Original: ტერენტ
  Corrected: ტერენტი
Your word: შოთ
  Original: შოთ
  Corrected: შოთ
Your word: შოთა
  Original: შოთა
  Corrected: შოთა
Your word: შოთი
  Original: შოთი
  Corrected: შოთი
Your word: კამპუსო
  Original: კამპუსო
  Corrected: კამპუსი
Your word: პროგამა
  Original: პროგამა
  Corrected: პროგამა
Your word: გამარჰონა
  Original: გამარჰონა
  Corrected: გამარგონა
Your word: პროგრმა
  Original: პროგრმა
  Corrected: პროგრამ
Your word: exit
Exiting interactive correction.


got my name correct, ტერენტ -> ტერენტი :D

there are weird words in dataset, that explain why it tries to match the uncorrect words to correct ones

In [31]:
print("\n" + "="*70)
print("STEP 2: Building character vocabulary")
print("="*70)
all_words = [word for pair in pairs for word in pair]
vocab = CharVocab().build_vocab(all_words)


vocab = CharVocab().build_vocab(all_words)

vocab_data = {
    "char2idx": vocab.char2idx,
    "idx2char": vocab.idx2char,
}

with open("drive/MyDrive/char_vocab.json", "w", encoding="utf-8") as f:
    json.dump(vocab_data, f, ensure_ascii=False, indent=2)



STEP 2: Building character vocabulary
Vocabulary size: 39
Vocabulary size: 39
