In [1]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

from glob import glob
from tqdm import tqdm
from PIL import Image

from Modules import LineDetection, OCRInference
from Utils import create_dir, get_file_name

In [3]:
def show_image(image: np.array):
    display(Image.fromarray(image))

In [4]:
line_model_config = "Models/LineModels/line_model_config.json"
ocr_model_config = "Models/OCRModels/LhasaKanjur/model_config.json"
line_inference = LineDetection(config_file=line_model_config, dilate_kernel=100, dilate_iterations=100, binarize_output=False)
ocr_inference = OCRInference(config_file=ocr_model_config)

In [5]:
def run_ocr(image_path: str, out_path: str, save_preview: bool = False):
    image_name = get_file_name(image_path)
    image = cv2.imread(image_path)
    prediction, line_images, sorted_contours, peaks = line_inference.predict(image, 0)
    predicted_text, raw_prediction = ocr_inference.run(line_images)
    
    out_text = f"{out_path}/{image_name}.txt"

    with open(out_text, "w", encoding="utf-8") as f:
        for line in predicted_text:
            f.write(f"{line}\n")

    if save_preview:
        prediction = cv2.cvtColor(prediction, cv2.COLOR_GRAY2BGR)
        cv2.addWeighted(prediction, 0.4, image, 1-0.4, 0, image)
        out_prediction = f"{out_path}/{image_name}_prediction.jpg"
        cv2.imwrite(out_prediction, image)

#### Run OCR on Testset 1

In [6]:
image_path = "Data\W30125"
images = glob(f"{image_path}/*.jpg")
print(f"Image: {len(images)}")

out_path = os.path.join(image_path, "predictions")
create_dir(out_path)

Image: 7


In [7]:
#run patched prediction
for _, image_path in tqdm(enumerate(images), total=len(images)):
    run_ocr(image_path, out_path, save_preview=True)

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [01:03<00:00,  9.06s/it]


#### Run OCR on Testset 2

In [23]:
line_inference = LineDetection(config_file=line_model_config, dilate_kernel=10, dilate_iterations=10, binarize_output=False)

In [24]:
image_path = "Data\W2DB4577"
images = glob(f"{image_path}/*.jpg")
print(f"Image: {len(images)}")

out_path = os.path.join(image_path, "predictions")
create_dir(out_path)

Image: 6


In [25]:
for _, image_path in tqdm(enumerate(images), total=len(images)):
    run_ocr(image_path, out_path, save_preview=True)

100%|██████████| 6/6 [00:56<00:00,  9.46s/it]


#### Run OCR on Testset 3

In [8]:
line_inference = LineDetection(config_file=line_model_config, dilate_kernel=10, dilate_iterations=10, binarize_output=False)

image_path = "Data\W26071-v56"
images = glob(f"{image_path}/*.jpg")
print(f"Image: {len(images)}")

out_path = os.path.join(image_path, "predictions")
create_dir(out_path)

Image: 6
