# Install deps (Colab-safe)

In [None]:
!pip -q install torch torchvision torchaudio
!pip -q install opencv-python numpy pillow tqdm
!pip install onnx
!pip install onnx onnxruntime



# Imports & Config

In [None]:
import os, json, random
import cv2
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# --- Charset: digits + uppercase + a few symbols (no space to avoid CTC empties)
DIGITS = "0123456789"
UPPER  = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
SYMS   = "-_:.,!?@#&()+*/="
CHARS  = DIGITS + UPPER + SYMS   # <-- keep this unchanged in this run

IMG_H = 32
IMG_W = 128       # fixed width to keep ONNX export simple and stable
MAX_LABEL_LEN = 10
TRAIN_SAMPLES = 50000
VAL_SAMPLES   = 5000
BATCH_SIZE    = 128
EPOCHS        = 20      # increase to 10–20 for better accuracy
LR            = 1e-3
SEED          = 2025
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

os.makedirs("artifacts", exist_ok=True)

# --- CTC mapping (0 = blank)
char2idx = {c:i+1 for i,c in enumerate(CHARS)}
idx2char = {i+1:c for i,c in enumerate(CHARS)}
BLANK_IDX = 0
NCLASS = len(CHARS) + 1

# --- Save charset + config
with open("artifacts/charset.json", "w") as f:
    json.dump({"chars": CHARS}, f, ensure_ascii=False)

LAST_CONV_TAG = "k2x1"                    # we use kernel_size=(2,1) below
CKPT_PATH = f"artifacts/crnn_nc{NCLASS}_{LAST_CONV_TAG}_best.pt"
ONNX_PATH = f"artifacts/crnn_nc{NCLASS}_{LAST_CONV_TAG}.onnx"


# Label helpers

In [None]:
def encode_label(text):
    arr = [char2idx[c] for c in text if c in char2idx]
    if len(arr) == 0:              # CTC target must be non-empty
        arr = [char2idx['A']]
    return arr

def ctc_greedy_decode(logits):     # logits: T x B x C
    pred = logits.argmax(2).detach().cpu().numpy()
    T, B = pred.shape
    out = []
    for b in range(B):
        seq = pred[:, b]
        s_prev = BLANK_IDX
        chars = []
        for s in seq:
            if s != s_prev and s != BLANK_IDX:
                chars.append(idx2char.get(int(s), ""))
            s_prev = s
        out.append("".join(chars))
    return out

# Minimal synthetic dataset (OpenCV text, no risky augments)

In [None]:
def rand_text(min_len=3, max_len=MAX_LABEL_LEN):
    L = random.randint(min_len, max_len)
    s = "".join(random.choice(CHARS) for _ in range(L))
    if not any(c.isalnum() for c in s):   # ensure not all symbols
        s = "A" + s[1:]
    return s

def render_text_image(text, h=IMG_H, w=IMG_W):
    img = np.full((h, w), 255, np.uint8)
    font = cv2.FONT_HERSHEY_SIMPLEX
    scale = 0.9
    thickness = 1
    (tw, th), _ = cv2.getTextSize(text, font, scale, thickness)
    if tw > w - 6:
        scale = max(0.5, (w - 6) / max(1, tw) * scale)
        (tw, th), _ = cv2.getTextSize(text, font, scale, thickness)
    x = max(3, (w - tw) // 2)
    y = max(th + 3, (h + th) // 2)
    cv2.putText(img, text, (x, y), font, scale, (0,), thickness, cv2.LINE_AA)
    return img

class SynthDataset(Dataset):
    def __init__(self, n):
        self.n = n
    def __len__(self): return self.n
    def __getitem__(self, idx):
        t = rand_text()
        im = render_text_image(t)
        im_t = torch.from_numpy(im).unsqueeze(0).float() / 255.0  # 1xHxW
        lab = torch.tensor(encode_label(t), dtype=torch.long)
        return im_t, lab, t

def collate_fn(batch):
    imgs, labs, raw = zip(*batch)
    fixed = []
    fixed_raw = list(raw)
    for i, l in enumerate(labs):
        if len(l) == 0:
            fixed.append(torch.tensor([char2idx["A"]], dtype=torch.long))
            if fixed_raw[i] == "":
                fixed_raw[i] = "A"
        else:
            fixed.append(l)
    imgs = torch.stack(imgs)  # Bx1xHxW
    label_lengths = torch.tensor([len(x) for x in fixed], dtype=torch.long)
    labels_concat = torch.cat(fixed)
    return imgs, labels_concat, label_lengths, fixed_raw

train_dl = DataLoader(
    SynthDataset(TRAIN_SAMPLES),
    batch_size=BATCH_SIZE, shuffle=True,
    num_workers=0, pin_memory=False, collate_fn=collate_fn
)
val_dl = DataLoader(
    SynthDataset(VAL_SAMPLES),
    batch_size=BATCH_SIZE, shuffle=False,
    num_workers=0, pin_memory=False, collate_fn=collate_fn
)

# CRNN model (CNN -> BiLSTM -> Linear), T=32 for W=128 (downscale W by 4; collapse H to 1)

In [None]:
class CRNN(nn.Module):
    def __init__(self, nclass):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2,2),   # 32x128 -> 16x64

            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2,2),   # 16x64 -> 8x32

            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,1), (2,1)),   # 8x32 -> 4x32

            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2,1), (2,1)),   # 4x32 -> 2x32

            # FIXED to (2,1) → matches LAST_CONV_TAG
            nn.Conv2d(512, 512, kernel_size=(2,1), stride=1, padding=0),
            nn.ReLU(True),
        )
        self.rnn = nn.LSTM(512, 256, num_layers=2, bidirectional=True, batch_first=False)
        self.fc  = nn.Linear(512, nclass)

    def forward(self, x):
        f = self.cnn(x)            # Bx512x1x32
        f = f.squeeze(2)           # Bx512x32
        f = f.permute(2,0,1)       # 32 x B x 512
        r, _ = self.rnn(f)         # 32 x B x 512
        y = self.fc(r)             # 32 x B x nclass
        return y

model = CRNN(NCLASS).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
ctc_loss  = nn.CTCLoss(blank=BLANK_IDX, zero_infinity=True)

# Train & Validate

In [None]:
import os

def run_epoch(dl, train=True):
    model.train(train)
    total_loss, total, correct = 0.0, 0, 0
    for imgs, labels_concat, label_lengths, raw in tqdm(dl, leave=False):
        imgs = imgs.to(DEVICE)
        labels_concat = labels_concat.to(DEVICE)

        logits = model(imgs)                 # T x B x C
        log_probs = F.log_softmax(logits, dim=2)
        T, B, C = log_probs.shape
        input_lengths = torch.full((B,), T, dtype=torch.long, device=DEVICE)

        loss = ctc_loss(log_probs, labels_concat, input_lengths, label_lengths.to(DEVICE))
        if train:
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

        total_loss += loss.item()
        total += B

        # exact string match (strict); tends to be 0% early on
        preds = ctc_greedy_decode(logits.detach())
        for p, gt in zip(preds, raw):
            correct += (p == gt)

    return total_loss / max(1, len(dl)), correct / max(1, total)

best_val = -1.0  # important: allow first epoch to save
LAST_PATH = CKPT_PATH.replace("_best.pt", "_last.pt")

for ep in range(1, EPOCHS+1):
    tr_loss, tr_acc = run_epoch(train_dl, True)
    vl_loss, vl_acc = run_epoch(val_dl, False)
    print(f"[Epoch {ep}/{EPOCHS}] train_loss={tr_loss:.4f} acc={tr_acc*100:.2f}% | val_loss={vl_loss:.4f} acc={vl_acc*100:.2f}%")

    # Always save the 'last' checkpoint
    torch.save(model.state_dict(), LAST_PATH)

    # Save 'best' when it improves (>= so epoch 1 saves even if 0%)
    if vl_acc >= best_val:
        best_val = vl_acc
        torch.save(model.state_dict(), CKPT_PATH)

# After training, guarantee CKPT_PATH exists
if not os.path.exists(CKPT_PATH):
    torch.save(model.state_dict(), CKPT_PATH)

# Load back the checkpoint **we just saved for THIS run**
model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
model.eval()

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 1/20] train_loss=4.4341 acc=0.00% | val_loss=4.3562 acc=0.00%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 2/20] train_loss=4.3522 acc=0.00% | val_loss=4.3529 acc=0.00%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 3/20] train_loss=4.3215 acc=0.00% | val_loss=4.0943 acc=0.00%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 4/20] train_loss=1.1344 acc=50.91% | val_loss=0.0239 acc=97.00%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 5/20] train_loss=0.0134 acc=98.17% | val_loss=0.0050 acc=99.70%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 6/20] train_loss=0.0110 acc=98.35% | val_loss=0.0035 acc=99.62%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 7/20] train_loss=0.0027 acc=99.65% | val_loss=0.0017 acc=99.70%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 8/20] train_loss=0.0025 acc=99.65% | val_loss=0.7094 acc=27.12%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 9/20] train_loss=0.0034 acc=99.48% | val_loss=0.0010 acc=99.88%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 10/20] train_loss=0.0009 acc=99.90% | val_loss=0.0007 acc=99.90%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 11/20] train_loss=0.0013 acc=99.79% | val_loss=0.0012 acc=99.74%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 12/20] train_loss=0.0008 acc=99.88% | val_loss=0.0002 acc=100.00%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 13/20] train_loss=0.0007 acc=99.90% | val_loss=0.0007 acc=99.92%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 14/20] train_loss=0.0175 acc=97.88% | val_loss=0.0009 acc=99.84%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 15/20] train_loss=0.0005 acc=99.94% | val_loss=0.0003 acc=99.98%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 16/20] train_loss=0.0008 acc=99.86% | val_loss=0.0004 acc=99.96%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 17/20] train_loss=0.0003 acc=99.96% | val_loss=0.0002 acc=99.98%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 18/20] train_loss=0.0003 acc=99.95% | val_loss=0.0002 acc=99.94%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 19/20] train_loss=0.0041 acc=99.26% | val_loss=0.0006 acc=99.90%


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

[Epoch 20/20] train_loss=0.0005 acc=99.91% | val_loss=0.0002 acc=100.00%


CRNN(
  (cnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inp

# Export ONNX (static shape 1x1x32x128 for simplicity)

In [None]:
dummy = torch.randn(1,1,IMG_H,IMG_W, device=DEVICE)
torch.onnx.export(
    model, dummy, ONNX_PATH,
    input_names=["input"], output_names=["logits"],
    opset_version=14
)
print("Saved ONNX to", ONNX_PATH, " | Checkpoint:", CKPT_PATH, " | Charset: artifacts/charset.json")

  torch.onnx.export(


Saved ONNX to artifacts/crnn_nc53_k2x1.onnx  | Checkpoint: artifacts/crnn_nc53_k2x1_best.pt  | Charset: artifacts/charset.json


# Quick check: OpenCV DNN inference on a synthetic sample

In [None]:
def preprocess_for_onnx(img_gray):
    g = cv2.resize(img_gray, (IMG_W, IMG_H)).astype(np.float32)/255.0
    return g[np.newaxis, np.newaxis, :, :]  # 1x1xH xW

def ctc_decode_from_onnx(logits_np):
    if logits_np.ndim == 3:      # T x B x C
        argm = logits_np.argmax(2)[:,0]
    elif logits_np.ndim == 2:    # T x C
        argm = logits_np.argmax(1)
    else:                        # 1 x T x C
        argm = logits_np[0].argmax(1)
    out, prev = [], BLANK_IDX
    for s in argm.tolist():
        if s != prev and s != BLANK_IDX:
            out.append(idx2char.get(int(s), ""))
        prev = s
    return "".join(out)

test_text = "A1-29!B"
test_img  = render_text_image(test_text)
net = cv2.dnn.readNet(ONNX_PATH)
net.setInput(preprocess_for_onnx(test_img))
out = net.forward()
pred = ctc_decode_from_onnx(out)
cv2.imwrite("artifacts/demo_input.png", cv2.cvtColor(test_img, cv2.COLOR_GRAY2BGR))
print("GT:", test_text, "| PRED:", pred)
print("Demo image saved at artifacts/demo_input.png")

GT: A1-29!B | PRED: A1-29!B
Demo image saved at artifacts/demo_input.png
