In [1]:
import os
import math
import string
import random
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image, ImageOps, ImageFilter
from sklearn.model_selection import train_test_split

In [2]:
# ----------------------------
# 1) Character set & CTC codec
# ----------------------------
def default_charset():
    # You can customize this to match your data (e.g., only lowercase + space)
    # Keep space ' ' included if your lines contain spaces.
    charset = list(string.digits + string.ascii_letters + string.punctuation + ' ')
    # Remove characters you know you don't have, or add accents if needed.
    return charset

class CTCCodec:
    """
    Maps characters <-> indices. Index 0 is reserved for CTC blank.
    """
    def __init__(self, charset: List[str]):
        self.blank_idx = 0
        self.chars = ['<BLK>'] + charset
        self.char2idx = {c: i+1 for i, c in enumerate(charset)}  # shift by +1
        self.idx2char = {i+1: c for i, c in enumerate(charset)}

    def encode(self, text: str) -> torch.Tensor:
        return torch.tensor([self.char2idx[c] for c in text if c in self.char2idx], dtype=torch.long)

    def decode_greedy(self, logits: torch.Tensor) -> List[str]:
        """
        logits: (T, N, C) log-probs or raw scores. We'll argmax over classes.
        Returns list of length N with collapsed CTC decoding.
        """
        with torch.no_grad():
            pred = logits.argmax(dim=-1)  # (T, N)
            pred = pred.cpu().numpy()
        N = pred.shape[1]
        texts = []
        for n in range(N):
            seq = pred[:, n]
            prev = -1
            out = []
            for idx in seq:
                if idx != self.blank_idx and idx != prev:
                    out.append(self.idx2char.get(int(idx), ''))
                prev = idx
            texts.append(''.join(out))
        return texts

In [3]:
# ----------------------------------
# 2) Image transforms & augmentations
# ----------------------------------
class KeepRatioResize:
    """
    Resize PIL image to target height with proportional width, no crop.
    """
    def __init__(self, target_h: int):
        self.target_h = target_h

    def __call__(self, img: Image.Image) -> Image.Image:
        w, h = img.size
        if h == self.target_h:
            return img
        new_w = max(1, round(w * (self.target_h / h)))
        return img.resize((new_w, self.target_h), Image.BILINEAR)

class ElasticLike:
    """
    Lightweight 'elastic' style warp using PIL perspective + slight blur/sharpen.
    Keeps text legible but varied.
    """
    def __init__(self, p=0.5, max_warp=0.08):
        self.p = p
        self.max_warp = max_warp

    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() > self.p:
            return img
        w, h = img.size
        dx = int(self.max_warp * w)
        dy = int(self.max_warp * h)
        # random offsets for corners
        src = [(0,0),(w,0),(w,h),(0,h)]
        img = img.transform((w + random.randint(-dx, dx), h + random.randint(-dx, dx)), Image.QUAD, src)
        if random.random() < 0.5:
            img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.2, 0.6)))
        if random.random() < 0.3:
            img = img.filter(ImageFilter.UnsharpMask(radius=1.0, percent=80, threshold=3))
        return img

def pil_to_tensor_normalized(img: Image.Image) -> torch.Tensor:
    """
    Convert PIL (grayscale) -> Tensor in [0,1], normalize to mean=0.5, std=0.5
    Output shape: (1, H, W)
    """
    t = transforms.functional.pil_to_tensor(img).float() / 255.0  # (1,H,W) for 'L'
    return transforms.functional.normalize(t, mean=[0.5], std=[0.5])

def binarize_if_needed(img: Image.Image, p=0.0):
    if p > 0 and random.random() < p:
        return img.convert('L').point(lambda x: 255 if x > 200 else 0, mode='L')
    return img

In [4]:
# ------------------------
# 3) Dataset definitions
# ------------------------
class LinesFile(Dataset):
    """
    labels.txt format: path<TAB>text (UTF-8)
    Converts to grayscale, resizes to H=64 with proportional width.
    """
    def __init__(self, labels_file: str, codec: CTCCodec, target_h: int = 64, keep_aspect=True, binarize_p=0.0):
        super().__init__()
        self.samples = []
        self.folder_path = os.path.dirname(labels_file)
        with open(labels_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.rstrip('\n')
                if not line.strip():
                    continue
                parts = line.split('\t', 1)
                if len(parts) != 2:
                    continue
                path, text = parts
                self.samples.append((path, text))
        self.codec = codec
        self.target_h = target_h
        self.keep_aspect = keep_aspect
        self.resize = KeepRatioResize(target_h)
        self.binarize_p = binarize_p

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file, text = self.samples[idx]
        path = os.path.join(self.folder_path, file)
        img = Image.open(path).convert('L')
        img = ImageOps.exif_transpose(img)
        if self.keep_aspect:
            img = self.resize(img)
        img = binarize_if_needed(img, self.binarize_p)
        tensor = pil_to_tensor_normalized(img)
        label = self.codec.encode(text)
        return tensor, label, text, os.path.basename(path)

In [5]:
class AugmentedWrapper(Dataset):
    """
    Wrap a base dataset and apply strong augmentations.
    Use multiple instances of this wrapper to grow dataset size to >=5x.
    """
    def __init__(self, base: LinesFile):
        super().__init__()
        self.base = base
        # Compose handwriting-friendly augmentations
        self.resize = KeepRatioResize(base.target_h)
        self.aug = transforms.RandomChoice([
            transforms.RandomAffine(degrees=2, translate=(0.02, 0.03), scale=(0.95, 1.05), shear=(-2, 2), fill=255),
            transforms.RandomPerspective(distortion_scale=0.3, p=1.0),
        ])
        self.colorjitter = transforms.ColorJitter(brightness=0.2, contrast=0.2)
        self.elastic = ElasticLike(p=0.7, max_warp=0.06)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        tensor, label, text, name = self.base[idx]
        # back to PIL to apply augmentations that expect PIL
        H = tensor.shape[1]
        W = tensor.shape[2]
        pil = transforms.functional.to_pil_image(((tensor * 0.5 + 0.5) * 255.0).byte())  # unnormalize for aug
        # augment in PIL space
        pil = self.elastic(pil)
        pil = self.aug(pil)
        pil = self.colorjitter(pil)
        # Small random Gaussian blur helps mimic scanning
        if random.random() < 0.3:
            pil = pil.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.2, 0.7)))
        # Ensure size back to desired height (keeps ratio)
        pil = self.resize(pil)
        # Occasionally invert (handwriting scans vary)
        if random.random() < 0.25:
            pil = ImageOps.invert(pil)
        tensor_aug = pil_to_tensor_normalized(pil)
        return tensor_aug, label, text, name

In [6]:
# ---------------------------------------
# 4) Collate: pad widths & build lengths
# ---------------------------------------
@dataclass
class Batch:
    imgs: torch.Tensor        # (B, 1, H, Wmax)
    labels: torch.Tensor      # (sum_targets,)
    label_lengths: torch.Tensor  # (B,)
    input_lengths: torch.Tensor  # (B,)  number of time steps per sample after CNN
    texts: List[str]
    names: List[str]
    orig_widths: List[int]

class PadCollate:
    """
    Pads each batch to max width (also to multiple of 4) and prepares CTC lengths.
    """
    def __init__(self, multiple_of: int = 4, height: int = 64):
        self.multiple_of = multiple_of
        self.height = height

    def __call__(self, batch):
        # batch: list of (tensor(1,H,W), label, text, name)
        imgs, labels, texts, names, widths = [], [], [], [], []
        for t, lab, txt, name in batch:
            _, h, w = t.size()
            assert h == self.height, f"Expected height {self.height}, got {h}"
            imgs.append(t)
            labels.append(lab)
            texts.append(txt)
            names.append(name)
            widths.append(w)

        B = len(imgs)
        max_w = max(widths)
        # pad to next multiple of self.multiple_of (for CNN width downsampling)
        if self.multiple_of > 1:
            max_w = int(math.ceil(max_w / self.multiple_of) * self.multiple_of)

        padded = torch.full((B, 1, self.height, max_w), fill_value=(0.5 - 0.5)/0.5, dtype=imgs[0].dtype)
        # Explanation: because we normalized to mean=0.5, std=0.5,
        # "white" (1.0) becomes (1-0.5)/0.5 = +1.0, "gray 0.5" is 0; but to avoid halo,
        # we can pad with normalized value of 1.0 (white) -> +1.0:
        padded.fill_(+1.0)

        for i, t in enumerate(imgs):
            _, _, w = t.size()
            padded[i, :, :, :w] = t

        labels_concat = torch.cat(labels, dim=0)
        label_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)

        # We'll use a CNN that downsamples width by factor 4 -> input_lengths = ceil(w/4)
        input_lengths = torch.tensor([math.ceil(w / 4) for w in widths], dtype=torch.long)

        return Batch(
            imgs=padded,
            labels=labels_concat,
            label_lengths=label_lengths,
            input_lengths=input_lengths,
            texts=texts,
            names=names,
            orig_widths=widths
        )

In [7]:
# ---------------------------
# 5) CRNN model (CNN + BiLSTM)
# ---------------------------
class CRNN(nn.Module):
    """
    CNN reduces H and W (width by 4x overall), then we pool height to 1 and treat width as time.
    """
    def __init__(self, num_classes: int, in_channels=1):
        super().__init__()
        # VGG-ish feature extractor; keep it simple and efficient
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # H/2, W/2

            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # H/4, W/4

            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True),
            # keep width stride=1 here

            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True),
        )
        # Collapse height to 1 with adaptive pooling; width stays ~W/4 due to the two pools above.
        self.height_pool = nn.AdaptiveAvgPool2d((1, None))

        self.rnn = nn.LSTM(
            input_size=256, hidden_size=256, num_layers=2,
            bidirectional=True, dropout=0.1, batch_first=False
        )
        self.fc = nn.Linear(512, num_classes)  # 2*hidden

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,1,H,W)
        f = self.features(x)                        # (B, C=256, H', W')
        f = self.height_pool(f).squeeze(2)          # (B, C, W')
        f = f.permute(2, 0, 1)                      # (T=W', B, C)
        out, _ = self.rnn(f)                        # (T, B, 2*H)
        logits = self.fc(out)                       # (T, B, num_classes)
        return logits

In [8]:
# --------------------------
# 6) Metrics: CER / WER
# --------------------------
def levenshtein(a: List[str], b: List[str]) -> int:
    # Levenshtein distance for lists of tokens (chars or words)
    dp = [[0]*(len(b)+1) for _ in range(len(a)+1)]
    for i in range(len(a)+1):
        dp[i][0] = i
    for j in range(len(b)+1):
        dp[0][j] = j
    for i in range(1, len(a)+1):
        for j in range(1, len(b)+1):
            cost = 0 if a[i-1]==b[j-1] else 1
            dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+cost)
    return dp[-1][-1]

def cer(ref: str, hyp: str) -> float:
    return levenshtein(list(ref), list(hyp)) / max(1, len(ref))

def wer(ref: str, hyp: str) -> float:
    return levenshtein(ref.split(), hyp.split()) / max(1, len(ref.split()))

In [9]:
# --------------------------
# 7) Training / Validation
# --------------------------
@dataclass
class TrainConfig:
    labels_file: str = os.path.join("images", "labels.txt")
    batch_size: int = 16
    epochs: int = 30
    lr: float = 1e-3
    num_workers: int = 4
    height: int = 64
    seed: int = 42
    aug_factor: int = 4   # original + 4× augmented = 5× total

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def make_dataloaders(cfg: TrainConfig, codec: CTCCodec):
    # Load all samples once
    all_data = []
    with open(cfg.labels_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.rstrip('\n')
            if not line.strip():
                continue
            parts = line.split('\t', 1)
            if len(parts) != 2:
                continue
            all_data.append(line)

    train_lines, val_lines = train_test_split(all_data, test_size=0.1, random_state=cfg.seed, shuffle=True)

    # Save temporary split files
    train_file = os.path.join("images", "train_split.txt")
    val_file = os.path.join("images", "val_split.txt")
    with open(train_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(train_lines))
    with open(val_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(val_lines))

    # Create datasets
    base_train = LinesFile(train_file, codec, target_h=cfg.height, keep_aspect=True)
    aug_wrappers = [AugmentedWrapper(base_train) for _ in range(cfg.aug_factor)]
    train_set = ConcatDataset([base_train] + aug_wrappers)

    val_set = LinesFile(val_file, codec, target_h=cfg.height, keep_aspect=True)

    collate = PadCollate(multiple_of=4, height=cfg.height)
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate)
    val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=False,
                            num_workers=cfg.num_workers, pin_memory=True, collate_fn=collate)
    return train_loader, val_loader

def train_one_epoch(model, loader, criterion, optimizer, device, codec: CTCCodec, log_interval=100):
    model.train()
    running_loss = 0.0
    for step, batch in enumerate(loader, 1):
        imgs = batch.imgs.to(device)
        labels = batch.labels.to(device)
        label_lengths = batch.label_lengths.to(device)
        input_lengths = batch.input_lengths.to(device)

        logits = model(imgs)  # (T, B, C)
        log_probs = logits.log_softmax(dim=-1)

        loss = criterion(log_probs, labels, input_lengths, label_lengths)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        running_loss += loss.item()
        if step % log_interval == 0:
            avg = running_loss / log_interval
            print(f"  step {step:5d} | train loss {avg:.4f}")
            running_loss = 0.0

def validate(model, loader, device, codec: CTCCodec):
    model.eval()
    total_loss = 0.0
    total_cer = 0.0
    total_wer = 0.0
    count = 0
    criterion = nn.CTCLoss(blank=codec.blank_idx, zero_infinity=True)

    with torch.no_grad():
        for batch in loader:
            imgs = batch.imgs.to(device)
            labels = batch.labels.to(device)
            label_lengths = batch.label_lengths.to(device)
            input_lengths = batch.input_lengths.to(device)

            logits = model(imgs)
            log_probs = logits.log_softmax(dim=-1)
            loss = criterion(log_probs, labels, input_lengths, label_lengths)
            total_loss += loss.item()

            # Greedy decode for metrics
            hyps = codec.decode_greedy(log_probs)
            for hyp, ref in zip(hyps, batch.texts):
                total_cer += cer(ref, hyp)
                total_wer += wer(ref, hyp)
                count += 1

    return {
        "loss": total_loss / max(1, len(loader)),
        "cer": total_cer / max(1, count),
        "wer": total_wer / max(1, count),
    }


In [15]:
def main():
    cfg = TrainConfig(
        labels_file=os.path.join("images", "labels.txt"),
        batch_size=16,
        epochs=30,
        lr=1e-3,
        num_workers=0,
        height=64,
        seed=42,
        aug_factor=5,
    )

    set_seed(cfg.seed)

    charset = default_charset()
    codec = CTCCodec(charset)
    train_loader, val_loader = make_dataloaders(cfg, codec)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 1 + len(charset)
    model = CRNN(num_classes=num_classes, in_channels=1).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=cfg.lr, steps_per_epoch=len(train_loader), epochs=cfg.epochs
    )
    criterion = nn.CTCLoss(blank=codec.blank_idx, zero_infinity=True)

    best_val = float("inf")
    for epoch in range(1, cfg.epochs + 1):
        print(f"\nEpoch {epoch}/{cfg.epochs}")
        train_one_epoch(model, train_loader, criterion, optimizer, device, codec, log_interval=100)
        metrics = validate(model, val_loader, device, codec)
        scheduler.step()

        print(f"  Val loss: {metrics['loss']:.4f} | CER: {metrics['cer']:.4f} | WER: {metrics['wer']:.4f}")

        if metrics['loss'] < best_val:
            best_val = metrics['loss']
            torch.save({
                "model": model.state_dict(),
                "codec_chars": codec.chars,
                "config": cfg.__dict__,
            }, "best_crnn_ctc.pth")
            print("  Saved checkpoint: best_crnn_ctc.pth")

if __name__ == "__main__":
    main()


Epoch 1/30




IndexError: list index out of range

**Inference code:**

In [11]:
# ---------- Load model checkpoint ----------
def load_model(checkpoint_path, device):
    ckpt = torch.load(checkpoint_path, map_location=device)
    codec = CTCCodec(ckpt["codec_chars"][1:])  # skip <BLK> token
    model = CRNN(num_classes=len(ckpt["codec_chars"]), in_channels=1).to(device)
    model.load_state_dict(ckpt["model"])
    model.eval()
    return model, codec

# ---------- Preprocess single image ----------
def preprocess_image(img_path, target_h=64):
    img = Image.open(img_path).convert("L")
    img = ImageOps.exif_transpose(img)
    resize = KeepRatioResize(target_h)
    img = resize(img)
    tensor = pil_to_tensor_normalized(img).unsqueeze(0)  # (1,1,H,W)
    return tensor, img.size  # (W,H)

# ---------- Decode prediction ----------
def predict(model, codec, img_tensor, device):
    with torch.no_grad():
        img_tensor = img_tensor.to(device)
        logits = model(img_tensor)         # (T, B, C)
        log_probs = logits.log_softmax(dim=-1)
        text = codec.decode_greedy(log_probs)[0]
    return text

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = "best_crnn_ctc.pth"
test_image = "20250625_095530.jpg"

print(f"Loading model from {model_path} ...")
model, codec = load_model(model_path, device)

print(f"Running inference on {test_image} ...")
img_tensor, (w, h) = preprocess_image(test_image, target_h=64)
pred_text = predict(model, codec, img_tensor, device)

print(f"\nPredicted text:\n{pred_text}")

Loading model from best_crnn_ctc.pth ...
Running inference on 20250625_095530.jpg ...

Predicted text:

