In [1]:
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Tuple


In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", DEVICE)


[INFO] Using device: cuda


In [3]:
def read_image(path: Path) -> np.ndarray:
    img = cv2.imread(str(path))
    if img is None:
        raise ValueError(f"Failed to read image: {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


In [None]:
def detect_text_regions(img: np.ndarray) -> List[np.ndarray]:
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    blur = cv2.GaussianBlur(gray, (5,5), 0)
    _, thresh = cv2.threshold(
        blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
    )

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15,5))
    dilated = cv2.dilate(thresh, kernel, iterations=2)

    contours, _ = cv2.findContours(
        dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )

    regions = []
    for c in contours:
        x,y,w,h = cv2.boundingRect(c)
        if w > 40 and h > 15:
            regions.append(img[y:y+h, x:x+w])

    return regions


In [None]:
class CRNN(nn.Module):
    def __init__(self, num_classes=36):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )

        self.rnn = nn.LSTM(
            128 * 8,
            128,
            bidirectional=True,
            batch_first=True
        )

        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        x = x.permute(0,3,1,2).contiguous()
        x = x.view(b, w, c*h)
        x,_ = self.rnn(x)
        return self.fc(x)


In [None]:
OCR_MODEL_PATH = Path("../outputs/models/ocr/best_model.pth")

ocr_model = CRNN().to(DEVICE)
ocr_model.eval()

if OCR_MODEL_PATH.exists():
    state = torch.load(OCR_MODEL_PATH, map_location=DEVICE)
    ocr_model.load_state_dict(state["model_state"])
    print("✅ OCR model loaded")
else:
    print("⚠️ No pretrained OCR model found")


In [None]:
ALPHABET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"

def decode_prediction(logits: torch.Tensor) -> str:
    preds = torch.argmax(logits, dim=-1)
    preds = preds.squeeze(0).cpu().numpy()

    prev = -1
    text = ""
    for p in preds:
        if p != prev and p < len(ALPHABET):
            text += ALPHABET[p]
        prev = p
    return text


In [None]:
@torch.no_grad()
def ocr_on_image(img: np.ndarray) -> List[str]:
    regions = detect_text_regions(img)
    results = []

    for r in regions:
        gray = cv2.cvtColor(r, cv2.COLOR_RGB2GRAY)
        gray = cv2.resize(gray, (128,32))
        t = torch.from_numpy(gray).float() / 255.0
        t = t.unsqueeze(0).unsqueeze(0).to(DEVICE)

        logits = ocr_model(t)
        text = decode_prediction(logits)
        results.append(text)

    return results


In [None]:
def compare_ocr(raw_img: np.ndarray, enhanced_img: np.ndarray):
    raw_text = ocr_on_image(raw_img)
    enh_text = ocr_on_image(enhanced_img)

    print("Raw OCR:", raw_text)
    print("Enhanced OCR:", enh_text)


In [None]:
def visualize_ocr(img_path: Path):
    img = read_image(img_path)
    texts = ocr_on_image(img)

    plt.figure(figsize=(6,4))
    plt.imshow(img)
    plt.title(f"OCR Output: {texts}")
    plt.axis("off")
    plt.show()
