In [None]:
from util import get_pil_contract, split_into_rows


image = get_pil_contract("1", "1")
rows = split_into_rows(image, 25)

In [2]:
import cv2
import torch
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoImageProcessor
from tqdm import tqdm

# Initialize processors and model
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
bollinger_image_processor = AutoImageProcessor.from_pretrained(
    "pstroe/bullinger-general-model"
)
bollinger_model = VisionEncoderDecoderModel.from_pretrained(
    "pstroe/bullinger-general-model"
)

device = "cuda" if torch.cuda.is_available() else "cpu"
bollinger_model = bollinger_model.to(device)


@torch.no_grad()  # Optimization for inference
def process_batch(batch_rois, batch_indices, ocr_proc, htr_model, htr_proc):
    # Convert all ROIs to PIL Images in batch
    roi_pils = [Image.fromarray(roi).convert("RGB") for roi in batch_rois]

    # Process entire batch at once
    pixel_values = htr_proc(
        images=roi_pils, return_tensors="pt", padding=True
    ).pixel_values.to(device)

    generated_ids = htr_model.generate(pixel_values)
    texts = ocr_proc.batch_decode(generated_ids, skip_special_tokens=True)

    return list(zip(texts, [idx + 1 for idx in batch_indices]))


def process_image(image, rows):
    # Collect all ROIs and their corresponding row indices
    rois = []
    row_indices = []

    # Iterate through rows
    for row_idx, row in enumerate(tqdm(rows, desc="Processing rows")):
        for (x, y), contour in row.items():
            x, y, w, h = cv2.boundingRect(contour)
            roi = image[y : y + h, x : x + w]
            rois.append(roi)
            row_indices.append(row_idx)

    # Process in batches
    BATCH_SIZE = 16  # Adjust based on your GPU memory
    batched_rois = [rois[i : i + BATCH_SIZE] for i in range(0, len(rois), BATCH_SIZE)]
    batched_indices = [
        row_indices[i : i + BATCH_SIZE] for i in range(0, len(row_indices), BATCH_SIZE)
    ]

    print(f"\nProcessing {len(rois)} ROIs in {len(batched_rois)} batches...")

    # Use ThreadPoolExecutor for batch processing
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = []
        for batch_roi, batch_idx in zip(batched_rois, batched_indices):
            future = executor.submit(
                process_batch,
                batch_roi,
                batch_idx,
                trocr_processor,
                bollinger_model,
                bollinger_image_processor,
            )
            futures.append(future)

        # Collect results
        results = []
        for future in futures:
            results.extend(future.result())

    print("\nSorting results...")
    return sorted(results, key=lambda x: x[1])  # Sort by row index

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.47.1"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

In [3]:
results = process_image(image, rows)
for text, row_num in results:
    print(f"Row {row_num}: {text}")

Processing rows: 0it [00:00, ?it/s]


AttributeError: 'Image' object has no attribute 'items'