# 0) Setup (Install & Imports)

In [1]:
!pip install torch torchvision torchaudio rapidfuzz
!pip install pillow

import shutil, glob, os, math, random, string
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont

import os
import cv2
import math
import time
import glob
import random
import string
import numpy as np
from pathlib import Path
from typing import List, Tuple
import shutil
import string

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

Collecting rapidfuzz
  Downloading rapidfuzz-3.14.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading rapidfuzz-3.14.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m66.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz
Successfully installed rapidfuzz-3.14.0


In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


# 1) Character Set & Utilities

In [3]:
# EXACTLY match what your generator emits
CHARSET = (
    list(string.digits) +
    list(string.ascii_uppercase) +
    list(string.ascii_lowercase) +
    list("-/:.,()&%+ #")
)
BLANK_IDX = 0
IDX2CHAR = {i+1: ch for i, ch in enumerate(CHARSET)}  # 1..N => char
CHAR2IDX = {ch: i+1 for i, ch in enumerate(CHARSET)}  # char => 1..N
NUM_CLASSES = len(CHARSET) + 1  # +1 for CTC blank at index 0

def text_to_labels(s: str):
    return [CHAR2IDX[ch] for ch in s if ch in CHAR2IDX]

def labels_to_text(labels):
    out, prev = [], None
    for l in labels:
        if l != BLANK_IDX and l != prev:
            out.append(IDX2CHAR.get(l, ""))
        prev = l
    return "".join(out)

# 2) Tiny Synthetic Dataset (OpenCV-generated)

## Random Text Gen Function

In [4]:
def _rand_text(min_len=5, max_len=12):
    """
    Generate realistic OCR strings with weighted patterns.
    Ensures final length is between [min_len, max_len].
    """
    def gen_12_digit():
        return "".join(random.choices(string.digits, k=12))

    def gen_lienceNo():
      return random.choice(["B", "v"]).join(random.choices(string.digits, k=7))

    def gen_9_digit_V():
        # 9 digits + space + 'V' or 'v'  (length 11)
        return "".join(random.choices(string.digits, k=9)) + " " + random.choice(["V", "v"])

    def gen_title_word():
        # One or two simple English-looking words (Title Case)
        vowels = "aeiou"
        consonants = "".join([c for c in string.ascii_lowercase if c not in vowels])
        def make_syllable():
            # simple CV or CVC to look word-like
            syl = random.choice(consonants) + random.choice(vowels)
            if random.random() < 0.5:
                syl += random.choice(consonants)
            return syl
        def make_word(min_w=3, max_w=9):
            wlen = random.randint(min_w, max_w)
            # stitch 2–4 syllables to reach length target
            s = ""
            while len(s) < wlen:
                s += make_syllable()
            s = s[:wlen]
            return s.capitalize()
        if random.random() < 0.5:
            return make_word()
        else:
            return f"{make_word()} {make_word()}"

    def gen_upper_code():
        # AB-1234 or ABC-12 etc.
        letters = "".join(random.choices(string.ascii_uppercase, k=random.randint(2, 3)))
        digits = "".join(random.choices(string.digits, k=random.randint(2, 4)))
        return f"{letters}-{digits}"

    def gen_invoice():
        # INV-123456 (8–11 chars typical)
        return "INV-" + "".join(random.choices(string.digits, k=random.randint(4, 7)))

    def gen_date():
        # DD/MM/YYYY
        d = random.randint(1, 28)
        m = random.randint(1, 12)
        y = random.randint(2000, 2029)
        return f"{d:02d}/{m:02d}/{y:04d}"

    def gen_plate():
        # ABC-1234 (common style)
        letters = "".join(random.choices(string.ascii_uppercase, k=3))
        digits = "".join(random.choices(string.digits, k=4))
        return f"{letters}-{digits}"

    # Weighted pattern list (func, weight)
    patterns = [
        (gen_12_digit, 3),
        (gen_9_digit_V, 3),
        (gen_lienceNo, 3),
        (gen_title_word, 2),
        (gen_upper_code, 2),
        (gen_invoice, 1),
        (gen_date, 1),
        (gen_plate, 1),
    ]

    funcs, weights = zip(*patterns)

    # Try up to a few times to honor length; then fall back to mixed
    for _ in range(6):
        txt = random.choices(funcs, weights=weights, k=1)[0]()
        if min_len <= len(txt) <= max_len:
            return txt

    # fallback (guaranteed length)
    return gen_mixed()

## Apply Backgrounds

In [5]:
BACKGROUND_DIR = Path("./backgrounds")

def _list_bg_images(bg_dir: Path):
    exts = ("*.png", "*.jpg", "*.jpeg", "*.webp")
    imgs = []
    if bg_dir.exists():
        for ext in exts:
            imgs += sorted(glob.glob(str(bg_dir / ext)))
    return imgs

def _prepare_bg(bg_bgr: np.ndarray, H: int, W: int) -> np.ndarray:
    """
    Accepts a small cropped swatch (BGR). Tiles if smaller than target,
    then random-crops to exactly HxW. Applies light jitter for realism.
    Returns BGR.
    """
    h, w = bg_bgr.shape[:2]
    # Tile if swatch is smaller than needed
    if h < H or w < W:
        rep_y = (H + h - 1) // h
        rep_x = (W + w - 1) // w
        bg_bgr = np.tile(bg_bgr, (rep_y, rep_x, 1))
        h, w = bg_bgr.shape[:2]

    # Random crop to HxW
    y0 = random.randint(0, h - H)
    x0 = random.randint(0, w - W)
    bg_bgr = bg_bgr[y0:y0 + H, x0:x0 + W]

    # Light jitter (optional)
    if random.random() < 0.5:
        bg_bgr = cv2.GaussianBlur(bg_bgr, (3, 3), 0)
    if random.random() < 0.8:
        alpha = random.uniform(0.85, 1.15)  # contrast-ish
        beta  = random.randint(-12, 12)     # brightness-ish
        bg_bgr = np.clip(bg_bgr.astype(np.float32) * alpha + beta, 0, 255).astype(np.uint8)

    return bg_bgr


## Generate Data

In [12]:
# ----------------- Paths & config -----------------
SYNTH_ROOT = Path("./synthetic_ocr")
SYNTH_IMG_DIR = SYNTH_ROOT / "images"
SYNTH_LABELS_PATH = SYNTH_ROOT / "labels.txt"
FONTS_DIR = Path("./fonts")
BACKGROUND_DIR = Path("./backgrounds")

N_SAMPLES = 10000
MIN_LENGTH = 2
MAX_LENGTH = 20

POOL = string.ascii_uppercase + string.ascii_lowercase + string.digits + " -/:.,()&%+ #"

# ----------------- helpers -----------------
def _list_fonts(fonts_dir: Path):
    if not fonts_dir.exists(): return []
    return sorted(glob.glob(str(fonts_dir / "*.ttf"))) + sorted(glob.glob(str(fonts_dir / "*.otf")))

def _list_bg_images(bg_dir: Path):
    if not bg_dir.exists(): return []
    exts = ("*.png","*.jpg","*.jpeg","*.webp","*.bmp")
    paths = []
    for e in exts: paths += glob.glob(str(bg_dir / e))
    return sorted(paths)

def _rand_text(min_len=5, max_len=12):
    length = random.randint(min_len, max_len)
    txt = "".join(random.choice(POOL) for _ in range(length)).strip()
    return txt if txt else "A1"

def _render_with_pillow(text: str, H: int = 48, W: int = 300, font_path: str = None):
    W = max(W, min(340, 12 * len(text) + 40))
    if font_path:
        best_size = 24
        for sz in [26, 28, 30, 32]:
            try:
                f = ImageFont.truetype(font_path, sz)
                _, ascent = f.getmetrics()
                if ascent <= int(0.8 * H): best_size = sz
            except Exception:
                pass
        font = ImageFont.truetype(font_path, best_size)
    else:
        font = ImageFont.load_default()

    img = Image.new("L", (W, H), color=255)
    draw = ImageDraw.Draw(img)
    bbox = draw.textbbox((0, 0), text, font=font)
    tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
    x = 5
    y = (H - th) // 2
    draw.text((x, y), text, font=font, fill=30)
    return np.array(img, dtype=np.uint8)

def _prepare_bg(bg_bgr, H, W):
    h, w = bg_bgr.shape[:2]
    if h < H or w < W:
        reps_y = math.ceil(H / h) + 1
        reps_x = math.ceil(W / w) + 1
        tiled = np.tile(bg_bgr, (reps_y, reps_x, 1))
        bg_bgr = tiled
        h, w = bg_bgr.shape[:2]
    y0 = random.randint(0, h - H)
    x0 = random.randint(0, w - W)
    crop = bg_bgr[y0:y0+H, x0:x0+W].copy()
    return cv2.GaussianBlur(crop, (3,3), 0)

# ----------------- main -----------------
def make_synthetic_dataset(n_samples=N_SAMPLES, min_len=MIN_LENGTH, max_len=MAX_LENGTH):
    # 1) Clean & recreate root dirs (this was missing)
    if SYNTH_ROOT.exists():
        shutil.rmtree(SYNTH_ROOT, ignore_errors=True)
    SYNTH_ROOT.mkdir(parents=True, exist_ok=True)
    SYNTH_IMG_DIR.mkdir(parents=True, exist_ok=True)   # ensures labels' parent exists too

    # 2) Collect resources
    font_paths = _list_fonts(FONTS_DIR)
    bg_paths = _list_bg_images(BACKGROUND_DIR)
    print(f"Fonts: {len(font_paths)} | Backgrounds: {len(bg_paths)}")
    if not font_paths:
        print("No TTF/OTF fonts found → using OpenCV Hershey fallback.")
    if not bg_paths:
        print("No backgrounds found → plain white backgrounds only.")

    # 3) Generate
    lines = []
    print("Generating synthetic dataset...")
    for i in range(n_samples):
        text = _rand_text(min_len, max_len)

        # base text image (grayscale)
        if font_paths:
            img = _render_with_pillow(text, H=48, W=300, font_path=random.choice(font_paths))
        else:
            img = np.ones((48, 300), dtype=np.uint8) * 255
            font = random.choice([
                cv2.FONT_HERSHEY_SIMPLEX, cv2.FONT_HERSHEY_DUPLEX,
                cv2.FONT_HERSHEY_PLAIN, cv2.FONT_HERSHEY_COMPLEX_SMALL
            ])
            font_scale = random.uniform(0.8, 1.2)
            thickness = random.randint(1, 2)
            (tw, th), _ = cv2.getTextSize(text, font, font_scale, thickness)
            x = 5
            y = (img.shape[0] + th) // 2 - 3
            cv2.putText(img, text, (x, y), font, font_scale, (0,), thickness, cv2.LINE_AA)

        # rotation
        ang = random.uniform(-4, 4)
        M = cv2.getRotationMatrix2D((img.shape[1]//2, img.shape[0]//2), ang, 1.0)
        img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]), borderValue=255)

        # perspective
        if random.random() < 0.3:
            h, w = img.shape[:2]
            src = np.float32([[0,0],[w-1,0],[0,h-1],[w-1,h-1]])
            dx = random.randint(0, 6)
            dy = random.randint(0, 4)
            dst = np.float32([[0+dx,0+dy],[w-1-dx,0],[0,h-1],[w-1-dx,h-1-dy]])
            P = cv2.getPerspectiveTransform(src, dst)
            img = cv2.warpPerspective(img, P, (w, h), borderValue=255)

        # optional textured background
        if bg_paths:
            bg = cv2.imread(random.choice(bg_paths), cv2.IMREAD_COLOR)
            if bg is not None:
                H, W = img.shape[:2]
                bg = _prepare_bg(bg, H, W)
                alpha = (255.0 - img.astype(np.float32)) / 255.0
                out = bg.astype(np.float32) * (1.0 - alpha[..., None])
                img = np.clip(out, 0, 255).astype(np.uint8)  # now BGR

        # light noise
        if random.random() < 0.5:
            noise = np.random.normal(0, 5, img.shape).astype(np.int16)
            img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8)

        # save
        fname = f"img_{i:05d}.png"
        cv2.imwrite(str(SYNTH_IMG_DIR / fname), img)
        lines.append(f"{fname}\t{text}")

    # 4) Write labels (parent exists now)
    with open(SYNTH_LABELS_PATH, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))

    print(f"Saved {len(lines)} samples to {SYNTH_IMG_DIR} and labels to {SYNTH_LABELS_PATH}")

# run
make_synthetic_dataset()

Fonts: 3 | Backgrounds: 11
Generating synthetic dataset...
Saved 10000 samples to synthetic_ocr/images and labels to synthetic_ocr/labels.txt


# 3) Data Loading and Preparation

In [7]:
def collate_fn(batch):
    images, targets, target_lengths = zip(*batch)

    # Stack images (they are already the same size)
    images = torch.stack(images, 0)

    # Find max target length in this batch
    max_target_len = max(target_lengths).item()

    # Pad all targets to max length
    padded_targets = []
    for target in targets:
        pad_size = max_target_len - len(target)
        padded_target = torch.cat([target, torch.zeros(pad_size, dtype=torch.long)])
        padded_targets.append(padded_target)

    targets = torch.stack(padded_targets, 0)
    target_lengths = torch.stack(target_lengths, 0)

    return images, targets, target_lengths

In [8]:
class OCRDataset(Dataset):
    def __init__(self, labels_file, img_dir, img_height=32, img_width=128):
        self.img_dir = Path(img_dir)
        self.img_height = img_height
        self.img_width = img_width

        # Read labels
        with open(labels_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()

        self.samples = []
        for line in lines:
            parts = line.strip().split('\t')
            if len(parts) >= 2:
                self.samples.append((parts[0], parts[1]))

        print(f"Loaded {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, text = self.samples[idx]
        img_path = self.img_dir / img_name

        # Load and preprocess image
        img = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
        if img is None:
            # Fallback to a blank image if loading fails
            img = np.ones((self.img_height, self.img_width), dtype=np.uint8) * 255

        # Resize with aspect ratio preservation
        h, w = img.shape
        new_w = int(w * (self.img_height / h))
        img = cv2.resize(img, (new_w, self.img_height), interpolation=cv2.INTER_AREA)

        # Pad to fixed width
        if new_w < self.img_width:
            pad_width = self.img_width - new_w
            img = np.pad(img, ((0, 0), (0, pad_width)), mode='constant', constant_values=255)
        else:
            img = img[:, :self.img_width]

        # Normalize and convert to tensor
        img = img.astype(np.float32) / 255.0
        img = (img - 0.5) / 0.5  # Normalize to [-1, 1]
        img_tensor = torch.from_numpy(img).unsqueeze(0)  # Add channel dimension

        # Convert text to label indices
        target = text_to_labels(text)
        target_length = torch.tensor([len(target)], dtype=torch.long)
        target = torch.tensor(target, dtype=torch.long)

        return img_tensor, target, target_length

# Create datasets
train_dataset = OCRDataset(SYNTH_LABELS_PATH, SYNTH_IMG_DIR)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

Loaded 10000 samples


# 4) CRNN Model Definition

In [9]:
class CRNN(nn.Module):
    def __init__(self, img_channel, img_height, img_width, num_classes, hidden_size=256):
        super(CRNN, self).__init__()

        # CNN layers
        self.cnn = nn.Sequential(
            nn.Conv2d(img_channel, 64, 3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU(),
            nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, None))  # Adaptive pooling to handle variable width
        )

        # Calculate LSTM input features
        with torch.no_grad():
            dummy_input = torch.zeros(1, img_channel, img_height, img_width)
            dummy_output = self.cnn(dummy_input)
            lstm_input_size = dummy_output.size(1) * dummy_output.size(2)

        # RNN layers (LSTM)
        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )

        # Output layer
        self.fc = nn.Linear(hidden_size * 2, num_classes)  # *2 for bidirectional

    def forward(self, x):
        # CNN
        x = self.cnn(x)

        # Prepare for LSTM
        batch, channels, height, width = x.size()
        x = x.permute(0, 3, 1, 2)  # [batch, width, channels, height]
        x = x.reshape(batch, width, channels * height)  # [batch, width, channels*height]

        # LSTM
        x, _ = self.lstm(x)

        # Output layer
        x = self.fc(x)
        x = x.permute(1, 0, 2)  # [width, batch, num_classes] for CTC

        return x

# Initialize model
model = CRNN(
    img_channel=1,
    img_height=32,
    img_width=128,
    num_classes=NUM_CLASSES
).to(DEVICE)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

Model has 7,693,643 parameters


# 5) Training Setup

In [10]:
# Loss function (CTC)
criterion = nn.CTCLoss(blank=BLANK_IDX, zero_infinity=True)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# 6) Training Loop

In [11]:
def train_model(model, dataloader, criterion, optimizer, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        correct = 0
        total = 0

        for batch_idx, (data, targets, target_lengths) in enumerate(dataloader):
            data = data.to(DEVICE)
            targets = targets.to(DEVICE)
            target_lengths = target_lengths.to(DEVICE)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(data)

            # Prepare for CTC loss
            input_lengths = torch.full(
                size=(outputs.size(1),),
                fill_value=outputs.size(0),
                dtype=torch.long
            ).to(DEVICE)

            # Calculate loss
            loss = criterion(outputs, targets, input_lengths, target_lengths)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Statistics
            epoch_loss += loss.item()

            # Calculate accuracy (greedy decoding)
            _, max_probs = torch.max(outputs.permute(1, 0, 2), 2)
            predictions = []
            for i in range(max_probs.size(0)):
                raw_pred = max_probs[i].cpu().numpy()
                pred_text = labels_to_text(raw_pred)
                predictions.append(pred_text)

            # Compare with ground truth
            for i, (pred, target_len) in enumerate(zip(predictions, target_lengths)):
                # Only use the non-padded part of the target
                target_text = labels_to_text(targets[i][:target_len].cpu().numpy())
                if pred == target_text:
                    correct += 1
                total += 1

            if batch_idx % 20 == 0:
                print(f'Epoch: {epoch+1}/{num_epochs} | '
                      f'Batch: {batch_idx}/{len(dataloader)} | '
                      f'Loss: {loss.item():.4f}')

        scheduler.step()
        accuracy = 100 * correct / total if total > 0 else 0
        avg_loss = epoch_loss / len(dataloader)
        print(f'Epoch {epoch+1} completed | '
              f'Avg Loss: {avg_loss:.4f} | '
              f'Accuracy: {accuracy:.2f}%')

    return model

# Train the model
print("Starting training...")
model = train_model(model, train_loader, criterion, optimizer, num_epochs=30)

Starting training...
Epoch: 1/30 | Batch: 0/313 | Loss: -2.9227
Epoch: 1/30 | Batch: 20/313 | Loss: 7.0321
Epoch: 1/30 | Batch: 40/313 | Loss: 5.4155
Epoch: 1/30 | Batch: 60/313 | Loss: 5.9717
Epoch: 1/30 | Batch: 80/313 | Loss: 5.3376
Epoch: 1/30 | Batch: 100/313 | Loss: 4.9195
Epoch: 1/30 | Batch: 120/313 | Loss: 5.0653
Epoch: 1/30 | Batch: 140/313 | Loss: 4.6093
Epoch: 1/30 | Batch: 160/313 | Loss: 5.0029
Epoch: 1/30 | Batch: 180/313 | Loss: 4.8764
Epoch: 1/30 | Batch: 200/313 | Loss: 5.2081
Epoch: 1/30 | Batch: 220/313 | Loss: 4.9635
Epoch: 1/30 | Batch: 240/313 | Loss: 4.6636
Epoch: 1/30 | Batch: 260/313 | Loss: 4.6830
Epoch: 1/30 | Batch: 280/313 | Loss: 5.1866
Epoch: 1/30 | Batch: 300/313 | Loss: 4.8467
Epoch 1 completed | Avg Loss: 4.9715 | Accuracy: 0.00%
Epoch: 2/30 | Batch: 0/313 | Loss: 4.8250
Epoch: 2/30 | Batch: 20/313 | Loss: 4.9096
Epoch: 2/30 | Batch: 40/313 | Loss: 4.6700
Epoch: 2/30 | Batch: 60/313 | Loss: 4.5871
Epoch: 2/30 | Batch: 80/313 | Loss: 5.1533
Epoch: 2/30

KeyboardInterrupt: 

In [None]:
"""# 8) Model Evaluation and Export"""

def predict_image(model, image_path):
    # Load and preprocess image
    img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
    if img is None:
        return "Error: Could not load image"

    # Preprocess like in dataset
    h, w = img.shape
    new_w = int(w * (32 / h))
    img = cv2.resize(img, (new_w, 32), interpolation=cv2.INTER_AREA)

    if new_w < 128:
        pad_width = 128 - new_w
        img = np.pad(img, ((0, 0), (0, pad_width)), mode='constant', constant_values=255)
    else:
        img = img[:, :128]

    # Normalize and convert to tensor
    img = img.astype(np.float32) / 255.0
    img = (img - 0.5) / 0.5
    img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(DEVICE)  # [1, 1, H, W]

    # Inference
    model.eval()
    with torch.no_grad():
        outputs = model(img_tensor)
        _, max_probs = torch.max(outputs.permute(1, 0, 2), 2)
        raw_pred = max_probs[0].cpu().numpy()
        pred_text = labels_to_text(raw_pred)

    return pred_text

# Test on a few samples
test_samples = random.sample(train_dataset.samples, min(5, len(train_dataset.samples)))
print("\nTesting on sample images:")
for img_name, true_text in test_samples:
    img_path = SYNTH_IMG_DIR / img_name
    pred_text = predict_image(model, img_path)
    print(f"Image: {img_name} | True: '{true_text}' | Pred: '{pred_text}'")

# Save the trained model
def save_model(model, path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'charset': CHARSET,
        'img_height': 32,
        'img_width': 128
    }, path)
    print(f"Model saved to {path}")

MODEL_PATH = "./crnn_ocr_model.pth"
save_model(model, MODEL_PATH)

print("Training and export completed successfully!")