In [None]:
import os
import math
import string
import random
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image, ImageOps, ImageFilter
from sklearn.model_selection import train_test_split

In [2]:
# ----------------------------
# 1) Character set & CTC codec
# ----------------------------
def default_charset():
    # You can customize this to match your data (e.g., only lowercase + space)
    # Keep space ' ' included if your lines contain spaces.
    charset = list(string.digits + string.ascii_letters + string.punctuation + ' ')
    # Remove characters you know you don't have, or add accents if needed.
    return charset

class CTCCodec:
    """
    Maps characters <-> indices. Index 0 is reserved for CTC blank.
    """
    def __init__(self, charset: List[str]):
        self.blank_idx = 0
        self.chars = ['<BLK>'] + charset
        self.char2idx = {c: i+1 for i, c in enumerate(charset)}  # shift by +1
        self.idx2char = {i+1: c for i, c in enumerate(charset)}

    def encode(self, text: str) -> torch.Tensor:
        return torch.tensor([self.char2idx[c] for c in text if c in self.char2idx], dtype=torch.long)

    def decode_greedy(self, logits: torch.Tensor) -> List[str]:
        """
        logits: (T, N, C) log-probs or raw scores. We'll argmax over classes.
        Returns list of length N with collapsed CTC decoding.
        """
        with torch.no_grad():
            pred = logits.argmax(dim=-1)  # (T, N)
            pred = pred.cpu().numpy()
        N = pred.shape[1]
        texts = []
        for n in range(N):
            seq = pred[:, n]
            prev = -1
            out = []
            for idx in seq:
                if idx != self.blank_idx and idx != prev:
                    out.append(self.idx2char.get(int(idx), ''))
                prev = idx
            texts.append(''.join(out))
        return texts

In [None]:
# ----------------------------------
# 2) Image transforms & augmentations
# ----------------------------------
class KeepRatioResize:
    """
    Resize PIL image to target height with proportional width, no crop.
    """
    def __init__(self, target_h: int):
        self.target_h = target_h

    def __call__(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        if h == self.target_h:
            return img
        new_w = max(1, round(w * (self.target_h / h)))
        return img.resize((new_w, self.target_h), Image.BILINEAR)

class ElasticLike:
    """
    Lightweight 'elastic' style warp using PIL perspective + slight blur/sharpen.
    Keeps text legible but varied.
    """
    def __init__(self, p=0.5, max_warp=0.08):
        self.p = p
        self.max_warp = max_warp

    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p:
            return img
        w, h = img.size
        dx = int(self.max_warp * w)
        dy = int(self.max_warp * h)
        # random offsets for corners
        src = [(0,0),(w,0),(w,h),(0,h)]
        img = img.transform((w + random.randint(-dx, dx), h + random.randint(-dx, dx)), Image.QUAD, src)
        if random.random() < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.2, 0.6)))
        if random.random() < 0.3:
            img = img.filter(ImageFilter.UnsharpMask(radius=1.0, percent=80, threshold=3))
        return img

def pil_to_tensor_normalized(img: Image.Image) -> torch.Tensor:
    """
    Convert PIL (grayscale) -> Tensor in [0,1], normalize to mean=0.5, std=0.5
    Output shape: (1, H, W)
    """
    t = transforms.functional.pil_to_tensor(img).float() / 255.0  # (1,H,W) for 'L'
    return transforms.functional.normalize(t, mean=[0.5], std=[0.5])

def binarize_if_needed(img: Image.Image, p=0.0):
    if p > 0 and random.random() < p:
        return img.convert('L').point(lambda x: 255 if x > 200 else 0, mode='L')
    return img

In [None]:
# ------------------------
# 3) Dataset definitions
# ------------------------
class LinesFile(Dataset):
    """
    labels.txt format: path<TAB>text (UTF-8)
    Converts to grayscale, resizes to H=64 with proportional width.
    """
    def __init__(self, file_path: str, codec: CTCCodec, target_h: int = 64, keep_aspect=True, binarize_p=0.0):
        super().__init__()
        self.samples = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.rstrip('\n')
                if not line.strip():
                    continue
                parts = line.split('\t', 1)
                if len(parts) != 2:
                    continue
                path, text = parts
                self.samples.append((path, text))
        self.codec = codec
        self.target_h = target_h
        self.keep_aspect = keep_aspect
        self.resize = KeepRatioResize(target_h)
        self.binarize_p = binarize_p

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, text = self.samples[idx]
        img = Image.open(path).convert('L')
        img = ImageOps.exif_transpose(img)
        if self.keep_aspect:
            img = self.resize(img)
        img = binarize_if_needed(img, self.binarize_p)
        tensor = pil_to_tensor_normalized(img)
        label = self.codec.encode(text)
        return tensor, label, text, os.path.basename(path)