# Text Recognition Model Training

Train CRNN/TrOCR for scene text recognition


In [None]:

#  CRNN + BiLSTM + CTC

import os
import random
from pathlib import Path
from collections import Counter
from tqdm import tqdm

import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
H, W = 32, 128
ICDAR_ROOT = Path('/content/drive/MyDrive/ICDAR_update')
TRAIN_IMAGES = ICDAR_ROOT / 'Train' / 'ch4_training_images'
TRAIN_GT = ICDAR_ROOT / 'Train' / 'ch4_training_localization_transcription_gt'
MODELS_DIR = Path('/content/drive/MyDrive/models_crnn_ctc')
MODELS_DIR.mkdir(parents=True, exist_ok=True)

BATCH_SIZE = 32
NUM_WORKERS = 2
NUM_EPOCHS = 20
LR = 2e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# dataset options
FILTER_PLACEHOLDERS = True
KEEP_RATE_FOR_PLACEHOLDER = 0.02

In [None]:
class CTCTokenizer:
    def __init__(self, chars):
        # chars: iterable of characters
        chars = sorted(set(chars))
        self.blank_idx = 0
        # build maps: 0 -> blank, 1.. -> characters
        self.idx_to_char = {0: ''}
        idx = 1
        for ch in chars:
            self.idx_to_char[idx] = ch
            idx += 1
        self.char_to_idx = {ch: i for i, ch in self.idx_to_char.items() if i != 0}
        self.vocab_size = len(self.idx_to_char)  # includes blank

    def encode(self, text):
        seq = []
        for ch in text:
            if ch in self.char_to_idx:
                seq.append(self.char_to_idx[ch])
            else:
                # skip unknown characters
                continue
        return seq

    def decode(self, seq):
        # seq: list/iterable of ints (predicted ids)
        out = []
        prev = None
        for i in seq:
            if i == self.blank_idx:
                prev = i
                continue
            if i != prev:
                out.append(self.idx_to_char.get(int(i), ''))
            prev = int(i)
        return ''.join(out)



In [None]:
def build_vocab(images_dir, gt_dir):
    chars = set()
    image_files = sorted(Path(images_dir).glob("*.jpg")) + sorted(Path(images_dir).glob("*.png"))
    for img_path in tqdm(image_files, desc="Scanning characters"):
        img_name = img_path.stem
        gt_file = Path(gt_dir) / f"gt_{img_name}.txt"
        if not gt_file.exists():
            gt_file = Path(gt_dir) / f"{img_name}.txt"
        if not gt_file.exists():
            continue
        try:
            with open(gt_file, 'r', encoding='utf-8', errors='ignore') as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    parts = line.split(',')
                    if len(parts) >= 9:
                        text = ','.join(parts[8:]).strip().strip('"').strip("'")
                        text = text.upper()
                        if FILTER_PLACEHOLDERS and all(ch == '#' for ch in text):
                            if random.random() > KEEP_RATE_FOR_PLACEHOLDER:
                                continue
                        for ch in text:
                            chars.add(ch)
        except Exception:
            continue
    chars.add(' ')
    chars = sorted(chars)
    return chars


In [None]:
class ICDAR_CTCDataset(Dataset):
    def __init__(self, images_dir, gt_dir, tokenizer, max_label_len=64):
        self.images_dir = Path(images_dir)
        self.gt_dir = Path(gt_dir)
        self.tokenizer = tokenizer
        self.max_label_len = max_label_len
        self.samples = self._load_samples()
        print(f"[Dataset] samples: {len(self.samples)}")

    def _load_samples(self):
        samples = []
        image_files = sorted(self.images_dir.glob("*.jpg")) + sorted(self.images_dir.glob("*.png"))
        for img_path in tqdm(image_files, desc="Loading samples"):
            img_name = img_path.stem
            gt_file = self.gt_dir / f"gt_{img_name}.txt"
            if not gt_file.exists():
                gt_file = self.gt_dir / f"{img_name}.txt"
            if not gt_file.exists():
                continue
            try:
                with open(gt_file, 'r', encoding='utf-8', errors='ignore') as f:
                    for line in f:
                        line = line.strip()
                        if not line:
                            continue
                        parts = line.split(',')
                        if len(parts) >= 9:
                            text = ','.join(parts[8:]).strip().strip('"').strip("'")
                            text = text.upper()
                            if FILTER_PLACEHOLDERS and all(ch == '#' for ch in text):
                                if random.random() > KEEP_RATE_FOR_PLACEHOLDER:
                                    continue
                            if 1 <= len(text) <= self.max_label_len:
                                samples.append((str(img_path), text))
            except Exception:
                continue
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, text = self.samples[idx]
        # read image, convert to grayscale, resize
        img = Image.open(img_path).convert('RGB')
        img = np.array(img)
        if img.ndim == 3:
            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img = cv2.resize(img, (W, H))
        img = img.astype(np.float32) / 255.0
        img = (img - 0.5) / 0.5  # normalize to approx [-1,1]
        img = torch.from_numpy(img).unsqueeze(0)  # (1, H, W)
        labels = self.tokenizer.encode(text)
        labels = torch.tensor(labels, dtype=torch.long)
        return img, labels, text


In [None]:
def collate_fn(batch):
    imgs = []
    labels = []
    label_lens = []
    texts = []
    for img, lab, txt in batch:
        imgs.append(img)
        labels.append(lab)
        label_lens.append(lab.numel())
        texts.append(txt)
    imgs = torch.stack(imgs, dim=0)
    if len(labels) > 0:
        labels_concat = torch.cat(labels)
    else:
        labels_concat = torch.tensor([], dtype=torch.long)
    label_lens = torch.tensor(label_lens, dtype=torch.long)
    return imgs, labels_concat, label_lens, texts

In [None]:
class CRNNEncoder(nn.Module):
    def __init__(self, in_channels=1, lstm_hidden=256):
        super().__init__()
        # conv layers (careful to reduce height to 1)
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, 1, 1),  # 32x128
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                   # 16x64

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),                   # 8x32

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2,1), (2,1)),           # 4x32

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2,1), (2,1)),


            nn.Conv2d(512, 512, kernel_size=(2,1), stride=1, padding=0),
            nn.ReLU(inplace=True),
        )
        self.lstm = nn.LSTM(input_size=512, hidden_size=lstm_hidden, num_layers=2, batch_first=True, bidirectional=True)

    def forward(self, x):
        # x: (B, 1, H, W)
        feat = self.features(x)  # (B, C=512, H'=1, W')
        feat = feat.squeeze(2)   # (B, 512, W')
        feat = feat.permute(0, 2, 1)  # (B, W', 512) -> time dim = W'
        out, _ = self.lstm(feat)      # (B, W', 2*lstm_hidden)
        return out


class CRNN_CTC_Model(nn.Module):
    def __init__(self, vocab_size, lstm_hidden=256):
        super().__init__()
        self.encoder = CRNNEncoder(in_channels=1, lstm_hidden=lstm_hidden)
        self.classifier = nn.Linear(lstm_hidden * 2, vocab_size)  # vocab_size includes blank idx 0

    def forward(self, images):
        # images: (B, 1, H, W)
        seq_feats = self.encoder(images)   # (B, T, feat_dim)
        logits = self.classifier(seq_feats)  # (B, T, V)
        return logits

In [None]:
def ctc_greedy_decode(logits, tokenizer):
    # logits: (B, T, V) raw logits (on device)
    with torch.no_grad():
        probs = logits.softmax(-1)
        preds = probs.argmax(-1).cpu().numpy()  # (B, T)
        outs = []
        for p in preds:
            out = []
            prev = -1
            for idx in p:
                if idx == tokenizer.blank_idx:
                    prev = idx
                    continue
                if idx != prev:
                    out.append(tokenizer.idx_to_char.get(int(idx), ''))
                prev = int(idx)
            outs.append(''.join(out))
    return outs

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running = 0.0
    pbar = tqdm(loader, desc="Train")
    for imgs, labels_concat, label_lens, _texts in pbar:
        imgs = imgs.to(device)
        labels_concat = labels_concat.to(device)
        label_lens = label_lens.to(device)

        logits = model(imgs)                 # (B, T, V)
        log_probs = logits.log_softmax(2)    # required for CTCLoss
        B, T, V = log_probs.shape
        input_lengths = torch.full(size=(B,), fill_value=T, dtype=torch.long).to(device)


        loss = criterion(log_probs.permute(1, 0, 2), labels_concat, input_lengths, label_lens)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    return running / max(1, len(loader))


def evaluate(model, loader, criterion, device, tokenizer, n_display=8):
    model.eval()
    total_loss = 0.0
    batches = 0
    displayed = 0
    with torch.no_grad():
        for imgs, labels_concat, label_lens, texts in tqdm(loader, desc="Evaluating"):
            imgs = imgs.to(device)
            labels_concat = labels_concat.to(device)
            label_lens = label_lens.to(device)

            logits = model(imgs)
            log_probs = logits.log_softmax(2)
            B, T, V = log_probs.shape
            input_lengths = torch.full(size=(B,), fill_value=T, dtype=torch.long).to(device)

            loss = criterion(log_probs.permute(1, 0, 2), labels_concat, input_lengths, label_lens)
            total_loss += loss.item()
            batches += 1

            preds = ctc_greedy_decode(logits, tokenizer)
            for i in range(len(preds)):
                if displayed < n_display:
                    print("GT:", texts[i])
                    print("PR:", preds[i])
                    print("-" * 40)
                    displayed += 1
    avg_loss = total_loss / max(1, batches)
    return avg_loss



In [None]:
def main():
    torch.manual_seed(42)
    random.seed(42)

    # Build vocab
    print("Building vocabulary...")
    chars = build_vocab(TRAIN_IMAGES, TRAIN_GT)
    print(f"Chars ({len(chars)}): {chars[:80]}")

    tokenizer = CTCTokenizer(chars)
    print(f"Tokenizer: vocab_size(including blank) = {tokenizer.vocab_size} (blank idx=0)")

    # Dataset + dataloaders
    full_ds = ICDAR_CTCDataset(TRAIN_IMAGES, TRAIN_GT, tokenizer)
    n = len(full_ds)
    train_n = int(0.9 * n)
    val_n = n - train_n
    train_ds, val_ds = torch.utils.data.random_split(full_ds, [train_n, val_n])
    print(f"Split: Train {len(train_ds)}, Val {len(val_ds)}")

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, collate_fn=collate_fn, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, collate_fn=collate_fn, pin_memory=True)

    # Model, loss, optimizer
    model = CRNN_CTC_Model(vocab_size=tokenizer.vocab_size, lstm_hidden=256).to(DEVICE)
    print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")

    criterion = nn.CTCLoss(blank=tokenizer.blank_idx, zero_infinity=True)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

    best_val = float('inf')

    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch {epoch}/{NUM_EPOCHS} ---")
        train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss = evaluate(model, val_loader, criterion, DEVICE, tokenizer, n_display=8)
        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        scheduler.step(val_loss)
        if val_loss < best_val:
            best_val = val_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'chars': chars,
                'epoch': epoch,
                'val_loss': val_loss
            }, MODELS_DIR / 'crnn_ctc_best.pth')
            print("Saved best model")

    print("Training finished.")


if __name__ == '__main__':
    main()