In [None]:
import os
import math
import string
import random
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image, ImageOps, ImageFilter
from sklearn.model_selection import train_test_split

In [2]:
# ----------------------------
# 1) Character set & CTC codec
# ----------------------------
def default_charset():
    # You can customize this to match your data (e.g., only lowercase + space)
    # Keep space ' ' included if your lines contain spaces.
    charset = list(string.digits + string.ascii_letters + string.punctuation + ' ')
    # Remove characters you know you don't have, or add accents if needed.
    return charset

class CTCCodec:
    """
    Maps characters <-> indices. Index 0 is reserved for CTC blank.
    """
    def __init__(self, charset: List[str]):
        self.blank_idx = 0
        self.chars = ['<BLK>'] + charset
        self.char2idx = {c: i+1 for i, c in enumerate(charset)}  # shift by +1
        self.idx2char = {i+1: c for i, c in enumerate(charset)}

    def encode(self, text: str) -> torch.Tensor:
        return torch.tensor([self.char2idx[c] for c in text if c in self.char2idx], dtype=torch.long)

    def decode_greedy(self, logits: torch.Tensor) -> List[str]:
        """
        logits: (T, N, C) log-probs or raw scores. We'll argmax over classes.
        Returns list of length N with collapsed CTC decoding.
        """
        with torch.no_grad():
            pred = logits.argmax(dim=-1)  # (T, N)
            pred = pred.cpu().numpy()
        N = pred.shape[1]
        texts = []
        for n in range(N):
            seq = pred[:, n]
            prev = -1
            out = []
            for idx in seq:
                if idx != self.blank_idx and idx != prev:
                    out.append(self.idx2char.get(int(idx), ''))
                prev = idx
            texts.append(''.join(out))
        return texts