# Install dependencies

In [51]:
!pip -q install torch torchvision torchaudio
!pip -q install datasets onnx onnxruntime opencv-python pillow tqdm
!pip install huggingface_hub[hf_xet]
!pip install ipywidgets



# Imports & Config

In [52]:
import os, json, random
from pathlib import Path
import numpy as np
import cv2
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset

DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
SEED          = 2025
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# How much data to load (Hugging Face slice syntax)
TRAIN_SLICE = "train[:10000]"   # change to e.g. "train[:50000]" when you scale up
VAL_SLICE   = "val[:1000]"

# Image size (match inference)
IMG_H, IMG_W = 32, 128

# Training knobs
BATCH_SIZE  = 128
EPOCHS      = 15          # increase to 15–20 for better accuracy
LR          = 1e-3
NUM_WORKERS = 0          # keep 0 in Colab for stability
PIN_MEMORY  = False

# Charset: digits + upper + lower + common symbols + SPACE
DIGITS = "0123456789"
UPPER  = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
LOWER  = "abcdefghijklmnopqrstuvwxyz"
SYMS   = "-_:.,!?@#&()+*/=%$'\"[]{}<>\\|"
CHARS  = DIGITS + UPPER + LOWER + SYMS + " "   # include space
BLANK_IDX = 0
char2idx = {c:i+1 for i,c in enumerate(CHARS)}
idx2char = {i+1:c for i,c in enumerate(CHARS)}
NCLASS = len(CHARS) + 1

ARTIFACTS = Path(".\\artifacts")
ARTIFACTS.mkdir(parents=True, exist_ok=True)
CKPT_PATH = ARTIFACTS/f"crnn_nc{NCLASS}_k2x1_best.pt"
LAST_PATH = ARTIFACTS/f"crnn_nc{NCLASS}_k2x1_last.pt"
ONNX_PATH = ARTIFACTS/f"crnn_nc{NCLASS}_k2x1.onnx"

with open(ARTIFACTS/"charset.json","w") as f:
    json.dump({"chars": CHARS}, f, ensure_ascii=False)

# Load HuggingFace data

In [53]:
from datasets import Dataset
import pandas as pd

# Read manifest into DataFrame
df = pd.read_csv("C:\\Arvin\\icbt\\assignments\\Final Project\\mjsynth_500k\\manifest.tsv", sep="\t", names=["image", "label"])

# Prepend full path to images
df["image"] = df["image"].apply(lambda x: f"C:\\Arvin\\icbt\\assignments\\Final Project\\mjsynth_500k\\{x}")

# Build Hugging Face dataset
hf_dataset = Dataset.from_pandas(df)

# Train/val split
split = hf_dataset.train_test_split(test_size=0.1, seed=42)
train_ds_hf = split["train"]
val_ds_hf   = split["test"]

print(train_ds_hf, val_ds_hf)

Dataset({
    features: ['image', 'label'],
    num_rows: 10000
}) Dataset({
    features: ['image', 'label'],
    num_rows: 1000
})


# Preprocess: grayscale + keep-aspect + left-pad

In [54]:
def preprocess_gray_keep_ratio(pil_img, target_h=IMG_H, target_w=IMG_W):
    # ensure grayscale np.uint8
    if pil_img.mode != "L":
        pil_img = pil_img.convert("L")
    g = np.array(pil_img, dtype=np.uint8)
    h, w = g.shape[:2]
    # fit within target while preserving aspect
    scale = min(target_w / w, target_h / h)
    nw, nh = max(1, int(w*scale)), max(1, int(h*scale))
    r = cv2.resize(g, (nw, nh), interpolation=cv2.INTER_LINEAR)
    canvas = np.full((target_h, target_w), 255, np.uint8)
    y0 = (target_h - nh)//2
    x0 = 0   # left align horizontally
    canvas[y0:y0+nh, x0:x0+nw] = r
    ten = torch.from_numpy(canvas).unsqueeze(0).float()/255.0  # [1,H,W]
    return ten

def encode_label(text):
    # filter to known charset; ensure non-empty for CTC
    arr = [char2idx[c] for c in text if c in char2idx]
    if len(arr) == 0:
        arr = [char2idx['A']]
    return torch.tensor(arr, dtype=torch.long)

class HFWrapper(Dataset):
    def __init__(self, hf_ds):
        self.hf = hf_ds
    def __len__(self): return len(self.hf)
    def __getitem__(self, idx):
        ex = self.hf[idx]
        img: Image.Image = ex["image"]
        lbl: str = ex["label"]
        img_t = preprocess_gray_keep_ratio(img)
        lab_t = encode_label(lbl)
        return img_t, lab_t, lbl

def collate_fn(batch):
    imgs, labs, raw = zip(*batch)
    # per-sample empty guard (extra safety)
    fixed, fixed_raw = [], list(raw)
    for i, l in enumerate(labs):
        if len(l) == 0:
            fixed.append(torch.tensor([char2idx["A"]], dtype=torch.long))
            if fixed_raw[i].strip() == "":
                fixed_raw[i] = "A"
        else:
            fixed.append(l)
    imgs = torch.stack(imgs)  # [B,1,H,W]
    label_lengths = torch.tensor([len(x) for x in fixed], dtype=torch.long)
    labels_concat = torch.cat(fixed)
    return imgs, labels_concat, label_lengths, fixed_raw

train_dl = DataLoader(HFWrapper(train_ds_hf), batch_size=BATCH_SIZE, shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)
val_dl   = DataLoader(HFWrapper(val_ds_hf), batch_size=BATCH_SIZE, shuffle=False,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)

# CRNN model

In [55]:
class CRNN(nn.Module):
    def __init__(self, nclass):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2,2),   # 32x128 -> 16x64

            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2,2),   # 16x64 -> 8x32

            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,1), (2,1)),   # 8x32 -> 4x32

            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,1), (2,1)),   # 4x32 -> 2x32

            nn.Conv2d(512, 512, kernel_size=(2,1), stride=1, padding=0),
            nn.ReLU(True),
        )
        self.rnn = nn.LSTM(512, 256, num_layers=2, bidirectional=True, batch_first=False)
        self.fc  = nn.Linear(512, nclass)

    def forward(self, x):
        f = self.cnn(x)            # [B,512,1,32]
        f = f.squeeze(2)           # [B,512,32]
        f = f.permute(2,0,1)       # [32,B,512] (T,B,C)
        r, _ = self.rnn(f)         # [32,B,512]
        y = self.fc(r)             # [32,B,nclass]
        return y

model = CRNN(NCLASS).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
ctc_loss  = nn.CTCLoss(blank=BLANK_IDX, zero_infinity=True)

def ctc_greedy_decode(logits):   # logits: [T,B,C]
    pred = logits.argmax(2).detach().cpu().numpy()
    T, B = pred.shape
    outs = []
    for b in range(B):
        seq, prev, chars = pred[:, b], BLANK_IDX, []
        for s in seq:
            if s != prev and s != BLANK_IDX:
                chars.append(idx2char.get(int(s), ""))
            prev = s
        outs.append("".join(chars))
    return outs

# Train (safe checkpointing)

In [56]:
def run_epoch(dl, train=True):
    model.train(train)
    total_loss, total, exact = 0.0, 0, 0
    for imgs, labels_concat, label_lengths, raw in tqdm(dl, leave=False):
        imgs = imgs.to(DEVICE)
        labels_concat = labels_concat.to(DEVICE)
        logits = model(imgs)                 # [T,B,C]
        log_probs = F.log_softmax(logits, dim=2)
        T, B, C = log_probs.shape
        input_lengths = torch.full((B,), T, dtype=torch.long, device=DEVICE)
        loss = ctc_loss(log_probs, labels_concat, input_lengths, label_lengths.to(DEVICE))
        if train:
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
        total_loss += loss.item()
        total += B
        # strict exact-match (will be low early)
        preds = ctc_greedy_decode(logits.detach())
        for p, gt in zip(preds, raw):
            exact += (p == gt)
    return total_loss / max(1, len(dl)), exact / max(1, total)

best_val = -1.0
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_dl, True)
    vl_loss, vl_acc = run_epoch(val_dl, False)
    print(f"[Epoch {ep}/{EPOCHS}] train_loss={tr_loss:.4f} acc={tr_acc*100:.2f}% | val_loss={vl_loss:.4f} acc={vl_acc*100:.2f}%")
    # always save "last"
    torch.save(model.state_dict(), LAST_PATH)
    # save "best" when improved (>= ensures epoch1 saves)
    if vl_acc >= best_val:
        best_val = vl_acc
        torch.save(model.state_dict(), CKPT_PATH)

  0%|          | 0/782 [00:00<?, ?it/s]

FileNotFoundError: [Errno 2] No such file or directory: 'image'

# Export ONNX + quick ORT check

In [None]:
model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
model.eval()

dummy = torch.randn(1,1,IMG_H,IMG_W, device=DEVICE)
torch.onnx.export(
    model, dummy, str(ONNX_PATH),
    input_names=["input"], output_names=["logits"],
    opset_version=14
)
print("Saved ONNX:", ONNX_PATH, "| Charset:", ARTIFACTS/"charset.json")

import onnxruntime as ort
sess = ort.InferenceSession(str(ONNX_PATH))
out = sess.run(None, {"input": dummy.cpu().numpy()})
print("ONNXRuntime forward OK, output shape:", np.array(out[0]).shape)

# OpenCV DNN inference helper (word/line crops)

In [None]:
def ctc_decode_from_logits_np(logits_np):
    if logits_np.ndim == 3:      # [T,B,C]
        arg = logits_np.argmax(2)[:,0]
    elif logits_np.ndim == 2:    # [T,C]
        arg = logits_np.argmax(1)
    else:                        # [1,T,C]
        arg = logits_np[0].argmax(1)
    out, prev = [], BLANK_IDX
    for s in arg.tolist():
        if s != prev and s != BLANK_IDX:
            out.append(idx2char.get(int(s), ""))
        prev = s
    return "".join(out)

net = cv2.dnn.readNet(str(ONNX_PATH))

def preprocess_for_infer_bgr(bgr, target_h=IMG_H, target_w=IMG_W):
    g = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
    h, w = g.shape[:2]
    scale = min(target_w / w, target_h / h)
    nw, nh = max(1, int(w*scale)), max(1, int(h*scale))
    r = cv2.resize(g, (nw, nh), interpolation=cv2.INTER_LINEAR)
    canvas = np.full((target_h, target_w), 255, np.uint8)
    y0 = (target_h - nh)//2
    x0 = 0
    canvas[y0:y0+nh, x0:x0+nw] = r
    blob = canvas.astype(np.float32)/255.0
    return blob[np.newaxis, np.newaxis, :, :]

def predict_image(path):
    bgr = cv2.imread(path, cv2.IMREAD_COLOR)
    if bgr is None:
        raise FileNotFoundError(path)
    inp = preprocess_for_infer_bgr(bgr)
    net.setInput(inp)
    y = net.forward()
    return ctc_decode_from_logits_np(y)

print("\nReady. Use: predict_image('/path/to/your_word_crop.png')  # returns text")