In [1]:
# ================================================================
# Segmentation-based Offline Handwritten Text OCR (PyTorch Lightning)
# - Builds char glyphs FROM word images (once), then saves NPZs
# - On next runs, loads NPZs instead of rebuilding
# - Word inference: segment (contours) -> classify each ROI -> join
# - Preserves your preprocessing:
#     * Drop NaNs, drop 'unreadable'
#     * Crop top-left 64x256; pad with white (255)
# - Reports: CER, 1-CER, ACC (exact match), WER
# - Plots loss curves after training (loss_curve.png)
# ================================================================

import os
import cv2
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset

import pytorch_lightning as pl
from pytorch_lightning import seed_everything

# ---------------------------
# Config (word dataset paths)
# ---------------------------

TRAIN_CSV     = "/kaggle/input/handwriting-recognition/written_name_train_v2.csv"
TRAIN_IMG_DIR = "/kaggle/input/handwriting-recognition/train_v2/train"
VAL_CSV       = "/kaggle/input/handwriting-recognition/written_name_validation_v2.csv"
VAL_IMG_DIR   = "/kaggle/input/handwriting-recognition/validation_v2/validation"
TEST_CSV      = "/kaggle/input/handwriting-recognition/written_name_test_v2.csv"
TEST_IMG_DIR  = "/kaggle/input/handwriting-recognition/test_v2/test"

# Artifacts (model + saved glyph datasets)
MODEL_PATH       = "char_cnn_32x32.pt"
CLASSES_PATH     = "char_classes.pkl"
TRAIN_GLYPHS_NPZ = "train_chars.npz"
VAL_GLYPHS_NPZ   = "val_chars.npz"
LOSS_PNG         = "loss_curve.png"

GLYPH_SIZE  = 32
WORD_CROP_H = 64
WORD_CROP_W = 256

seed_everything(42, workers=True)

# ---------------------------
# Metrics
# ---------------------------

def levenshtein_distance(s1: str, s2: str) -> int:
    if len(s1) < len(s2):
        s1, s2 = s2, s1
    previous = list(range(len(s2) + 1))
    for i, c1 in enumerate(s1):
        current = [i + 1]
        for j, c2 in enumerate(s2):
            ins = previous[j + 1] + 1
            dele = current[j] + 1
            sub = previous[j] + (c1 != c2)
            current.append(min(ins, dele, sub))
        previous = current
    return previous[-1]

def compute_metrics(preds, truths):
    total_chars, total_char_errs = 0, 0
    total_words, total_word_errs = 0, 0
    exact = 0
    for gt, pr in zip(truths, preds):
        dist = levenshtein_distance(gt, pr)
        total_char_errs += dist
        total_chars += len(gt)
        total_words += 1
        total_word_errs += int(gt != pr)
        if gt == pr:
            exact += 1
    cer = (total_char_errs / total_chars) if total_chars > 0 else 0.0
    one_minus_cer = 1.0 - cer
    acc = (exact / total_words) if total_words > 0 else 0.0
    wer = (total_word_errs / total_words) if total_words > 0 else 0.0
    return cer, one_minus_cer, acc, wer

# ---------------------------
# Word CSV loader (keep NaN/'unreadable' filtering)
# ---------------------------

def load_word_index(csv_path, images_dir):
    df = pd.read_csv(csv_path)
    if "FILENAME" not in df.columns or "IDENTITY" not in df.columns:
        raise ValueError(f"{csv_path} must have columns 'FILENAME' and 'IDENTITY'")
    df = df.dropna(subset=["FILENAME", "IDENTITY"]).copy()
    df["IDENTITY"] = df["IDENTITY"].astype(str)
    df = df[df["IDENTITY"].str.strip().str.lower() != "unreadable"].reset_index(drop=True)
    return [(os.path.join(images_dir, r["FILENAME"]), r["IDENTITY"]) for _, r in df.iterrows()]

# ---------------------------
# Preprocess word images (yours)
# ---------------------------

def preprocess_word_image_top_left_crop_pad(img_gray: np.ndarray, crop_h=64, crop_w=256) -> np.ndarray:
    h, w = img_gray.shape[:2]
    crop = img_gray[:min(h, crop_h), :min(w, crop_w)]
    out = np.ones((crop_h, crop_w), dtype=np.uint8) * 255  # white pad
    out[:crop.shape[0], :crop.shape[1]] = crop
    return out

# ---------------------------
# Segmentation helpers
# ---------------------------

def sort_contours(cnts, method="left-to-right"):
    reverse = False
    axis = 0
    if method in ("right-to-left", "bottom-to-top"):
        reverse = True
    if method in ("top-to-bottom", "bottom-to-top"):
        axis = 1
    boxes = [cv2.boundingRect(c) for c in cnts]
    cnts = list(cnts)
    pairs = sorted(zip(cnts, boxes), key=lambda b: b[1][axis], reverse=reverse)
    return [c for c, _ in pairs]

def segment_rois(img_gray_64x256: np.ndarray):
    """Return list of ROIs (grayscale uint8) left->right."""
    _, thresh1 = cv2.threshold(img_gray_64x256, 127, 255, cv2.THRESH_BINARY_INV)
    dilated = cv2.dilate(thresh1, None, iterations=2)
    cnts = cv2.findContours(dilated.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
    cnts = sort_contours(cnts, method="left-to-right")
    rois = []
    H, W = img_gray_64x256.shape[:2]
    for c in cnts:
        if cv2.contourArea(c) <= 10:
            continue
        x, y, w, h = cv2.boundingRect(c)
        x0, y0 = max(0, x), max(0, y)
        x1, y1 = min(W, x + w), min(H, y + h)
        roi = img_gray_64x256[y0:y1, x0:x1]
        if roi.size == 0:
            continue
        rois.append(roi)
    return rois

def roi_to_32x32(roi):
    _, roi_bin = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
    roi_bin = cv2.resize(roi_bin, (GLYPH_SIZE, GLYPH_SIZE), interpolation=cv2.INTER_CUBIC)
    roi_bin = roi_bin.astype(np.float32) / 255.0
    roi_bin = np.expand_dims(roi_bin, axis=0)  # [1,32,32]
    return roi_bin

# ---------------------------
# Build & SAVE char dataset FROM word images
# ---------------------------

def build_and_save_char_glyphs(train_csv, train_img_dir, val_csv, val_img_dir,
                               out_train_npz=TRAIN_GLYPHS_NPZ, out_val_npz=VAL_GLYPHS_NPZ,
                               out_classes_pkl=CLASSES_PATH,
                               crop_h=WORD_CROP_H, crop_w=WORD_CROP_W,
                               max_words_train=None, max_words_val=None):
    """
    Build glyph datasets from train/val words and save to NPZ.
    Saves:
      - train_chars.npz: X [N,1,32,32] float32, y [N] object (characters)
      - val_chars.npz:   X [M,1,32,32], y [M] object
      - char_classes.pkl: list of sorted unique characters across train+val
    """
    def build(csv_path, img_dir, max_words):
        pairs = load_word_index(csv_path, img_dir)
        if max_words is not None:
            pairs = pairs[:max_words]
        X_list, y_list = [], []
        for img_path, text in tqdm(pairs, desc=f"Building char dataset from {os.path.basename(csv_path)}"):
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is None or not isinstance(text, str) or len(text) == 0:
                continue
            img = preprocess_word_image_top_left_crop_pad(img, crop_h, crop_w)
            rois = segment_rois(img)
            if not rois:
                continue
            k = min(len(rois), len(text))
            for i in range(k):
                roi32 = roi_to_32x32(rois[i])  # [1,32,32]
                X_list.append(roi32)
                y_list.append(text[i])
        if not X_list:
            raise RuntimeError(f"No glyphs built from {csv_path}.")
        X = np.stack(X_list, axis=0).astype(np.float32)  # [N,1,32,32]
        y = np.array(y_list, dtype=object)               # keep raw chars
        return X, y

    Xtr, ytr = build(train_csv, train_img_dir, max_words_train)
    Xva, yva = build(val_csv, val_img_dir, max_words_val)

    # union of classes
    classes = sorted(list(set(ytr.tolist()) | set(yva.tolist())))

    # save compressed
    np.savez_compressed(out_train_npz, X=Xtr, y=ytr)
    np.savez_compressed(out_val_npz,   X=Xva, y=yva)
    with open(out_classes_pkl, "wb") as f:
        pickle.dump(classes, f)

    print(f"[Saved] {out_train_npz}  (X={Xtr.shape}, y={len(ytr)})")
    print(f"[Saved] {out_val_npz}    (X={Xva.shape}, y={len(yva)})")
    print(f"[Saved] {out_classes_pkl}  (#classes={len(classes)}): {classes}")

# ---------------------------
# LOAD saved glyph datasets
# ---------------------------

def load_char_glyphs(train_npz=TRAIN_GLYPHS_NPZ, val_npz=VAL_GLYPHS_NPZ, classes_pkl=CLASSES_PATH):
    if not (os.path.isfile(train_npz) and os.path.isfile(val_npz) and os.path.isfile(classes_pkl)):
        return None
    dtr = np.load(train_npz, allow_pickle=True)
    dva = np.load(val_npz, allow_pickle=True)
    with open(classes_pkl, "rb") as f:
        classes = pickle.load(f)
    Xtr, ytr = dtr["X"], dtr["y"].tolist()
    Xva, yva = dva["X"], dva["y"].tolist()
    return (Xtr, ytr), (Xva, yva), classes

# ---------------------------
# Character CNN & Lightning
# ---------------------------

class CharCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16
            nn.Conv2d(32, 64, kernel_size=3), nn.ReLU(inplace=True),  # 14x14
            nn.MaxPool2d(2),  # 7x7
            nn.Conv2d(64, 128, kernel_size=3), nn.ReLU(inplace=True), # 5x5
            nn.MaxPool2d(2),  # 2x2
            nn.Dropout(0.25),
            nn.Flatten(),
            nn.Linear(128 * 2 * 2, 128), nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes),
        )
    def forward(self, x): return self.net(x)

class CharModule(pl.LightningModule):
    def __init__(self, num_classes: int, lr: float = 1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = CharCNN(num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.train_epoch_losses, self.val_epoch_losses = [], []
        self._train_buf, self._val_buf = [], []
    def forward(self, x): return self.model(x)
    def training_step(self, batch, _):
        x, y = batch
        logits = self(x); loss = self.criterion(logits, y)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True, batch_size=x.size(0))
        self._train_buf.append(loss.detach().cpu().item()); return loss
    def validation_step(self, batch, _):
        x, y = batch
        loss = self.criterion(self(x), y)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, batch_size=x.size(0))
        self._val_buf.append(loss.detach().cpu().item())
    def on_train_epoch_end(self):
        if self._train_buf: self.train_epoch_losses.append(float(np.mean(self._train_buf))); self._train_buf = []
    def on_validation_epoch_end(self):
        if self._val_buf: self.val_epoch_losses.append(float(np.mean(self._val_buf))); self._val_buf = []
    def configure_optimizers(self): return optim.Adam(self.parameters(), lr=self.hparams.lr)

# ---------------------------
# Dataloaders from NPZs
# ---------------------------

def make_loaders_from_npz(Xtr, ytr, Xva, yva, classes, batch_size=256):
    class_to_idx = {c: i for i, c in enumerate(classes)}
    ytr_idx = torch.tensor([class_to_idx[c] for c in ytr], dtype=torch.long)
    yva_idx = torch.tensor([class_to_idx[c] for c in yva], dtype=torch.long)

    # X are numpy arrays [N,1,32,32] float32
    Xtr_t = torch.from_numpy(Xtr).float()
    Xva_t = torch.from_numpy(Xva).float()

    train_ds = TensorDataset(Xtr_t, ytr_idx)
    val_ds   = TensorDataset(Xva_t, yva_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=3, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=3, pin_memory=True)
    return train_loader, val_loader, class_to_idx

# ---------------------------
# Segment & predict a word
# ---------------------------

@torch.no_grad()
def segment_and_predict_word(img_gray_64x256: np.ndarray, model: CharCNN, idx_to_class: dict, device: torch.device) -> str:
    rois = segment_rois(img_gray_64x256)
    if not rois: return ""
    model.eval()
    letters = []
    for roi in rois:
        roi32 = roi_to_32x32(roi)  # [1,32,32]
        tensor = torch.from_numpy(np.expand_dims(roi32, 0)).to(device)  # [1,1,32,32]
        logits = model(tensor)
        pred_idx = int(torch.argmax(logits, dim=1).item())
        letters.append(idx_to_class.get(pred_idx, ""))
    return "".join(letters)

# ---------------------------
# Evaluate on full word images
# ---------------------------

def evaluate_on_words(csv_path, img_dir, model, idx_to_class, device, max_items=None, show_examples=5):
    pairs = load_word_index(csv_path, img_dir)
    if max_items is not None: pairs = pairs[:max_items]
    preds, truths = [], []
    for img_path, gt in tqdm(pairs, desc=f"Eval {os.path.basename(csv_path)}"):
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            preds.append(""); truths.append(gt); continue
        img_proc = preprocess_word_image_top_left_crop_pad(img, WORD_CROP_H, WORD_CROP_W)
        pred = segment_and_predict_word(img_proc, model, idx_to_class, device)
        preds.append(pred); truths.append(gt)
    print("\n--- Examples ---")
    for i in range(min(show_examples, len(preds))):
        print(f"GT: {truths[i]} | PRED: {preds[i]}")
    cer, one_minus_cer, acc, wer = compute_metrics(preds, truths)
    print("\n--- Metrics ---")
    print(f"CER: {cer:.6f}")
    print(f"1 - CER (Char Acc): {one_minus_cer:.6f}")
    print(f"ACC (Exact Match): {acc:.6f}")
    print(f"WER: {wer:.6f}")
    return preds, truths, (cer, one_minus_cer, acc, wer)

# ---------------------------
# Main
# ---------------------------

def main():
    # 1) Try loading saved glyphs
    loaded = load_char_glyphs()
    if loaded is None:
        # 2) Build and save glyphs (one-time)
        build_and_save_char_glyphs(TRAIN_CSV, TRAIN_IMG_DIR, VAL_CSV, VAL_IMG_DIR)
        loaded = load_char_glyphs()
    (Xtr, ytr), (Xva, yva), classes = loaded
    print(f"[Glyphs] train X={Xtr.shape}, val X={Xva.shape}, classes={len(classes)}")

    # 3) Train (or load) character classifier
    if os.path.isfile(MODEL_PATH) and os.path.isfile(CLASSES_PATH):
        with open(CLASSES_PATH, "rb") as f:
            classes_loaded = pickle.load(f)
        if classes_loaded != classes:
            print("[Warn] classes in disk differ from NPZ classes; overwriting with NPZ classes.")
            with open(CLASSES_PATH, "wb") as f:
                pickle.dump(classes, f)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = CharCNN(num_classes=len(classes)).to(device)
        if os.path.isfile(MODEL_PATH):
            model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
            print("[Info] Loaded existing model weights.")
    else:
        with open(CLASSES_PATH, "wb") as f:
            pickle.dump(classes, f)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = CharCNN(num_classes=len(classes)).to(device)

    # 4) Make loaders from NPZ and train with Lightning
    train_loader, val_loader, class_to_idx = make_loaders_from_npz(Xtr, ytr, Xva, yva, classes, batch_size=256)
    idx_to_class = {i: c for c, i in class_to_idx.items()}

    module = CharModule(num_classes=len(classes), lr=1e-3)
    # initialize with current model weights if present
    module.model.load_state_dict(model.state_dict(), strict=False)

    accelerator = "gpu" if torch.cuda.is_available() else "cpu"
    trainer = pl.Trainer(max_epochs=40, accelerator=accelerator, devices=1, log_every_n_steps=20)
    trainer.fit(module, train_loader, val_loader)

    # Save weights and loss curve
    torch.save(module.model.state_dict(), MODEL_PATH)
    plt.figure(figsize=(6,4))
    plt.plot(range(1, len(module.train_epoch_losses)+1), module.train_epoch_losses, label="Train Loss")
    plt.plot(range(1, len(module.val_epoch_losses)+1), module.val_epoch_losses, label="Val Loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Char CNN Loss")
    plt.legend(); plt.tight_layout(); plt.savefig(LOSS_PNG)
    print(f"[Saved] Model -> {MODEL_PATH}")
    print(f"[Saved] Classes -> {CLASSES_PATH}")
    print(f"[Saved] Loss curve -> {LOSS_PNG}")

    # Reload best state for eval (just use saved)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CharCNN(len(classes)).to(device)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

    # 5) Evaluate on validation and test (segment -> classify -> join)
    print("\n[Validation]")
    evaluate_on_words(VAL_CSV, VAL_IMG_DIR, model, idx_to_class, device)

    print("\n[Test]")
    evaluate_on_words(TEST_CSV, TEST_IMG_DIR, model, idx_to_class, device)

if __name__ == "__main__":
    main()


Building char dataset from written_name_train_v2.csv: 100%|██████████| 330294/330294 [33:08<00:00, 166.07it/s]
Building char dataset from written_name_validation_v2.csv: 100%|██████████| 41280/41280 [04:04<00:00, 169.07it/s]


[Saved] train_chars.npz  (X=(2105109, 1, 32, 32), y=2105109)
[Saved] val_chars.npz    (X=(263612, 1, 32, 32), y=263612)
[Saved] char_classes.pkl  (#classes=50): [' ', "'", '-', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '`', 'a', 'b', 'c', 'e', 'f', 'g', 'h', 'i', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'y', 'z']
[Glyphs] train X=(2105109, 1, 32, 32), val X=(263612, 1, 32, 32), classes=50


2025-10-09 19:03:40.010702: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760036620.197268      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760036620.248136      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Saved] Model -> char_cnn_32x32.pt
[Saved] Classes -> char_classes.pkl
[Saved] Loss curve -> loss_curve.png

[Validation]


Eval written_name_validation_v2.csv: 100%|██████████| 41280/41280 [06:35<00:00, 104.29it/s]



--- Examples ---
GT: BILEL | PRED: BAILEL
GT: LAUMIONIER | PRED: LAUMONIER
GT: LEA | PRED: ALEAIEAA
GT: JEAN-ROCH | PRED: JEANEROCH
GT: RUPP | PRED: RUPP

--- Metrics ---
CER: 0.645707
1 - CER (Char Acc): 0.354293
ACC (Exact Match): 0.205547
WER: 0.794453

[Test]


Eval written_name_test_v2.csv: 100%|██████████| 41289/41289 [11:30<00:00, 59.82it/s]



--- Examples ---
GT: KEVIN | PRED: KAVIEN
GT: CLOTAIRE | PRED: LOEAIEE
GT: LENA | PRED: LENA
GT: JULES | PRED: TEULES
GT: CHERPIN | PRED: CHERPIIN

--- Metrics ---
CER: 0.655236
1 - CER (Char Acc): 0.344764
ACC (Exact Match): 0.194798
WER: 0.805202
