In [None]:
# LOCAL ENVIRONMENT SETUP 
!pip install torch torchvision torchaudio --quiet
!pip install pytorch-lightning albumentations scipy scikit-image matplotlib tqdm textdistance --quiet

import torch, cv2, numpy as np, textdistance

print("Torch version:", torch.__version__)
print("OpenCV version:", cv2.__version__)
print("NumPy version:", np.__version__)
print("TextDistance version:", textdistance.__version__)


In [None]:
import torch, cv2, numpy as np, textdistance

# Check GPU availability
print("GPU Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))

# Test OpenCV image loading
import os
print("OpenCV test:")
blank = np.zeros((50, 100, 3), dtype=np.uint8)
cv2.putText(blank, "OK", (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2)
cv2.imwrite("cv_test.jpg", blank)
img = cv2.imread("cv_test.jpg")
print("Image shape:", img.shape)

# Test TextDistance Levenshtein distance
d = textdistance.levenshtein.distance("hello", "hallo")
print("TextDistance test: 'hello' vs 'hallo' =", d)


In [None]:
# LOCAL DATASET EXTRACTION
import os, zipfile

# Base path 
BASE_DIR = r"Your Path"

# Path to ZIP dataset
ZIP_PATH = os.path.join(BASE_DIR, "dataset", "archive.zip")

# Where to extract the dataset
DATA_ROOT = os.path.join(BASE_DIR, "dataset", "extracted")

print("Using local dataset ZIP:", ZIP_PATH)

# Create extraction folder if not exists
os.makedirs(DATA_ROOT, exist_ok=True)

# Extract only if not already extracted
if not any(os.scandir(DATA_ROOT)):
    print("Extracting dataset... please wait.")
    with zipfile.ZipFile(ZIP_PATH, 'r') as z:
        z.extractall(DATA_ROOT)
    print("Extracted to:", DATA_ROOT)
else:
    print("Dataset already extracted at:", DATA_ROOT)

# Define type of dataset youâ€™re using
DATA_TYPE = "line" 

# Verify contents
print("\nTop-level structure:")
for root, dirs, files in os.walk(DATA_ROOT):
    print(f"{root} | subfolders: {dirs[:3]} | #files: {len(files)}")
    break

print("\nDATA_ROOT:", DATA_ROOT)
print("DATA_TYPE:", DATA_TYPE)


In [None]:
# Cell 3 â€” Inspect Local Dataset Structure
import os, glob, pprint

# Ensure paths from Cell 2 exist
if 'DATA_ROOT' not in globals():
    DATA_ROOT = r"E:\PYTHON\CV_Project2\dataset\extracted"   # fallback if not defined

if 'DATA_TYPE' not in globals():
    DATA_TYPE = "line"   

print("DATA_ROOT:", DATA_ROOT)
print("DATA_TYPE:", DATA_TYPE)

print("\nTop-level tree:")
for root, dirs, files in os.walk(DATA_ROOT):
    print(f"{root}   subdirs: {dirs[:5]}   #files: {len(files)}")
    break  

# Locate the main dataset directories
images_dirs = glob.glob(os.path.join(DATA_ROOT, "**", "Images"), recursive=True)
trans_dirs = glob.glob(os.path.join(DATA_ROOT, "**", "Transcriptions"), recursive=True)
sets_dirs = glob.glob(os.path.join(DATA_ROOT, "**", "Sets"), recursive=True)

print("\nDetected nodes:")
print(" Images dirs:", images_dirs[:3])
print(" Transcription dirs:", trans_dirs[:3])
print(" Sets dirs:", sets_dirs[:3])

# Show sample image and transcription filenames
if images_dirs:
    sample_images = glob.glob(os.path.join(images_dirs[0], "*"))[:8]
    print("\nSample images:")
    for img in sample_images:
        print("  ", img)

if trans_dirs:
    sample_trans = glob.glob(os.path.join(trans_dirs[0], "*"))[:8]
    print("\nSample transcription files:")
    for txt in sample_trans:
        print("  ", txt)


In [None]:
# TRANSCRIPTION LOADER 

import os, re, glob
from typing import Dict

#  Utility functions 

def normalize_text(t: str) -> str:
    """Clean whitespace and remove extra spaces."""
    if not t:
        return ""
    t = t.strip()
    t = re.sub(r"\s+", " ", t)
    return t

def _read_text_file(path: str) -> str:
    """Read file safely with multiple encodings."""
    for enc in ("utf-8", "latin1"):
        try:
            with open(path, "r", encoding=enc, errors="ignore") as fh:
                return fh.read()
        except:
            continue
    return ""

def normalize_key(k: str) -> str:
    """
    Normalize the transcription key so it matches image filenames:
       train2011-130_6 â†’ train2011-130_000006
    """
    k = os.path.splitext(os.path.basename(k))[0]
    if "-" in k and "_" in k:
        prefix, num = k.split("_", 1)
        num = num.zfill(6)
        return f"{prefix}_{num}"
    return k

# Main Loader 

def load_transcriptions(data_root: str) -> Dict[str, str]:
    trans_map = {}

    # Load ONLY per-image .txt files from Transcriptions folder
    trans_paths = glob.glob(
        os.path.join(data_root, "**", "Transcriptions", "*.txt"),
        recursive=True
    )

    print(f"Found {len(trans_paths)} transcription files in Transcriptions/ folder.")

    for p in sorted(trans_paths):
        fname = os.path.basename(p).replace(".txt", "")
        key = normalize_key(fname)

        txt = _read_text_file(p).strip()
        if not txt:
            continue

        # Remove quotes "..."
        if txt.startswith('"') and txt.endswith('"'):
            txt = txt[1:-1]

        txt = normalize_text(txt)
        if len(txt) < 2:
            continue

        trans_map[key] = txt

    # We IGNORE mapping-style files (TrainLines.txt, TestLines.txt, etc.)
    #    because they DO NOT contain actual text â†’ only filenames.

    # Final cleanup of keys just to be safe
    trans_map = {normalize_key(k): v for k, v in trans_map.items()}

    print(f"Loaded {len(trans_map)} valid transcriptions.")
    print("Example keys:", list(trans_map.keys())[:8])
    return trans_map

# Assign global TRANS_MAP 
TRANS_MAP = load_transcriptions(DATA_ROOT)


In [None]:
key = "eval2011-0_000001"
print("GT:", TRANS_MAP.get(key))


In [None]:
# RWGD 

import numpy as np
import cv2
from scipy.interpolate import griddata

def RWGD_light(image, interval_at_h80=26, std_at_h80=1.7, target_h=80):
    """
    Memory-optimized Random Warp Grid Distortion (RWGD) implementation.
    Uses linear interpolation to save memory and prevent kernel crashes.

    Args:
        image: Grayscale handwriting image (uint8)
        interval_at_h80: Grid interval at height=80 (default: 26)
        std_at_h80: Standard deviation for displacement (default: 1.7)
        target_h: Reference height (default: 80)
    Returns:
        Warped image 
    """
    h, w = image.shape
    scale = h / target_h
    interval = max(4, int(round(interval_at_h80 * scale)))
    disp_std = float(std_at_h80 * scale)

    gx = list(range(0, w, interval))
    if gx[-1] != w - 1:
        gx.append(w - 1)
    gy = list(range(0, h, interval))
    if gy[-1] != h - 1:
        gy.append(h - 1)

    src_pts = np.array([[x, y] for y in gy for x in gx], dtype=np.float32)
    dst_pts = src_pts + np.random.normal(0, disp_std, src_pts.shape).astype(np.float32)

    grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))

    # ðŸ”¹ Linear interpolation
    disp_xi = griddata(
        (src_pts[:, 0], src_pts[:, 1]),
        dst_pts[:, 0] - src_pts[:, 0],
        (grid_x, grid_y),
        method="linear",
        fill_value=0.0
    )
    disp_yi = griddata(
        (src_pts[:, 0], src_pts[:, 1]),
        dst_pts[:, 1] - src_pts[:, 1],
        (grid_x, grid_y),
        method="linear",
        fill_value=0.0
    )

    map_x = (grid_x + disp_xi).astype(np.float32)
    map_y = (grid_y + disp_yi).astype(np.float32)

    warped = cv2.remap(
        image, map_x, map_y, interpolation=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE
    )
    return warped


In [None]:
# Three-channel preprocessing 
import numpy as np
import cv2
from skimage.filters import threshold_local

def howe_binarize(img: np.ndarray) -> np.ndarray:
    """
    Approximate Howe adaptive binarization.
    """
    img = np.clip(img, 0, 255).astype(np.uint8)
    h = img.shape[0]

    # Block size scaled with image height 
    block = int(min(max(15, h // 3), 55))  # keeps block size between 15â€“55
    offset = 10  # as per-paper 

    try:
        T = threshold_local(img, block_size=block, offset=offset)
        out = (img > T).astype(np.uint8) * 255
    except Exception as e:
        # fallback: if threshold_local fails (rare on small images)
        _, out = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return out

def otsu_binarize(img: np.ndarray) -> np.ndarray:
    """
    Standard Otsu global thresholding.
    """
    img = np.clip(img, 0, 255).astype(np.uint8)
    _, th = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return th

def make_three_channel(img: np.ndarray) -> np.ndarray:
    """
    Combine grayscale, Otsu-binarized, and Howe-binarized versions
    into a 3-channel (HÃ—WÃ—3) image for CNN input.
    """
    if img.ndim == 3 and img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = np.clip(img, 0, 255).astype(np.uint8)

    g = img
    b1 = otsu_binarize(img)
    b2 = howe_binarize(img)
    stacked = np.stack([g, b1, b2], axis=2)
    return stacked


In [None]:
# Class HandWritingDataset
import os
import cv2
import numpy as np
from torch.utils.data import Dataset

class HandwritingDataset(Dataset):
    def __init__(self, image_paths, trans_map, height=80, augment_type='rwgd',
                 apply_pn=True, channel_mode='gray', dataset_type='line', seg_error_set=None):
        """
        PyTorch dataset for handwriting line/word recognition.
        
        Args:
            image_paths: list of image file paths
            trans_map: dict basename -> ground truth transcription
            height: target image height (default 80)
            augment_type: 'none' | 'rwgd' | 'simard' | 'affine'
            channel_mode: 'gray' or 'three'
            dataset_type: 'line' or 'word'
            seg_error_set: optional set of basenames to skip (bad segmentations)
        """
        self.paths = image_paths
        self.trans = trans_map
        self.h = height
        self.augment_type = augment_type
        self.apply_pn = apply_pn
        self.channel_mode = channel_mode
        self.dataset_type = dataset_type
        self.seg_error_set = seg_error_set if seg_error_set is not None else set()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        b = os.path.basename(p)

        # Skip bad segmentation 
        if b in self.seg_error_set:
            img = np.ones((self.h, self.h), dtype=np.uint8) * 255
            gt = ""
            return img, gt, b

        # Load image 
        img = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
        if img is None or img.size == 0:
            raise RuntimeError(f"Cannot read image: {p}")

        #  Profile Normalization
        if self.apply_pn:
            try:
                img = profile_normalization(img, target_h=self.h)
            except Exception as e:
                print(f"PN failed for {b}: {e}")
                img = cv2.resize(img, (max(1, int(img.shape[1])), self.h), interpolation=cv2.INTER_AREA)

        # Resize height to match model input 
        img = cv2.resize(img, (max(1, int(img.shape[1])), self.h), interpolation=cv2.INTER_AREA)

        #Channel construction
        if self.channel_mode == 'three':
            img_ch = make_three_channel(img)  
        else:
            img_ch = img 

        # Augmentation 
        if self.augment_type == 'rwgd':
            if self.channel_mode == 'three':
                for c in range(3):
                    img_ch[:, :, c] = RWGD_light(img_ch[:, :, c])
            else:
                img_ch = RWGD_light(img_ch)

        elif self.augment_type == 'simard':
            if self.channel_mode == 'three':
                for c in range(3):
                    img_ch[:, :, c] = simard_wrapper(img_ch[:, :, c])
            else:
                img_ch = simard_wrapper(img_ch)

        elif self.augment_type == 'affine':
            if self.channel_mode == 'three':
                for c in range(3):
                    img_ch[:, :, c] = affine_rotate_shear(img_ch[:, :, c])
            else:
                img_ch = affine_rotate_shear(img_ch)

        #Robust ground-truth lookup fix (to stop 0 loss)
        key= normalize_key(os.path.splitext(b)[0])
        if key in self.trans:
            gt = self.trans[key]
        elif b in self.trans:
            gt = self.trans[b]
        else:
            gt = ""

        return img_ch, gt, b


In [None]:
# Locate the Images folder
import glob, os

IMAGES_DIRS = glob.glob(os.path.join(DATA_ROOT, "**", "Images"), recursive=True)
if IMAGES_DIRS:
    IMAGES_DIR = IMAGES_DIRS[0]
else:
    candidates = []
    for root, dirs, files in os.walk(DATA_ROOT):
        imgs = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))]
        if len(imgs) > 20:
            candidates.append(root)
    if candidates:
        IMAGES_DIR = candidates[0]
    else:
        raise RuntimeError(" No image folder found under DATA_ROOT. Please check extraction structure.")

print(f" Using IMAGES_DIR: {IMAGES_DIR}")


In [None]:
import string

def is_punctuation_only(text: str) -> bool:
    """
    Returns True if the text contains ONLY punctuation or whitespace.
    """
    text = text.strip()
    if not text:
        return True
    return all(ch in string.punctuation for ch in text)


In [None]:
# DATASET SPLITTING FUNCTION

import os, glob, random

def build_splits(
    data_root, images_dir, trans_map, 
    test_filter_punct=True, 
    split=(0.8, 0.1, 0.1), 
    seed=42, 
    min_test_size=100
):
    random.seed(seed)

    #  Collect images
    images = sorted([
        p for p in glob.glob(os.path.join(images_dir, "*"))
        if p.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))
    ])

    if not images:
        raise RuntimeError(f"No image files found in {images_dir}")

    #  Deterministic shuffle 
    random.shuffle(images)
    n = len(images)
    n1 = int(n * split[0])
    n2 = int(n * (split[0] + split[1]))

    train = images[:n1]
    val = images[n1:n2]
    test = images[n2:]

    #  Load segmentation error flags 
    seg_error_set = set()
    possible_flag_files = glob.glob(os.path.join(data_root, "**", "*seg*err*.*"), recursive=True) + \
                          glob.glob(os.path.join(data_root, "**", "*segmentation*.*"), recursive=True)

    for ff in possible_flag_files:
        try:
            with open(ff, encoding='utf-8', errors='ignore') as f:
                for ln in f:
                    ln = ln.strip()
                    if ln:
                        seg_error_set.add(os.path.basename(ln.split()[0]))
        except:
            pass

    #  Filter TEST set 
    if test_filter_punct:
        test_filtered = []

        for p in test:
            b = os.path.basename(p)
            key = normalize_key(os.path.splitext(b)[0])   # correct key
            gt = trans_map.get(key, "")                  # correct GT from map

            # Remove punctuation-only or too-short GT
            if len(gt.strip()) < 2 or is_punctuation_only(gt):
                continue

            if b in seg_error_set:
                continue

            test_filtered.append(p)

        # Guarantee non-empty test set
        if len(test_filtered) < min_test_size:
            print(f"Too few test samples ({len(test_filtered)}). Restoring unfiltered test set.")
        else:
            test = test_filtered

    print(f"Split complete: {len(train)} train, {len(val)} val, {len(test)} test | Seg errors: {len(seg_error_set)}")
    return train, val, test, seg_error_set


In [None]:
train_list, val_list, test_list, seg_error_set = build_splits(
    DATA_ROOT, IMAGES_DIR, TRANS_MAP
)


In [None]:
# Ensure charset builder exists
def build_charset(trans_list):
    """Builds character vocabulary from transcriptions."""
    chars = set()
    for t in trans_list:
        if not isinstance(t, str):
            continue
        chars.update(list(t))
    chars = sorted(list(chars))
    itos = [''] + chars  # index-to-symbol (0 reserved for blank)
    stoi = {c: i for i, c in enumerate(itos)}
    return stoi, itos


In [None]:
# Cell 10 â€” CRNN model
import torch
import torch.nn as nn

class CRNN(nn.Module):
    def __init__(self, num_classes, in_channels=1, dropout=0.5):
        super().__init__()

        chs = [in_channels, 64, 128, 256, 256, 512, 512]
        layers = []
        for i in range(1, len(chs)):
            layers.append(nn.Conv2d(chs[i-1], chs[i], kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))

            if i in [1, 2]:
                layers.append(nn.MaxPool2d(2, 2))         # downscale both H and W
            elif i in [4, 6]:
                layers.append(nn.MaxPool2d((2, 1), (2, 1)))  # downscale only H

            if i in [4, 5]:
                layers.append(nn.BatchNorm2d(chs[i]))

        self.cnn = nn.Sequential(*layers)
        #adaptive average pooling to flatten height to 1
        self.reduce_h = nn.AdaptiveAvgPool2d((1, None))

        #BiLSTMs
        self.rnn1 = nn.LSTM(512, 512, num_layers=1, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.rnn2 = nn.LSTM(1024, 256, num_layers=1, bidirectional=True, batch_first=True)

        #Output
        self.fc = nn.Linear(512, num_classes)
        self.logsoft = nn.LogSoftmax(dim=2)

    def forward(self, x):
        out = self.cnn(x)                   # [B, C, H, W]
        out = self.reduce_h(out)            # [B, C, 1, W]
        out = out.squeeze(2)                # [B, C, W]
        out = out.permute(0, 2, 1)          # [B, W, C]

        out, _ = self.rnn1(out)
        out = self.dropout(out)
        out, _ = self.rnn2(out)
        out = self.dropout(out)
        out = self.fc(out)
        out = self.logsoft(out)
        return out


# GPU setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Example test run 
stoi, itos = build_charset(list(TRANS_MAP.values()))
num_classes = len(itos)

model = CRNN(num_classes=num_classes, in_channels=3, dropout=0.5).to(device)

# Test dummy input 
dummy = torch.randn(2, 3, 80, 400).to(device)
with torch.no_grad():
    out = model(dummy)

print("Model output shape:", out.shape)
# [batch=2, width=??, classes=num_classes]


In [None]:
#Collate, encode/decode, lexicon-based scoring, evaluate with TTA
import os
import numpy as np
import torch
import torch.nn as nn
import textdistance                        
from collections import Counter


# Collation (pad/truncate widths)

def collate_batch(batch, max_width=1600):
    imgs, gts, keys = zip(*batch)
    first = imgs[0]
    # multi-channel case (H x W x C)
    if first.ndim == 3:
        C = first.shape[2]
        H = first.shape[0]
        W = min(max(img.shape[1] for img in imgs), max_width)
        tensor = torch.zeros((len(imgs), C, H, W), dtype=torch.float32)  
        for i, im in enumerate(imgs):
            im_proc = im
            if im_proc.ndim == 2:
                im_proc = np.expand_dims(im_proc, 2)
            h,w,c = im_proc.shape
            w = min(w, W)
            
            im_norm = (255.0 - im_proc[:,:w,:]).astype(np.float32) / 255.0
            tensor[i, :, :h, :w] = torch.tensor(im_norm.transpose(2,0,1))
        return tensor, list(gts), list(keys)
    else:
        
        H = first.shape[0]
        W = min(max(img.shape[1] for img in imgs), max_width)
        tensor = torch.zeros((len(imgs), 1, H, W), dtype=torch.float32)
        for i, im in enumerate(imgs):
            w = min(W, im.shape[1])
            im_norm = (255.0 - im[:, :w]).astype(np.float32) / 255.0
            tensor[i,0,:im_norm.shape[0],:im_norm.shape[1]] = torch.tensor(im_norm)
        return tensor, list(gts), list(keys)


# Text encoding/decoding helpers

def encode_text(s, stoi):
    s = normalize_text(s)
    return [stoi.get(ch, 0) for ch in s]

def greedy_decode(log_probs): 
    idx = torch.argmax(log_probs, dim=2)
    result = []
    for row in idx:
        prev = None
        out = []
        for a in row.cpu().numpy().tolist():
            if a != prev and a != 0:
                out.append(a)
            prev = a
        result.append(out)
    return result

def indices_to_text(indices, itos):
    return ''.join([itos[i] for i in indices])


# Distance / metrics (use textdistance)

def cer(ref, hyp):
    # character-level Levenshtein
    return textdistance.levenshtein.distance(ref, hyp) / max(1, len(ref))

def wer(ref, hyp):
    r = ref.split(); h = hyp.split()
    if len(r) == 0:
        return 0.0 if len(h) == 0 else 1.0

    return textdistance.levenshtein.distance(r, h) / max(1, len(r))


# CTC loss and lexicon helpers

ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)

def prune_lexicon_by_edit(greedy_txt, lexicon, max_dist=4):
    candidates = []
    for w in lexicon:
        if abs(len(w) - len(greedy_txt)) > max_dist:
            continue
        d = textdistance.levenshtein.distance(greedy_txt, w)
        if d <= max_dist:
            candidates.append((d,w))
    candidates.sort()
    return [w for _,w in candidates]

def best_lexicon_ctc_choice(log_probs_tensor, lexicon, stoi, max_prune_dist=4):
    
    if log_probs_tensor.dim() == 3:
        logp = log_probs_tensor[0]
    else:
        logp = log_probs_tensor
    # greedy text
    greedy_idx = torch.argmax(logp, dim=1).cpu().numpy().tolist()
    prev = None
    greedy_inds = []
    for a in greedy_idx:
        if a != prev and a != 0:
            greedy_inds.append(a)
        prev = a
    
    greedy_txt = indices_to_text(greedy_inds, itos)
    candidates = prune_lexicon_by_edit(greedy_txt, lexicon, max_prune_dist)
    if not candidates:
        return greedy_txt, None
    best_w, best_loss = None, float('inf')
    T, C = logp.shape
    input_len = torch.tensor([T], dtype=torch.long)
    for w in candidates:
        lab = torch.tensor([stoi.get(ch,0) for ch in w], dtype=torch.long)
        # logp: T x C -> needs shape (T, C) and ctc expects (T, N, C) or (N, T, C) depending; we adapt:
        # Our ctc usage: log_probs.unsqueeze(1) (T x C -> 1 x T x C) and labs shaped accordingly below
        try:
            loss = ctc_loss_fn(logp.unsqueeze(1), lab.unsqueeze(0), input_len, torch.tensor([len(lab)], dtype=torch.long))
        except Exception:
            # fallback if shapes mismatch; skip this candidate
            continue
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_w = w
    return best_w, best_loss


# Evaluation with optional TTA and lexicon scoring

def evaluate_model(model, dataloader, stoi, itos, device, dataset_type='line',
                   lexicon=None, test_time_N=1, use_rwgd_on_test=True):
    model.eval()
    total_cer = 0.0
    total_wer = 0.0
    n = 0
    with torch.no_grad():
        for batch in dataloader:
            imgs, gts, keys = batch
            B = imgs.size(0)
            for i in range(B):
                
                if imgs.dim() == 4:
                    
                    if imgs.size(1) > 1:
                        im_np = (255.0 - imgs[i].cpu().numpy()*255.0).transpose(1,2,0)
                        if isinstance(im_np, np.ndarray) and im_np.ndim==3 and im_np.shape[2]==1:
                            im_np = im_np[:,:,0]
                    else:
                        im_np = (255.0 - imgs[i,0].cpu().numpy()*255.0)
                else:
                    im_np = imgs[i].cpu().numpy()

                # Run TTA variants
                variants = []
                for k in range(test_time_N):
                    if use_rwgd_on_test and getattr(dataloader.dataset, "augment_type", "") == 'rwgd':
                        # apply memory-light RWGD on numpy image (single-channel or per-channel)
                        if isinstance(im_np, np.ndarray) and im_np.ndim==2:
                            v = RWGD_light(im_np.astype(np.uint8))
                        elif isinstance(im_np, np.ndarray) and im_np.ndim==3:
                            v = im_np.copy()
                            for ch in range(v.shape[2]):
                                v[:,:,ch] = RWGD_light(v[:,:,ch].astype(np.uint8))
                        else:
                            v = im_np
                    else:
                        v = im_np

                    # prepare tensor (C x H x W) normalized as model expects
                    if isinstance(v, np.ndarray) and v.ndim==2:
                        v_t = torch.tensor((255.0 - v).astype(np.float32)/255.0)[None,None].to(device)
                    else:
                        # H x W x C -> C x H x W
                        v_t = torch.tensor((255.0 - v).astype(np.float32)/255.0).transpose(2,0,1)[None].to(device)

                    out = model(v_t)  # 1 x T x C
                    variants.append(out.cpu())

                # Decide final prediction
                final_text = ""
                if lexicon is not None and dataset_type == 'word':
                    best_overall = None
                    best_loss = float('inf')
                    for v_out in variants:
                        cand, loss = best_lexicon_ctc_choice(v_out, lexicon, stoi, max_prune_dist=4)
                        if loss is not None and loss < best_loss:
                            best_loss = loss
                            best_overall = cand
                    if best_overall is None:
                        greedy_preds = [indices_to_text(greedy_decode(v)[0], itos) for v in variants]
                        final_text = Counter(greedy_preds).most_common(1)[0][0]
                    else:
                        final_text = best_overall
                else:
                    greedy_preds = [indices_to_text(greedy_decode(v)[0], itos) for v in variants]
                    final_text = Counter(greedy_preds).most_common(1)[0][0]

                ref = normalize_text(gts[i])
                hyp = final_text
                total_cer += cer(ref, hyp)
                total_wer += wer(ref, hyp)
                n += 1

    return total_cer / max(1, n), total_wer / max(1, n)


In [None]:
import os

non_empty = 0
empty = 0

for p in train_list[:200]:
    b = os.path.basename(p)
    key = os.path.splitext(b)[0]
    gt = TRANS_MAP.get(key, "")
    if gt.strip():
        non_empty += 1
    else:
        empty += 1

print(f" Non-empty transcriptions: {non_empty}")
print(f"Empty transcriptions: {empty}")


In [None]:
#  Profile Normalization 
import numpy as np, cv2, math, gc

def compute_horizontal_profile_std(img: np.ndarray) -> float:
    """Compute Ïƒ of the horizontal projection profile (line height estimate)."""
    inv = 255 - img
    profile = inv.sum(axis=1)
    total = profile.sum() + 1e-8
    if total < 1e-8:
        return 1.0
    idx = np.arange(len(profile), dtype=np.float32)
    mean = (profile * idx).sum() / total
    var = (profile * (idx - mean) ** 2).sum() / total
    return float(math.sqrt(var + 1e-8))

def profile_normalization(img: np.ndarray,
                          target_h: int = 80,
                          r: float = 1.75,
                          ref_baseline: float = 16.0) -> np.ndarray:
    """
    Lightweight profile normalization:
      â€¢ Scales image vertically by variance of its horizontal projection
      â€¢ Centers text vertically on a fixed-height canvas
      â€¢ Uses minimal temporary arrays for memory safety
    """
    try:
        sigma = compute_horizontal_profile_std(img)
        sigma = max(sigma, 1e-3)
        s = ref_baseline / (sigma * r)

        new_h = max(8, int(round(img.shape[0] * s)))
        new_w = max(1, int(round(img.shape[1] * s)))

        resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)

        inv = 255 - resized
        total = inv.sum() + 1e-8
        rows = np.arange(resized.shape[0], dtype=np.float32)
        r_mean = float((inv.sum(axis=1) * rows).sum() / total)

        canvas = np.full((target_h, resized.shape[1]), 255, dtype=np.uint8)
        target_r = (target_h - 1) / 2.0
        shift = int(round(target_r - r_mean))

        y0 = max(0, shift)
        y1 = min(target_h, shift + resized.shape[0])
        src0 = max(0, -shift)
        src1 = src0 + (y1 - y0)

        if y0 < y1 and src0 < src1:
            canvas[y0:y1, :] = resized[src0:src1, :]

        # explicit cleanup
        del resized, inv, rows
        gc.collect()
        return canvas

    except Exception as e:
        # fallback to simple resize
        print(f"PN fallback: {e}")
        return cv2.resize(img, (max(1, int(img.shape[1])), target_h),
                          interpolation=cv2.INTER_AREA)

print("profile_normalization ready.")


In [None]:

# TINY TRAINING PIPELINE (FAST â€” 40 samples)


import os, cv2, torch, glob, random
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# 1. SMALL DATASET CREATION


# We assume TRANS_MAP, train_list, val_list, test_list already exist.
# Just pick small subsets.

small_train = train_list[:40]     # 40 training images
small_val   = val_list[:5]        # 5 validation images
small_test  = test_list[:5]       # 5 test images

print("Using tiny dataset:")
print("Train:", len(small_train))
print("Val:", len(small_val))
print("Test:", len(small_test))



# 2. Minimal Normalization


def normalize_text(t):
    return t.strip() if isinstance(t, str) else ""


# 3. Tiny Dataset Loader (simple, no caching needed)


class TinyDataset(Dataset):
    def __init__(self, paths, trans_map, height=80):
        self.paths = paths
        self.trans = trans_map
        self.h = height

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        key = os.path.splitext(os.path.basename(p))[0]

        # Load grayscale
        img = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise RuntimeError(f"Cannot load image: {p}")

        # simple height resize
        scale = self.h / img.shape[0]
        new_w = int(img.shape[1] * scale)
        img = cv2.resize(img, (new_w, self.h))

        # Normalize 0â€“1, invert (text = white)
        img = (255 - img).astype(np.float32) / 255.0
        img = img[None, :, :]  # (1, H, W)

        gt = self.trans.get(key, "")
        return img, gt, key



# 4. Collate Function


def collate_batch(batch):
    imgs, gts, keys = zip(*batch)
    max_w = max(im.shape[2] for im in imgs)
    padded = torch.zeros((len(imgs), 1, imgs[0].shape[1], max_w))

    for i, im in enumerate(imgs):
        w = im.shape[2]
        padded[i, :, :, :w] = torch.tensor(im)

    return padded, list(gts), list(keys)



# 5. Build Charset (stoi, itos)


def build_charset(texts):
    chars = set()
    for t in texts:
        chars.update(list(t))
    chars = sorted(list(chars))
    itos = [''] + chars
    stoi = {c: i for i, c in enumerate(itos)}
    return stoi, itos

train_texts = [normalize_text(TRANS_MAP.get(os.path.splitext(os.path.basename(p))[0], "")) 
               for p in small_train]
stoi, itos = build_charset(train_texts)
num_classes = len(itos)



# 6. CRNN Model (Simplified)


class SimpleCRNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )
        self.rnn = nn.LSTM(128*40, 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, num_classes)
        self.logsoft = nn.LogSoftmax(dim=2)

    def forward(self, x):
        out = self.cnn(x)                 # [B,128,40,W/2]
        B,C,H,W = out.shape
        out = out.permute(0,3,1,2)        # [B,W/2,C,H]
        out = out.reshape(B, W, C*H)      # [B,T,features]
        out,_ = self.rnn(out)
        out = self.fc(out)
        return out



# 7. Encoding + Greedy Decoding


def encode_text(t, stoi):
    return [stoi.get(c, 0) for c in t]

def greedy_decode(logits):
    idx = torch.argmax(logits, dim=2)[0].cpu().tolist()
    prev = None
    seq = []
    for i in idx:
        if i != prev and i != 0:
            seq.append(i)
        prev = i
    return ''.join(itos[i] for i in seq)



# 8. Train for 2 epochs (FAST)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleCRNN(num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
ctc = nn.CTCLoss(blank=0, zero_infinity=True)

train_ds = TinyDataset(small_train, TRANS_MAP)
val_ds   = TinyDataset(small_val, TRANS_MAP)
test_ds  = TinyDataset(small_test, TRANS_MAP)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_batch)
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collate_batch)

for epoch in range(2):
    model.train()
    total_loss = 0

    for imgs, gts, _ in train_loader:
        imgs = imgs.to(device)
        logits = model(imgs)
        B,T,C = logits.shape

        targets = []
        lengths = []
        for gt in gts:
            enc = encode_text(gt, stoi)
            if enc:
                targets.extend(enc)
                lengths.append(len(enc))

        if not lengths:
            continue

        targets = torch.tensor(targets, dtype=torch.long).to(device)
        input_len = torch.tensor([T]*len(lengths), dtype=torch.long)
        target_len = torch.tensor(lengths, dtype=torch.long)

        loss = ctc(logits.permute(1,0,2), targets, input_len, target_len)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss:.4f}")



# 9. Evaluate on tiny test set


def CER(a, b):
    import textdistance
    return textdistance.levenshtein.distance(a, b) / max(1, len(a))

def WER(a, b):
    import textdistance
    return textdistance.levenshtein.distance(a.split(), b.split()) / max(1, len(a.split()))

cer_total, wer_total, count = 0,0,0

print("\nPredictions on tiny test set:\n")

model.eval()
with torch.no_grad():
    for imgs, gts, keys in test_loader:
        imgs = imgs.to(device)
        out = model(imgs)
        pred = greedy_decode(out)

        gt = gts[0]

        print("Image:", keys[0])
        print("GT:   ", gt)
        print("Pred: ", pred)
        print("-"*40)

        cer_total += CER(gt, pred)
        wer_total += WER(gt, pred)
        count += 1

print("\nFINAL TINY RESULTS:")
print("CER:", cer_total / count)
print("WER:", wer_total / count)


In [None]:
#  FINAL TRAINING SCRIPT 

import os, re, glob, gc, math, time
from typing import Dict
import random

import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch import amp


#  PROFILE NORMALIZATION 

def compute_horizontal_profile_std(img: np.ndarray) -> float:
    inv = 255 - img
    profile = inv.sum(axis=1)
    total = profile.sum() + 1e-8
    if total < 1e-8:
        return 1.0
    idx = np.arange(len(profile), dtype=np.float32)
    mean = (profile * idx).sum() / total
    var = (profile * (idx - mean) ** 2).sum() / total
    return float(math.sqrt(var + 1e-8))

def profile_normalization(img: np.ndarray, target_h=80, r=1.75, ref_baseline=16.0):
    try:
        sigma = compute_horizontal_profile_std(img)
        sigma = max(sigma, 1e-3)
        s = ref_baseline / (sigma * r)
        new_h = max(8, int(round(img.shape[0] * s)))
        new_w = max(1, int(round(img.shape[1] * s)))
        resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)

        inv = 255 - resized
        total = inv.sum() + 1e-8
        rows = np.arange(resized.shape[0], dtype=np.float32)
        r_mean = float((inv.sum(axis=1) * rows).sum() / total)

        canvas = np.full((target_h, resized.shape[1]), 255, dtype=np.uint8)
        target_r = (target_h - 1) / 2.0
        shift = int(round(target_r - r_mean))
        y0, y1 = max(0, shift), min(target_h, shift + resized.shape[0])
        src0, src1 = max(0, -shift), max(0, -shift) + (y1 - y0)

        if y0 < y1 and src0 < src1:
            canvas[y0:y1, :] = resized[src0:src1, :]

        del resized, inv, rows
        gc.collect()
        return canvas
    except Exception as e:
        print(f"PN fallback: {e}")
        return cv2.resize(img, (max(1, int(img.shape[1])), target_h), interpolation=cv2.INTER_AREA)

print("Profile Normalization ready.")


#  TRANSCRIPTION LOADER  

def normalize_text(t):
    if t is None:
        return ""
    t = re.sub(r"\s+", " ", t.strip())
    return t

def load_transcriptions(data_root: str) -> Dict[str, str]:
    """
    Load only the per-image transcription files from Transcriptions/ folder.
    This avoids reading Set files (TrainLines/TestLines) that contain filenames only.
    """
    trans_map = {}
    trans_files = glob.glob(os.path.join(data_root, "**", "Transcriptions", "*.txt"), recursive=True)
    if not trans_files:
        # fallback: sometimes transcriptions are saved in a top-level Transcriptions folder name variant
        trans_files = glob.glob(os.path.join(data_root, "**", "*Transcriptions*", "*.txt"), recursive=True)

    for p in sorted(trans_files):
        try:
            txt = open(p, "r", encoding="utf-8", errors="ignore").read().strip()
        except Exception:
            txt = ""
        base = os.path.splitext(os.path.basename(p))[0]
        # remove surrounding quotes if any
        if txt.startswith('"') and txt.endswith('"'):
            txt = txt[1:-1]
        trans_map[base] = normalize_text(txt)
    print(f"Loaded {len(trans_map)} valid transcription files from Transcriptions/.")
    return trans_map

#  Ensure DATA_ROOT 
try:
    DATA_ROOT
except NameError:
    DATA_ROOT = r"E:\PYTHON\CV_Project2\dataset\extracted"
    print("DATA_ROOT not found â€” using default:", DATA_ROOT)

TRANS_MAP = load_transcriptions(DATA_ROOT)

non_empty = sum(1 for v in TRANS_MAP.values() if v.strip())
empty = sum(1 for v in TRANS_MAP.values() if not v.strip())
print(f"TRANS_MAP: {len(TRANS_MAP)} entries | Non-empty: {non_empty} | Empty: {empty}")


#  DATASET & CACHE CLASS 

class HandwritingDataset(Dataset):
    def __init__(self, image_paths, trans_map, height=80, apply_pn=True, cache_dir="cache_png", seg_error_set=None):
        self.paths = image_paths
        self.trans = trans_map
        self.h = height
        self.apply_pn = apply_pn
        self.cache_dir = cache_dir
        self.seg_error_set = seg_error_set if seg_error_set else set()
        if self.cache_dir:
            os.makedirs(self.cache_dir, exist_ok=True)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        key = os.path.splitext(os.path.basename(p))[0]

        if key in self.seg_error_set:
            return np.ones((self.h, self.h), np.uint8) * 255, "", key

        cache_path = os.path.join(self.cache_dir, key + ".npy")
        if os.path.exists(cache_path):
            img = np.load(cache_path)
        else:
            img = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
            if img is None:
                raise RuntimeError(f"Cannot read image: {p}")
            if self.apply_pn:
                img = profile_normalization(img, target_h=self.h)
            img = cv2.resize(img, (max(1, int(img.shape[1])), self.h), interpolation=cv2.INTER_AREA)
            np.save(cache_path, img)
        gt = self.trans.get(key, "")
        return img, gt, key

print("HandwritingDataset ready (cached).")


#  COLLATE & HELPERS  

def collate_batch(batch, max_width=1600):
    imgs, gts, keys = zip(*batch)
    first = imgs[0]
    # multi-channel not used (we're single-channel)
    H = first.shape[0]
    W = min(max(img.shape[1] for img in imgs), max_width)
    tensor = torch.zeros((len(imgs), 1, H, W), dtype=torch.float32)
    for i, im in enumerate(imgs):
        w = min(W, im.shape[1])
        im_norm = (255.0 - im[:, :w]).astype(np.float32) / 255.0
        tensor[i,0,:im_norm.shape[0],:im_norm.shape[1]] = torch.tensor(im_norm)
    return tensor, list(gts), list(keys)

def encode_text(s, stoi):
    s = normalize_text(s)
    return [stoi.get(ch, 0) for ch in s]

def greedy_decode(log_probs):  # log_probs: B x T x C
    idx = torch.argmax(log_probs, dim=2)
    result = []
    for row in idx:
        prev = None
        out = []
        for a in row.cpu().numpy().tolist():
            if a != prev and a != 0:
                out.append(a)
            prev = a
        result.append(out)
    return result

def indices_to_text(indices, itos):
    return ''.join([itos[i] for i in indices])

# Levenshtein-based metrics 
import textdistance
def cer(ref, hyp):
    return textdistance.levenshtein.distance(ref, hyp) / max(1, len(ref))

def wer(ref, hyp):
    r = ref.split(); h = hyp.split()
    if len(r) == 0:
        return 0.0 if len(h) == 0 else 1.0
    return textdistance.levenshtein.distance(r, h) / max(1, len(r))

# simple charset builde
def build_charset(trans_list):
    chars = set()
    for t in trans_list:
        if not isinstance(t, str):
            continue
        chars.update(list(t))
    chars = sorted(list(chars))
    itos = [''] + chars
    stoi = {c:i for i,c in enumerate(itos)}
    return stoi, itos


# CRNN model 

class CRNN(nn.Module):
    def __init__(self, num_classes, in_channels=1, dropout=0.5):
        super().__init__()
        chs = [in_channels, 64, 128, 256, 256, 512, 512]
        layers = []
        for i in range(1, len(chs)):
            layers.append(nn.Conv2d(chs[i-1], chs[i], kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
            if i in [1,2]:
                layers.append(nn.MaxPool2d(2,2))
            elif i in [4,6]:
                layers.append(nn.MaxPool2d((2,1),(2,1)))
            if i in [4,5]:
                layers.append(nn.BatchNorm2d(chs[i]))
        self.cnn = nn.Sequential(*layers)
        self.reduce_h = nn.AdaptiveAvgPool2d((1, None))
        self.rnn1 = nn.LSTM(512, 512, num_layers=1, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.rnn2 = nn.LSTM(1024, 256, num_layers=1, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, num_classes)
        self.logsoft = nn.LogSoftmax(dim=2)
    def forward(self, x):
        out = self.cnn(x)                   # [B,C,H,W]
        out = self.reduce_h(out)            # [B,C,1,W]
        out = out.squeeze(2)                # [B,C,W]
        out = out.permute(0,2,1)            # [B,W,C]
        out, _ = self.rnn1(out)
        out = self.dropout(out)
        out, _ = self.rnn2(out)
        out = self.dropout(out)
        out = self.fc(out)
        out = self.logsoft(out)
        return out


#  CTC loss & eval helpers 

ctc_loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)

def evaluate_model(model, dataloader, stoi, itos, device, dataset_type='line',
                   lexicon=None, test_time_N=1, use_rwgd_on_test=False):
    model.eval()
    total_cer = 0.0
    total_wer = 0.0
    n = 0
    with torch.no_grad():
        for batch in dataloader:
            imgs, gts, keys = batch
            B = imgs.size(0)
            for i in range(B):
                # prepare single input
                if imgs.size(1) > 1:
                    im_np = (255.0 - imgs[i].cpu().numpy()*255.0).transpose(1,2,0)
                    if isinstance(im_np, np.ndarray) and im_np.ndim==3 and im_np.shape[2]==1:
                        im_np = im_np[:,:,0]
                else:
                    im_np = (255.0 - imgs[i,0].cpu().numpy()*255.0)
                # prepare tensor
                v_t = torch.tensor((im_np).astype(np.float32)/255.0)[None,None].to(device)
                out = model(v_t)
                # decode
                greedy_inds = greedy_decode(out.cpu())[0]
                hyp = indices_to_text(greedy_inds, itos)
                ref = normalize_text(gts[i])
                total_cer += cer(ref, hyp)
                total_wer += wer(ref, hyp)
                n += 1
    return total_cer / max(1, n), total_wer / max(1, n)


# === TRAIN/EVAL LOOPS 

def train_one_epoch(model, loader, optimizer, stoi, device, scaler, log_interval=200):
    model.train()
    total_loss, batches = 0.0, 0
    for batch_idx, batch in enumerate(loader):
        imgs, gts, _ = batch
        imgs = imgs.to(device, non_blocking=True)
        with amp.autocast(device_type='cuda', enabled=(device.type=='cuda')):
            out = model(imgs)             # B x T x C
            B, T, C = out.shape
            targets, target_lens, valid_idx = [], [], []
            for i, gt in enumerate(gts):
                lab = encode_text(gt, stoi)
                if len(lab) > 0:
                    valid_idx.append(i)
                    targets.extend(lab)
                    target_lens.append(len(lab))
            if not valid_idx:
                continue
            out_valid = out[valid_idx]             # NxTxc
            targets = torch.tensor(targets, dtype=torch.long, device=device)
            input_lens = torch.full((len(valid_idx),), T, dtype=torch.long, device=device)
            target_lens = torch.tensor(target_lens, dtype=torch.long, device=device)
            loss = ctc_loss_fn(out_valid.permute(1,0,2), targets, input_lens, target_lens)
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        batches += 1
        if (batch_idx + 1) % log_interval == 0:
            print(f"  Batch {batch_idx+1}: loss={loss.item():.4f}")
        if (batch_idx + 1) % 50 == 0:
            if device.type == "cuda":
                torch.cuda.empty_cache()
            gc.collect()
    return total_loss / max(1, batches)


# RUN EXPERIMENT
def run_experiment_timed(augment_type='none', epochs=10, batch_size=8,
                         use_pn=True, channel_mode='gray', fast_mode=False,
                         patience=3, min_epochs=4):
    print(f"\n=== Experiment: augment={augment_type}, channel={channel_mode} ===")
    global train_list, val_list, test_list, seg_error_set, DATA_TYPE
    if "DATA_TYPE" not in globals():
        DATA_TYPE = "word" if any("word" in d.lower() for d in os.listdir(DATA_ROOT)) else "line"
        print(f" Auto-detected DATA_TYPE = '{DATA_TYPE}'")
    if fast_mode:
        train_subset, val_subset, test_subset = train_list[:300], val_list[:100], test_list[:50]
        print(" Fast-mode enabled (small subsets).")
    else:
        train_subset, val_subset, test_subset = train_list, val_list, test_list
        print(f" Using full dataset: {len(train_subset)} train / {len(val_subset)} val / {len(test_subset)} test")

    train_ds = HandwritingDataset(train_subset, TRANS_MAP, height=80, apply_pn=use_pn, cache_dir="cache_png", seg_error_set=seg_error_set)
    val_ds = HandwritingDataset(val_subset, TRANS_MAP, height=80, apply_pn=use_pn, cache_dir="cache_png")
    test_ds = HandwritingDataset(test_subset, TRANS_MAP, height=80, apply_pn=use_pn, cache_dir="cache_png")

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_batch, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collate_batch, num_workers=0)

    train_texts = [TRANS_MAP.get(os.path.splitext(os.path.basename(p))[0], "") for p in train_subset]
    stoi, itos = build_charset(train_texts)
    num_classes = len(itos)
    in_channels = 1 if channel_mode=='gray' else 3

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CRNN(num_classes=num_classes, in_channels=in_channels).to(device).to(torch.float32)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    scaler = amp.GradScaler(enabled=(device.type=='cuda'))

    print(f" Device: {device} | Model params: {sum(p.numel() for p in model.parameters())/1e6:.2f} M")

    best_val = float('inf')
    best_epoch = -1
    for ep in range(epochs):
        print(f"\n--- Epoch {ep+1}/{epochs} ---")
        t0 = time.time()
        loss = train_one_epoch(model, train_loader, optimizer, stoi, device, scaler)
        t_elapsed = time.time() - t0
        print(f"Epoch {ep+1} done in {t_elapsed:.1f}s | Avg loss: {loss:.4f}")

        cer_v, wer_v = evaluate_model(model, val_loader, stoi, itos, device, dataset_type=DATA_TYPE)
        print(f" Validation: CER={cer_v:.4f}, WER={wer_v:.4f}")

        # scheduler uses loss/metric; use CER+WER average
        scheduler.step(cer_v + wer_v)

        # checkpoint best
        val_metric = cer_v + wer_v
        if val_metric < best_val:
            best_val = val_metric
            best_epoch = ep
            torch.save(model.state_dict(), f"crnn_best_epoch{ep+1}.pth")
            print("  Saved best model.")

        # early stop logic (ensure at least min_epochs)
        if (ep - best_epoch) >= patience and ep+1 >= min_epochs:
            print(f"Early stopping triggered (no improvement for {patience} epochs).")
            break

    # final test
    cer_t, wer_t = evaluate_model(model, test_loader, stoi, itos, device, dataset_type=DATA_TYPE)
    print(f"\n=== Final Test ===\nTest CER={cer_t:.4f}, WER={wer_t:.4f}")
    return model, (cer_t, wer_t), (stoi, itos)


# RUN with tuned params 
model, metrics, voc = run_experiment_timed(
    augment_type='none',     
    epochs=10,               
    batch_size=8,            
    use_pn=True,
    channel_mode='gray',
    fast_mode=False,         # full dataset
    patience=3,
    min_epochs=4
)
print("\nTraining complete â€” metrics:", metrics)
