In [17]:
# ------------------------------
# Cell 1: Imports & Setup
# ------------------------------
import os
import random
import pickle
from collections import Counter
import time
import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x202d2ecc910>

In [18]:
# ------------------------------
# Cell 2: Device detection
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("CUDA device:", torch.cuda.get_device_name(0))
    torch.backends.cudnn.benchmark = True

Using device: cuda
CUDA device: NVIDIA GeForce RTX 3050 4GB Laptop GPU


In [19]:
# ------------------------------
# Cell 3: Paths & folders
# ------------------------------
os.makedirs("data", exist_ok=True)
os.makedirs("vocab", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

CORPUS_PATH = "data/all_hindi_clean.txt"
DATA_PAIRS_PATH = "data/data_pairs.pkl"
VOCAB_PATH = "vocab/hindi_vocab.tsv"

In [20]:
# ------------------------------
# Cell 4: Load corpus
# ------------------------------
MAX_LINES = 100000  # limit

if not os.path.exists(CORPUS_PATH):
    raise FileNotFoundError(f"Corpus not found at {CORPUS_PATH}")

sentences = []
with open(CORPUS_PATH, "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if i >= MAX_LINES:
            break
        line = line.strip()
        if line:
            sentences.append(line)

print(f"Loaded {len(sentences)} sentences (limited to {MAX_LINES})")

Loaded 100000 sentences (limited to 100000)


In [21]:
# ------------------------------
# Cell 5: Create typos function
# ------------------------------
def create_typos(sentence, typo_prob=0.2):
    words = sentence.split()
    new_words = []
    for w in words:
        if random.random() < typo_prob:
            typo_type = random.choice(["delete", "replace", "transpose"])
            if typo_type == "delete" and len(w) > 1:
                i = random.randint(0, len(w)-1)
                w = w[:i] + w[i+1:]
            elif typo_type == "replace" and len(w) > 0:
                i = random.randint(0, len(w)-1)
                w = w[:i] + random.choice(list(w)) + w[i+1:]
            elif typo_type == "transpose" and len(w) > 1:
                i = random.randint(0, len(w)-2)
                w = w[:i] + w[i+1] + w[i] + w[i+2:]
        new_words.append(w)
    return " ".join(new_words)

In [22]:
# ------------------------------
# Cell 6: Create dataset pairs
# ------------------------------
data_pairs = [(create_typos(s), s) for s in sentences]
print("Sample pair:", data_pairs[0])

Sample pair: ('के', 'के')


In [23]:
# ------------------------------
# Cell 7: Build char-level vocabulary
# ------------------------------

# Collect all characters in the dataset
all_text = " ".join(s for _, s in data_pairs)
chars = sorted(list(set(all_text)))

PAD, SOS, EOS, UNK = "<PAD>", "<SOS>", "<EOS>", "<UNK>"

vocab = {PAD: 0, SOS: 1, EOS: 2, UNK: 3}
for c in chars:
    if c not in vocab:
        vocab[c] = len(vocab)

rev_vocab = {idx: char for char, idx in vocab.items()}
vocab_size = len(vocab)
print(f"Char-level vocab size: {vocab_size}")

# Save char vocab
with open(VOCAB_PATH, "w", encoding="utf-8") as f:
    for char, idx in vocab.items():
        f.write(f"{char}\t{idx}\n")
print("Character-level vocabulary saved!")


Char-level vocab size: 91
Character-level vocabulary saved!


In [24]:
# ------------------------------
# Cell 8: Save data pairs
# ------------------------------
with open(DATA_PAIRS_PATH, "wb") as f:
    pickle.dump(data_pairs, f)
print("Data pairs saved!")

Data pairs saved!


In [25]:
# ------------------------------
# Cell 9: Dataset & DataLoader
# ------------------------------
class HindiSpellDataset(Dataset):
    def __init__(self, pairs, vocab):
        self.pairs = pairs
        self.vocab = vocab
        self.SOS = vocab[SOS]
        self.EOS = vocab[EOS]
        self.UNK = vocab[UNK]

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        # char-level tokenization
        src_ids = [self.vocab.get(c, self.UNK) for c in src] + [self.EOS]
        tgt_ids = [self.SOS] + [self.vocab.get(c, self.UNK) for c in tgt] + [self.EOS]
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)

def collate_fn(batch):
    PAD_IDX = vocab[PAD]
    src_batch, tgt_batch = zip(*batch)
    src_max = max(len(s) for s in src_batch)
    tgt_max = max(len(t) for t in tgt_batch)
    src_padded = torch.full((len(batch), src_max), PAD_IDX, dtype=torch.long)
    tgt_padded = torch.full((len(batch), tgt_max), PAD_IDX, dtype=torch.long)
    src_lengths = []
    tgt_lengths = []
    for i, (s, t) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :len(s)] = s
        tgt_padded[i, :len(t)] = t
        src_lengths.append(len(s))
        tgt_lengths.append(len(t))
    return src_padded, tgt_padded, torch.tensor(src_lengths), torch.tensor(tgt_lengths)

def make_dataloader(pairs, batch_size=16, shuffle=True):
    dataset = HindiSpellDataset(pairs, vocab)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn,
                        num_workers=0, pin_memory=(device.type=="cuda"))
    return loader

In [26]:
# ------------------------------
# Cell 10: Subsample dataset for laptop training
# ------------------------------
TRAIN_SUBSET = 100_000
if len(data_pairs) > TRAIN_SUBSET:
    data_pairs_subset = random.sample(data_pairs, TRAIN_SUBSET)
else:
    data_pairs_subset = data_pairs

print(f"Training on {len(data_pairs_subset)} sentence pairs (subset)")
dataloader = make_dataloader(data_pairs_subset, batch_size=16)


Training on 100000 sentence pairs (subset)


In [27]:
# ------------------------------
# Cell 11: Model definition
# ------------------------------
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        embedded = self.dropout(self.embedding(x))
        packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (h, c) = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        return out, (h, c)

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, hidden):
        x = x.unsqueeze(1)
        embedded = self.dropout(self.embedding(x))
        output, hidden = self.lstm(embedded, hidden)
        output = self.fc(output.squeeze(1))
        return output, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        tgt_len = tgt.size(1)
        vocab_size = self.decoder.fc.out_features
        outputs = torch.zeros(batch_size, tgt_len, vocab_size, device=self.device)

        encoder_out, hidden = self.encoder(src, src_lengths)
        input = tgt[:, 0]
        for t in range(1, tgt_len):
            output, hidden = self.decoder(input, hidden)
            outputs[:, t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = tgt[:, t] if teacher_force else top1
        return outputs


In [28]:
# ------------------------------
# Cell 12: Initialize model & optimizer
# ------------------------------
embed_size = 192
hidden_size = 256
num_layers = 2
dropout = 0.1

encoder = Encoder(vocab_size, embed_size, hidden_size, num_layers, dropout)
decoder = Decoder(vocab_size, embed_size, hidden_size, num_layers, dropout)
model = Seq2Seq(encoder, decoder, device).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD])
optimizer = optim.Adam(model.parameters(), lr=0.001)

use_amp = (device.type=="cuda")
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


In [29]:
# ------------------------------
# Cell 13: Improved Training Loop with Validation & Scheduled Teacher Forcing
# ------------------------------

# Hyperparameters
num_epochs = 10
max_grad_norm = 1.0
accum_steps = 2  # simulate larger batch size
best_val_loss = float('inf')

# Split train/validation
random.shuffle(data_pairs)
split_idx = int(0.9 * len(data_pairs))
train_pairs = data_pairs[:split_idx]
val_pairs = data_pairs[split_idx:]

train_loader = make_dataloader(train_pairs, batch_size=16)
val_loader = make_dataloader(val_pairs, batch_size=16, shuffle=False)

# Scheduled teacher forcing
def get_teacher_forcing_ratio(epoch, max_epochs=num_epochs):
    # start at 1.0, decay to 0.5
    return max(0.5, 1.0 - epoch/max_epochs)

# Training loop
for epoch in range(1, num_epochs+1):
    model.train()
    epoch_loss = 0
    start_time = time.time()
    teacher_forcing_ratio = get_teacher_forcing_ratio(epoch)
    
    batch_iterator = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}")
    
    for batch_idx, (src, tgt, src_lengths, tgt_lengths) in batch_iterator:
        batch_start = time.time()
        
        src, tgt = src.to(device), tgt.to(device)
        src_lengths, tgt_lengths = src_lengths.to(device), tgt_lengths.to(device)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type=device.type, enabled=use_amp):
            output = model(src, src_lengths, tgt, teacher_forcing_ratio=teacher_forcing_ratio)
            output_dim = output.shape[-1]
            output = output[:,1:].reshape(-1, output_dim)
            tgt_target = tgt[:,1:].reshape(-1)
            loss = criterion(output, tgt_target) / accum_steps

        scaler.scale(loss).backward()

        if (batch_idx + 1) % accum_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        epoch_loss += loss.item() * accum_steps
        
        batch_time = time.time() - batch_start
        elapsed = time.time() - start_time
        avg_batch_time = elapsed / (batch_idx + 1)
        remaining_batches = len(train_loader) - batch_idx - 1
        eta = remaining_batches * avg_batch_time
        
        batch_iterator.set_postfix({
            "Loss": f"{loss.item()*accum_steps:.4f}",
            "TF_Ratio": f"{teacher_forcing_ratio:.2f}",
            "Batch Time": f"{batch_time:.2f}s",
            "ETA": f"{eta/60:.1f}m"
        })

    avg_train_loss = epoch_loss / len(train_loader)
    print(f"\nEpoch {epoch} Average Training Loss: {avg_train_loss:.4f}")

    # Validation
    val_loss, val_acc = evaluate(model, val_loader, criterion, vocab, device)
    print(f"Epoch {epoch} Validation Loss: {val_loss:.4f}, Token Accuracy: {val_acc*100:.2f}%")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_path = f"checkpoints/seq2seq_best.pt"
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "vocab": vocab
        }, ckpt_path)
        print(f"Saved best model checkpoint: {ckpt_path}\n")


Epoch 1: 100%|██████████| 5625/5625 [04:09<00:00, 22.54it/s, Loss=0.2035, TF_Ratio=0.90, Batch Time=0.05s, ETA=0.0m]



Epoch 1 Average Training Loss: 0.7577


Evaluating: 100%|██████████| 625/625 [00:05<00:00, 112.87it/s]


Avg Loss: 0.7081 | Token Accuracy: 86.85%
Epoch 1 Validation Loss: 0.7081, Token Accuracy: 86.85%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 2: 100%|██████████| 5625/5625 [04:39<00:00, 20.10it/s, Loss=0.3037, TF_Ratio=0.80, Batch Time=0.04s, ETA=0.0m]



Epoch 2 Average Training Loss: 0.3022


Evaluating: 100%|██████████| 625/625 [00:06<00:00, 98.23it/s] 


Avg Loss: 0.5745 | Token Accuracy: 88.47%
Epoch 2 Validation Loss: 0.5745, Token Accuracy: 88.47%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 3: 100%|██████████| 5625/5625 [05:18<00:00, 17.66it/s, Loss=0.1857, TF_Ratio=0.70, Batch Time=0.07s, ETA=0.0m]



Epoch 3 Average Training Loss: 0.2615


Evaluating: 100%|██████████| 625/625 [00:08<00:00, 73.73it/s]


Avg Loss: 0.4648 | Token Accuracy: 90.42%
Epoch 3 Validation Loss: 0.4648, Token Accuracy: 90.42%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 4: 100%|██████████| 5625/5625 [05:55<00:00, 15.80it/s, Loss=0.1173, TF_Ratio=0.60, Batch Time=0.05s, ETA=0.0m]  



Epoch 4 Average Training Loss: 0.2436


Evaluating: 100%|██████████| 625/625 [00:06<00:00, 96.29it/s] 


Avg Loss: 0.4116 | Token Accuracy: 91.42%
Epoch 4 Validation Loss: 0.4116, Token Accuracy: 91.42%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 5: 100%|██████████| 5625/5625 [05:12<00:00, 18.02it/s, Loss=0.2409, TF_Ratio=0.50, Batch Time=0.05s, ETA=0.0m]



Epoch 5 Average Training Loss: 0.2377


Evaluating: 100%|██████████| 625/625 [00:05<00:00, 109.16it/s]


Avg Loss: 0.3889 | Token Accuracy: 91.62%
Epoch 5 Validation Loss: 0.3889, Token Accuracy: 91.62%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 6: 100%|██████████| 5625/5625 [05:10<00:00, 18.12it/s, Loss=0.2860, TF_Ratio=0.50, Batch Time=0.05s, ETA=0.0m]



Epoch 6 Average Training Loss: 0.2187


Evaluating: 100%|██████████| 625/625 [00:05<00:00, 116.27it/s]


Avg Loss: 0.3828 | Token Accuracy: 91.79%
Epoch 6 Validation Loss: 0.3828, Token Accuracy: 91.79%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 7: 100%|██████████| 5625/5625 [05:08<00:00, 18.26it/s, Loss=0.1998, TF_Ratio=0.50, Batch Time=0.06s, ETA=0.0m]  



Epoch 7 Average Training Loss: 0.2060


Evaluating: 100%|██████████| 625/625 [00:07<00:00, 85.07it/s]


Avg Loss: 0.3708 | Token Accuracy: 92.16%
Epoch 7 Validation Loss: 0.3708, Token Accuracy: 92.16%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 8: 100%|██████████| 5625/5625 [04:33<00:00, 20.58it/s, Loss=0.1393, TF_Ratio=0.50, Batch Time=0.02s, ETA=0.0m]



Epoch 8 Average Training Loss: 0.1941


Evaluating: 100%|██████████| 625/625 [00:04<00:00, 135.83it/s]


Avg Loss: 0.3668 | Token Accuracy: 92.36%
Epoch 8 Validation Loss: 0.3668, Token Accuracy: 92.36%
Saved best model checkpoint: checkpoints/seq2seq_best.pt



Epoch 9: 100%|██████████| 5625/5625 [04:53<00:00, 19.18it/s, Loss=0.2447, TF_Ratio=0.50, Batch Time=0.05s, ETA=0.0m]



Epoch 9 Average Training Loss: 0.1832


Evaluating: 100%|██████████| 625/625 [00:05<00:00, 118.10it/s]


Avg Loss: 0.3672 | Token Accuracy: 92.36%
Epoch 9 Validation Loss: 0.3672, Token Accuracy: 92.36%


Epoch 10: 100%|██████████| 5625/5625 [04:53<00:00, 19.19it/s, Loss=0.1288, TF_Ratio=0.50, Batch Time=0.03s, ETA=0.0m]



Epoch 10 Average Training Loss: 0.1756


Evaluating: 100%|██████████| 625/625 [00:04<00:00, 128.42it/s]

Avg Loss: 0.3659 | Token Accuracy: 92.53%
Epoch 10 Validation Loss: 0.3659, Token Accuracy: 92.53%
Saved best model checkpoint: checkpoints/seq2seq_best.pt






In [30]:
torch.save(encoder.state_dict(), 'encoder_state_dict.h5')
torch.save(decoder.state_dict(), 'decoder_state_dict.h5')

In [34]:
# ------------------------------
# Cell 14: Evaluation (Improved)
# ------------------------------
def evaluate(model, dataloader, criterion, vocab, device):
    model.eval()
    total_loss = 0
    total_items = 0      # number of sequences
    total_tokens = 0
    correct_tokens = 0

    PAD_IDX = vocab["<PAD>"]

    with torch.no_grad():
        for src, tgt, src_lengths, tgt_lengths in tqdm(dataloader, desc="Evaluating"):
            src, tgt = src.to(device), tgt.to(device)
            src_lengths, tgt_lengths = src_lengths.to(device), tgt_lengths.to(device)

            # forward pass (no teacher forcing)
            output = model(src, src_lengths, tgt, teacher_forcing_ratio=0.0)

            # flatten for CE loss
            output_dim = output.shape[-1]
            output_flat = output[:, 1:].reshape(-1, output_dim)
            tgt_flat = tgt[:, 1:].reshape(-1)

            loss = criterion(output_flat, tgt_flat)

            batch_size = src.size(0)
            total_loss += loss.item() * batch_size
            total_items += batch_size

            # accuracy (char-level)
            pred_tokens = output_flat.argmax(dim=1)
            mask = tgt_flat != PAD_IDX

            correct_tokens += (pred_tokens[mask] == tgt_flat[mask]).sum().item()
            total_tokens += mask.sum().item()

    # final metrics
    avg_loss = total_loss / max(1, total_items)
    accuracy = correct_tokens / max(1, total_tokens)

    print(f"\n=== Evaluation Result ===")
    print(f"Avg Loss       : {avg_loss:.4f}")
    print(f"Char Accuracy  : {accuracy * 100:.2f}%")
    print("==========================\n")

    return avg_loss, accuracy



def decode_sequence(ids, rev_vocab):
    """
    Convert ID list → Hindi string (character-level)
    Ignore special tokens.
    """
    ignore = {vocab["<PAD>"], vocab["<SOS>"], vocab["<EOS>"]}
    chars = [rev_vocab.get(i, "") for i in ids if i not in ignore]
    return "".join(chars)




def show_sample_predictions(model, dataloader, rev_vocab, num_samples=5):
    model.eval()
    shown = 0
    
    with torch.no_grad():
        for src, tgt, src_lengths, tgt_lengths in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            src_lengths = src_lengths.to(device)

            output = model(src, src_lengths, tgt, teacher_forcing_ratio=0.0)
            pred_ids = output.argmax(dim=-1)

            for i in range(src.size(0)):
                src_txt = decode_sequence(src[i].cpu().tolist(), rev_vocab)
                tgt_txt = decode_sequence(tgt[i].cpu().tolist(), rev_vocab)
                pred_txt = decode_sequence(pred_ids[i].cpu().tolist(), rev_vocab)

                print("Input     :", src_txt)
                print("Target    :", tgt_txt)
                print("Predicted :", pred_txt)
                print("-" * 60)

                shown += 1
                if shown >= num_samples:
                    return



# ------------------------------
# Load checkpoint (optional)
# ------------------------------
ckpt_path = "checkpoints/seq2seq_best.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model_state"])
print(f"Loaded BEST checkpoint from {ckpt_path}")



# ------------------------------
# Evaluate on the same subset
# ------------------------------
evaluate(model, dataloader, criterion, vocab, device)

# Show sample predictions
show_sample_predictions(model, dataloader, rev_vocab, num_samples=5)

  checkpoint = torch.load(ckpt_path, map_location=device)


Loaded BEST checkpoint from checkpoints/seq2seq_best.pt


Evaluating: 100%|██████████| 6250/6250 [02:45<00:00, 37.76it/s] 



=== Evaluation Result ===
Avg Loss       : 0.2541
Char Accuracy  : 94.38%

Input     : रॉलिंग
Target    : रॉलिंग
Predicted : रॉलिंग
------------------------------------------------------------
Input     : कनका
Target    : कनका
Predicted : कनका
------------------------------------------------------------
Input     : तम्बोली
Target    : तम्बोली
Predicted : तम्बोली
------------------------------------------------------------
Input     : ओल्मपियन
Target    : ओलम्पियन
Predicted : ओल्मपियन
------------------------------------------------------------
Input     : पुश्तैनी
Target    : पुश्तैनी
Predicted : पुश्तैनी
------------------------------------------------------------
