In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Device:", DEVICE)

# =====================
# PATHS
# =====================
DATA_ROOT = Path("../data/datasets/ICDAR2015/Word recognition train set")
GT_FILE = DATA_ROOT / "gt.txt"
MODEL_SAVE = Path("../outputs/models/ocr")
MODEL_SAVE.mkdir(parents=True, exist_ok=True)


[INFO] Device: cuda


In [3]:
ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
NUM_CLASSES = len(ALPHABET) + 1  # +1 for CTC blank

char2idx = {c: i for i, c in enumerate(ALPHABET)}


In [4]:
class ICDARWordDataset(Dataset):
    def __init__(self, root, gt_file):
        self.root = root
        self.samples = []

        with open(gt_file, "r", encoding="utf-8") as f:
            for line in f:
                name, text = line.strip().split(",", 1)
                self.samples.append((name, text))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, text = self.samples[idx]
        img_path = self.root / img_name

        img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (128, 32))
        img = img.astype(np.float32) / 255.0

        img = torch.from_numpy(img).unsqueeze(0)

        label = torch.tensor([char2idx[c] for c in text], dtype=torch.long)

        return img, label, len(label)


In [5]:
def collate_fn(batch):
    imgs, labels, lengths = zip(*batch)

    imgs = torch.stack(imgs)
    labels = torch.cat(labels)
    lengths = torch.tensor(lengths)

    return imgs, labels, lengths


In [None]:
dataset = ICDARWordDataset(DATA_ROOT, GT_FILE)
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

model = CRNN().to(DEVICE)
criterion = nn.CTCLoss(blank=NUM_CLASSES-1, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
NUM_EPOCHS = 30

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels, label_lens in tqdm(loader, desc=f"Epoch {epoch}"):
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        label_lens = label_lens.to(DEVICE)

        preds = model(imgs)
        T = preds.size(1)
        pred_lens = torch.full(
            size=(imgs.size(0),),
            fill_value=T,
            dtype=torch.long
        ).to(DEVICE)

        loss = criterion(
            preds.permute(1,0,2),
            labels,
            pred_lens,
            label_lens
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"[Epoch {epoch}] Loss: {avg_loss:.4f}")

    torch.save(
        {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "loss": avg_loss
        },
        MODEL_SAVE / "ocr_latest.pth"
    )


Epoch 0:   0%|          | 0/140 [00:00<?, ?it/s]