In [24]:
import argparse
import json
import math
import os
from typing import List, Dict, Tuple

import cv2
import numpy as np
import torch
from PIL import Image
from craft_text_detector import Craft
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

In [25]:
IMAGE_DIR = "test.png"
OUTPUT_DIR = "out.png"
OUTPUT_IMG_DIR = "out2.png"
MAX_TOKENS = 64
MODEL_NAME = "microsoft/trocr-base-handwritten"

In [26]:
def load_detector(use_cuda: bool) -> Craft:
    # CRAFT text detector
    craft = Craft(
        output_dir=None,            # no disk output
        crop_type="poly",           # polygon crops
        cuda=use_cuda,
        long_size=1280              # increase for higher-res detection if needed
    )
    return craft


def detect_regions(image_path: str, craft: Craft) -> Dict:
    """
    Runs CRAFT detection on an image file.
    Returns dict with polys (Nx4x2), boxes (simple boxes), heatmaps, etc.
    """
    result = craft.detect_text(image_path)
    return result


def poly_to_bbox(poly: np.ndarray, margin: int = 2, img_w: int = None, img_h: int = None) -> Tuple[int, int, int, int]:
    """
    Convert a polygon (4x2) to a padded axis-aligned bbox (x, y, w, h)
    """
    xs = poly[:, 0]
    ys = poly[:, 1]
    x_min = max(int(np.floor(xs.min())) - margin, 0)
    y_min = max(int(np.floor(ys.min())) - margin, 0)
    x_max = int(np.ceil(xs.max())) + margin
    y_max = int(np.ceil(ys.max())) + margin
    if img_w is not None:
        x_max = min(x_max, img_w - 1)
    if img_h is not None:
        y_max = min(y_max, img_h - 1)
    w = max(x_max - x_min, 1)
    h = max(y_max - y_min, 1)
    return x_min, y_min, w, h


def sort_polys_reading_order(polys: List[np.ndarray]) -> List[int]:
    """
    Sort polygons roughly in reading order: top-to-bottom, then left-to-right.
    Returns indices of sorted order.
    """
    # Compute top-left point for each poly
    anchors = []
    for i, p in enumerate(polys):
        x_min = float(np.min(p[:, 0]))
        y_min = float(np.min(p[:, 1]))
        anchors.append((i, y_min, x_min))

    # First by y, then by x, with a tolerance to cluster lines
    anchors.sort(key=lambda t: (round(t[1] / 15.0), t[2]))
    return [a[0] for a in anchors]


class HandwritingRecognizer:
    def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu", model_name: str = MODEL_NAME):
        self.device = device
        self.processor = TrOCRProcessor.from_pretrained(model_name)
        self.model = VisionEncoderDecoderModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

    @torch.no_grad()
    def recognize_batch(self, images: List[Image.Image], max_new_tokens: int = 64) -> List[Dict]:
        """
        Recognize a batch of PIL images.
        Returns list of dicts: {text, confidence}
        Confidence is an approximation using sequence_scores from generate.
        """
        if not images:
            return []
        pixel_values = self.processor(images=images, return_tensors="pt").pixel_values.to(self.device)

        gen_out = self.model.generate(
            pixel_values,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            return_dict_in_generate=True,
            output_scores=True
        )
        sequences = gen_out.sequences
        texts = self.processor.batch_decode(sequences, skip_special_tokens=True)

        # Approximate confidence: exp(sequence_logprob)
        # sequence_scores are log-prob per sequence if output_scores=True and no sampling.
        confidences = []
        if hasattr(gen_out, "sequences_scores") and gen_out.sequences_scores is not None:
            for s in gen_out.sequences_scores:
                # Clamp for numerical safety, convert to probability in (0,1]
                prob = float(torch.exp(torch.clamp(s, min=-50.0, max=0.0)).item())
                confidences.append(prob)
        else:
            confidences = [None] * len(texts)

        return [{"text": t.strip(), "confidence": c} for t, c in zip(texts, confidences)]


def visualize(image_bgr: np.ndarray, polys: List[np.ndarray], texts: List[str]) -> np.ndarray:
    vis = image_bgr.copy()
    for i, (poly, txt) in enumerate(zip(polys, texts)):
        pts = poly.astype(np.int32).reshape((-1, 1, 2))
        cv2.polylines(vis, [pts], isClosed=True, color=(0, 255, 0), thickness=2)
        # Put index to avoid clutter with long text
        tl = (int(poly[:, 0].min()), int(poly[:, 1].min()) - 5)
        cv2.putText(vis, f"#{i+1}", tl, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (50, 220, 50), 2, cv2.LINE_AA)
    return vis

In [27]:
def main():
    use_cuda = torch.cuda.is_available()
    device = "cuda" if use_cuda else "cpu"

    # Load image
    image_bgr = cv2.imread(IMAGE_DIR)
    if image_bgr is None:
        raise FileNotFoundError(f"Could not read image: {IMAGE_DIR}")
    img_h, img_w = image_bgr.shape[:2]
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

    # 1) Detect text regions
    craft = load_detector(use_cuda=use_cuda)
    detection = detect_regions(IMAGE_DIR, craft)
    polys = detection.get("polys", [])  # list of 4x2 arrays
    if not polys:
        print("No text regions detected.")
        with open(OUTPUT_DIR, "w", encoding="utf-8") as f:
            json.dump({"image": IMAGE_DIR, "items": []}, f, ensure_ascii=False, indent=2)
        return

    # Convert to numpy arrays if needed
    np_polys = []
    for p in polys:
        np_polys.append(np.array(p, dtype=np.float32))
    idx_order = sort_polys_reading_order(np_polys)

    # 2) Crop regions and recognize
    pil_crops = []
    meta = []
    for idx in idx_order:
        poly = np_polys[idx]
        x, y, w, h = poly_to_bbox(poly, margin=3, img_w=img_w, img_h=img_h)
        crop_bgr = image_bgr[y:y+h, x:x+w]
        if crop_bgr.size == 0 or w < 4 or h < 4:
            # Skip tiny or invalid crops
            pil_crops.append(None)
            meta.append({"poly": poly.tolist(), "bbox": [x, y, w, h], "valid": False})
            continue
        crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
        pil_crops.append(Image.fromarray(crop_rgb))
        meta.append({"poly": poly.tolist(), "bbox": [x, y, w, h], "valid": True})

    recognizer = HandwritingRecognizer(device=device)
    # Some crops may be None; filter, run, then map back
    valid_indices = [i for i, im in enumerate(pil_crops) if im is not None]
    valid_images = [pil_crops[i] for i in valid_indices]

    rec_results = recognizer.recognize_batch(valid_images, max_new_tokens=MAX_TOKENS) if valid_images else []
    # Map back
    texts_all = [""] * len(pil_crops)
    confs_all = [None] * len(pil_crops)
    for k, i in enumerate(valid_indices):
        texts_all[i] = rec_results[k]["text"]
        confs_all[i] = rec_results[k]["confidence"]

    # 3) Prepare output
    items = []
    for i, (m, text, conf) in enumerate(zip(meta, texts_all, confs_all)):
        items.append({
            "index": i + 1,
            "text": text,
            "confidence": None if conf is None else float(conf),
            "polygon": m["poly"],
            "bbox": m["bbox"],
            "valid": m["valid"]
        })

    result = {
        "image": IMAGE_DIR,
        "items": items
    }

    os.makedirs(os.path.dirname(OUTPUT_DIR) or ".", exist_ok=True)
    with open(OUTPUT_DIR, "w", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False, indent=2)

    # 4) Optional visualization
    if OUTPUT_IMG_DIR:
        ordered_polys = [np_polys[i] for i in idx_order]
        vis_img = visualize(image_bgr, ordered_polys, [t for t in texts_all])
        cv2.imwrite(OUTPUT_IMG_DIR, vis_img)

    # Cleanup CRAFT to free memory
    craft.unload_craftnet_model()
    craft.unload_refinenet_model()

    print(f"Done. JSON -> {OUTPUT_DIR}" + (f", VIS -> {OUTPUT_IMG_DIR}" if OUTPUT_IMG_DIR else ""))

In [28]:
main()

ImportError: cannot import name 'model_urls' from 'torchvision.models.vgg' (/home/kronbii/miniconda/envs/medical-ocr/lib/python3.10/site-packages/torchvision/models/vgg.py)