# Imports

In [1]:
import os, json, math, random
from pathlib import Path
import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

from datasets import Dataset as HFDataset, Features, Value
from datasets import Image as HFImage
from tqdm.auto import tqdm

# Config

In [3]:
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
SEED         = 2025
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# (A) paths (EDIT THESE TWO)
BASE_DIR     = r"C:\Arvin\icbt\assignments\Final Project\mjsynth_500k"
MANIFEST     = rf"{BASE_DIR}\manifest.tsv"

# (B) subset sizes (edit any time)
N_TRAIN      = 10_000
N_TEST       = 1_000

# (C) training knobs
IMG_H, IMG_W = 32, 128           # widen to 256 if your words are long (then retrain & re-export)
BATCH_SIZE   = 128
EPOCHS       = 5                  # raise to 10-20 for better accuracy
LR           = 1e-3
NUM_WORKERS  = 0                  # Windows-safe
PIN_MEMORY   = torch.cuda.is_available()

# (D) charset (must include every char in your labels; SPACE included)
DIGITS = "0123456789"
UPPER  = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
LOWER  = "abcdefghijklmnopqrstuvwxyz"
SYMS   = "-_:.,!?@#&()+*/=%$'\"[]{}<>\\|"
CHARS  = DIGITS + UPPER + LOWER + SYMS + " "
BLANK_IDX = 0
char2idx = {c:i+1 for i,c in enumerate(CHARS)}  # 1..N for chars; 0 is CTC blank
idx2char = {i+1:c for i,c in enumerate(CHARS)}
NCLASS   = len(CHARS) + 1

# (E) artifacts
ARTIFACTS = Path(BASE_DIR) / "artifacts"
ARTIFACTS.mkdir(parents=True, exist_ok=True)
CKPT_BEST = ARTIFACTS / f"crnn_nc{NCLASS}_best.pt"
CKPT_LAST = ARTIFACTS / f"crnn_nc{NCLASS}_last.pt"
ONNX_PATH = ARTIFACTS / f"crnn_nc{NCLASS}.onnx"
with open(ARTIFACTS/"charset.json","w",encoding="utf-8") as f:
    json.dump({"chars": CHARS}, f, ensure_ascii=False)

# Build a HF dataset from ONE manifest (skip header if present)

In [4]:
import pandas as pd

df = pd.read_csv(
    MANIFEST, sep="\t", header=None, names=["image","label"],
    dtype={"image": str, "label": str}, engine="python"
)
# drop a header row if someone wrote "image\tlabel"
df = df[df["image"].str.lower() != "image"].reset_index(drop=True)

# prepend absolute path and filter missing files
df["image"] = df["image"].apply(lambda rel: os.path.join(BASE_DIR, rel))
df = df[df["image"].apply(os.path.exists)].reset_index(drop=True)
assert len(df) > 0, "No valid rows after filtering missing files."

# make HF dataset that yields PIL.Image automatically
features = Features({"image": HFImage(), "label": Value("string")})
hf_full  = HFDataset.from_pandas(df, features=features)

# shuffle once, then select fixed counts
hf_full  = hf_full.shuffle(seed=SEED)
take_tr  = min(N_TRAIN, len(hf_full))
take_te  = min(N_TEST, max(0, len(hf_full)-take_tr))
train_ds_hf = hf_full.select(range(take_tr))
val_ds_hf   = hf_full.select(range(take_tr, take_tr+take_te))

print(f"Loaded: train={len(train_ds_hf)} | val={len(val_ds_hf)}")


Loaded: train=10000 | val=1000


#  Preprocess + label encode + wrapper + collate

In [5]:
def preprocess_gray_keep_ratio(pil_img, target_h=IMG_H, target_w=IMG_W):
    if pil_img.mode != "L":
        pil_img = pil_img.convert("L")
    g = np.array(pil_img, dtype=np.uint8)
    h, w = g.shape[:2]
    scale = min(target_w / w, target_h / h)
    nw, nh = max(1, int(w*scale)), max(1, int(h*scale))
    r = cv2.resize(g, (nw, nh), interpolation=cv2.INTER_LINEAR)
    canvas = np.full((target_h, target_w), 255, np.uint8)
    y0 = (target_h - nh)//2
    x0 = 0  # left align
    canvas[y0:y0+nh, x0:x0+nw] = r
    return torch.from_numpy(canvas).unsqueeze(0).float()/255.0  # [1,H,W]

def encode_label(text):
    arr = [char2idx[c] for c in text if c in char2idx]
    if len(arr) == 0:
        arr = [char2idx['A']]
    return torch.tensor(arr, dtype=torch.long)

class HFWrapper(Dataset):
    def __init__(self, hf_ds):
        self.hf = hf_ds
    def __len__(self):
        return len(self.hf)
    def __getitem__(self, idx):
        ex = self.hf[idx]         # {"image": PIL.Image, "label": str}
        img_t = preprocess_gray_keep_ratio(ex["image"])
        lab_t = encode_label(ex["label"])
        return img_t, lab_t, ex["label"]

def collate_fn(batch):
    imgs, labs, raws = zip(*batch)         # always (tensor, tensor, str)
    labs = list(labs); raws = list(raws)
    for i, l in enumerate(labs):
        if len(l) == 0:
            labs[i] = torch.tensor([char2idx['A']], dtype=torch.long)
            raws[i] = "A"
    imgs = torch.stack(imgs)               # [B,1,H,W]
    label_lengths = torch.tensor([len(l) for l in labs], dtype=torch.long)
    labels_concat = torch.cat(labs)
    return imgs, labels_concat, label_lengths, raws

train_dl = DataLoader(HFWrapper(train_ds_hf), batch_size=BATCH_SIZE, shuffle=True,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)
val_dl   = DataLoader(HFWrapper(val_ds_hf), batch_size=BATCH_SIZE, shuffle=False,
                      num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)

# CRNN model (CNN -> BiLSTM -> Linear)

In [6]:
class CRNN(nn.Module):
    def __init__(self, nclass):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2,2),                         # 32x128 -> 16x64

            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2,2),                         # 16x64 -> 8x32

            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,1),(2,1)),                 # 8x32  -> 4x32

            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,1),(2,1)),                 # 4x32  -> 2x32

            nn.Conv2d(512, 512, kernel_size=(2,1), stride=1),
            nn.ReLU(True),
        )
        self.rnn = nn.LSTM(512, 256, num_layers=2, bidirectional=True, batch_first=False)
        self.fc  = nn.Linear(512, nclass)

    def forward(self, x):               # x: [B,1,32,128]
        f = self.cnn(x)                 # [B,512,1,32]
        f = f.squeeze(2)                # [B,512,32]
        f = f.permute(2,0,1)            # [T=32, B, 512]
        r, _ = self.rnn(f)              # [T, B, 512]
        y = self.fc(r)                  # [T, B, nclass]
        return y

model = CRNN(NCLASS).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
ctc_loss  = nn.CTCLoss(blank=BLANK_IDX, zero_infinity=True)

def ctc_greedy_decode(logits):          # logits: [T,B,C]
    pred = logits.argmax(2).detach().cpu().numpy()
    T, B = pred.shape
    outs = []
    for b in range(B):
        seq, prev, chars = pred[:, b], BLANK_IDX, []
        for s in seq:
            if s != prev and s != BLANK_IDX:
                chars.append(idx2char.get(int(s), ""))
            prev = s
        outs.append("".join(chars))
    return outs

# Train loop (safe checkpointing)

In [None]:
def run_epoch(dl, train=True):
    model.train(train)
    total_loss, total, exact = 0.0, 0, 0
    for imgs, labels_concat, label_lengths, raw in tqdm(dl, leave=False):
        imgs = imgs.to(DEVICE)
        labels_concat = labels_concat.to(DEVICE)

        logits = model(imgs)                        # [T,B,C]
        log_probs = F.log_softmax(logits, dim=2)
        T, B, C = log_probs.shape
        input_lengths = torch.full((B,), T, dtype=torch.long, device=DEVICE)

        loss = ctc_loss(log_probs, labels_concat, input_lengths, label_lengths.to(DEVICE))
        if train:
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

        total_loss += loss.item()
        total += B

        preds = ctc_greedy_decode(logits.detach())
        for p, gt in zip(preds, raw):
            exact += (p == gt)

    return total_loss / max(1, len(dl)), exact / max(1, total)

best_val = -1.0
for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_dl, True)
    vl_loss, vl_acc = run_epoch(val_dl, False)
    print(f"[Epoch {ep}/{EPOCHS}] train_loss={tr_loss:.4f} acc={tr_acc*100:.2f}% | val_loss={vl_loss:.4f} acc={vl_acc*100:.2f}%")
    torch.save(model.state_dict(), CKPT_LAST)
    if vl_acc >= best_val:
        best_val = vl_acc
        torch.save(model.state_dict(), CKPT_BEST)

print("Saved:", CKPT_BEST, "|", CKPT_LAST)

  0%|          | 0/79 [00:00<?, ?it/s]

# Export ONNX

In [None]:
import onnx, onnxruntime as ort

model.load_state_dict(torch.load(CKPT_BEST, map_location=DEVICE))
model.eval()
dummy = torch.randn(1,1,IMG_H,IMG_W, device=DEVICE)
torch.onnx.export(
    model, dummy, str(ONNX_PATH),
    input_names=["input"], output_names=["logits"],
    opset_version=14
)
print("ONNX saved to:", ONNX_PATH)