In [21]:
import os, sys, json, numpy as np
from datetime import datetime

def _import_or_fail():
    err_msgs = []
    ocr_utils = None
    for try_mod in ("ocr_utils", "model", "utils", "dataset"):
        try:
            ocr_utils = __import__(try_mod, fromlist=["*"])
            if all(hasattr(ocr_utils, name) for name in ("set_seed","load_split","split_train_valid")):
                break
        except Exception as e:
            err_msgs.append(f"[warn] import {try_mod} failed: {e}")
            ocr_utils = None
    if ocr_utils is None:
        raise ImportError("Could not import set_seed/load_split/split_train_valid from ocr_utils/model/utils/dataset.\n" + "\n".join(err_msgs))

    train_mod = __import__("train", fromlist=["*"])
    assert all(hasattr(train_mod, nm) for nm in ("train_one_model","evaluate_and_save","_ensure_X","build_model"))
    return ocr_utils, train_mod

ocr_utils, train_mod = _import_or_fail()
set_seed = getattr(ocr_utils, "set_seed")
load_split = getattr(ocr_utils, "load_split")
split_train_valid = getattr(ocr_utils, "split_train_valid")

train_one_model = getattr(train_mod, "train_one_model")
evaluate_and_save = getattr(train_mod, "evaluate_and_save")
_ensure_X = getattr(train_mod, "_ensure_X")
build_model = getattr(train_mod, "build_model")

print("[OK] Imports ready.")


[OK] Imports ready.


In [25]:

TRAIN_DIR = "../data/train"
TEST_DIR  = "../data/test"
MODEL     = "cnn"
IMAGE_SZ  = 32
EPOCHS    = 10
LR        = 1e-3
BATCH     = 128
HIDDEN    = [128]
USE_DROPOUT   = True
DROPOUT_RATIO = 0.5
SEED = 42
OUT_ROOT = "runs"
set_seed(SEED)



In [23]:

print(f"[Load][Train] {TRAIN_DIR}")
Xtr_all, ytr_all, label2idx, idx2label, stats_tr = load_split(
    TRAIN_DIR, image_size=IMAGE_SZ, filter_to_train_labels=None, char_filter_fn=None
)
print("Train stats:", stats_tr)

Xtr, ytr, Xva, yva = split_train_valid(Xtr_all, ytr_all, valid_ratio=0.2, seed=SEED)

if MODEL == "cnn":
    Xtr = Xtr.reshape(len(Xtr), IMAGE_SZ, IMAGE_SZ)
    Xva = Xva.reshape(len(Xva), IMAGE_SZ, IMAGE_SZ)
    input_info = (1, IMAGE_SZ, IMAGE_SZ)
else:
    input_info = IMAGE_SZ * IMAGE_SZ

output_size = len(label2idx)
print("[Info] classes:", output_size, "input_info:", input_info)

net, best = train_one_model(
    Xtr, ytr, Xva, yva,
    input_info=input_info, output_size=output_size,
    model_type=MODEL,
    hidden_sizes=HIDDEN,
    lr=LR, batch_size=BATCH, epochs=EPOCHS,
    use_batchnorm=(MODEL=="mlp"),
    use_dropout=USE_DROPOUT, dropout_ratio=DROPOUT_RATIO,
    optimizer_name="adam", seed=SEED,
)


[Load][Train] ../data/train
[../data/train] 100/814 used=203
[../data/train] 200/814 used=402
[../data/train] 300/814 used=598
[../data/train] 400/814 used=790
[../data/train] 500/814 used=985
[../data/train] 600/814 used=1185
[../data/train] 700/814 used=1384
[../data/train] 800/814 used=1584
Train stats: defaultdict(<class 'int'>, {'skipped_non_single': 32197, 'used': 1612})
[Info] classes: 1103 input_info: (1, 32, 32)
[CNN][Epoch 001] train_acc=0.0054  val_acc=0.0000
[CNN][Epoch 002] train_acc=0.0047  val_acc=0.0000
[CNN][Epoch 003] train_acc=0.0101  val_acc=0.0000
[CNN][Epoch 004] train_acc=0.0124  val_acc=0.0000
[CNN][Epoch 005] train_acc=0.0264  val_acc=0.0000
[CNN][Epoch 006] train_acc=0.0271  val_acc=0.0000
[CNN][Epoch 007] train_acc=0.0419  val_acc=0.0000
[CNN][Epoch 008] train_acc=0.0597  val_acc=0.0000
[CNN][Epoch 009] train_acc=0.0814  val_acc=0.0000
[CNN][Epoch 010] train_acc=0.1264  val_acc=0.0031
=> Restored best epoch 10 (val_acc=0.0031)


In [26]:

print(f"[Load][Test] {TEST_DIR}")
Xte, yte, _, _, stats_te = load_split(
    TEST_DIR, image_size=IMAGE_SZ, filter_to_train_labels=label2idx, char_filter_fn=None
)
print("Test stats:", stats_te)

if MODEL == "cnn":
    Xte = Xte.reshape(len(Xte), IMAGE_SZ, IMAGE_SZ)

ts = datetime.now().strftime("%Y%m%d-%H%M%S")
out_dir = os.path.join(OUT_ROOT, f"{MODEL}_{ts}")
os.makedirs(out_dir, exist_ok=True)

with open(os.path.join(out_dir, "label2idx.json"), "w", encoding="utf-8") as f:
    json.dump({k:int(v) for k,v in label2idx.items()}, f, ensure_ascii=False, indent=2)

acc, cm = evaluate_and_save(net, Xte, yte, idx2label, out_dir=out_dir)
np.save(os.path.join(out_dir, "cm.npy"), cm)

state = net.snapshot_state()
np.savez(os.path.join(out_dir, "model_state.npz"), **state)

print("[Done] out_dir =", out_dir, "  test_acc =", acc)


[Load][Test] ../data/test
[../data/test] 100/1057 used=115
[../data/test] 200/1057 used=236
[../data/test] 300/1057 used=361
[../data/test] 400/1057 used=494
[../data/test] 500/1057 used=610
[../data/test] 600/1057 used=733
[../data/test] 700/1057 used=870
[../data/test] 800/1057 used=991
[../data/test] 900/1057 used=1109
[../data/test] 1000/1057 used=1232
Test stats: defaultdict(<class 'int'>, {'skipped_non_single': 41721, 'skipped_unknown_label': 812, 'used': 1288})
[Validation] accuracy = 0.0070  (N=1288)
Saved: runs\cnn_20251023-165141\cm.npy and labels.json
[Done] out_dir = runs\cnn_20251023-165141   test_acc = 0.006987577639751553
