In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import cv2
import json
import torch
import glob
import torchvision
import numpy as np
from thefuzz import fuzz
from torchvision import io
from pytesseract import Output
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from dataclasses import dataclass
from pytesseract import pytesseract
import torchvision.ops.boxes as bops
from shapely.geometry import Polygon
from torch.utils.data import Dataset, DataLoader

PATH_2_TESSERACT = "/usr/local/bin/tesseract"
pytesseract.tesseract_cmd = PATH_2_TESSERACT

In [None]:
# jupyter notebook formatting
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
HTML("""
<style>
.output_png {
    display: table-cell;
    text-align: center;
    vertical-align: middle;
}
</style>
""")

### Classes

In [None]:
@dataclass
class OcrBbox:
    """Class for keeping track of an OCR bounding box."""
    bbox: tuple  # 4-tuple with (x_left, y_upper, x_right, y_lower)
    text: str  # the text in the prediciton/ground truth
    confidence: int = None  # confidence of prediction (leave to None for ground truth)

In [None]:
class FUNSDDataset(Dataset):
    def __init__(self, images_path, annotations_path):
        super().__init__()
        self.image_paths = sorted(glob.glob(os.path.join(images_path, "*.png")))
        self.annotation_paths = sorted(glob.glob(os.path.join(annotations_path, "*.json")))
        
        # verify corectness of data
        assert len(self.image_paths) == len(self.annotation_paths), "ERROR: Must have 1-to-1 image-to-annotation correspondence"
        image_indexes = np.array([os.path.basename(t).split('.')[0] for t in self.image_paths])
        annotation_indexes = np.array([os.path.basename(t).split('.')[0] for t in self.annotation_paths])
        assert np.all(image_indexes == annotation_indexes), "ERROR: At least one image does not have an annotation JSON file."
        
    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # read bounding boxes with their corresonding text predictions
        json_data = json.load(open(self.annotation_paths[idx]))
        bboxes = [OcrBbox(text=w["text"], bbox=tuple(w["box"])) for p in json_data["form"] for w in p["words"]]
        return {"image": image, "bboxes": bboxes}
    
    def __len__(self):
        return len(self.image_paths)

### Functions

In [None]:
def run_tesseract(img):
    """
    Run Google Tesseract OCR on the given image and return all the information provided by the OCR system.
    :param img: np array; RGB format image
    :return: a dictionary with the following keys:
        - "text": list of strings; the words predicted by the OCR system
        - "confidence": list of floats: confidence for each predicted word, provided by the OCR system
        - "box": list of 4-tuples containing the bounding box in the format (x0, y0, x1, y1). Defines the
        upper left and lower right points
    """
    d = pytesseract.image_to_data(img, output_type=Output.DICT)
    texts, confidences, boxes = [], [], []
    num_boxes = len(d['level'])
    for i in range(num_boxes):
        text = d['text'][i]
        conf = d['conf'][i] / 100. if d['conf'][i] >= 0 else d['conf'][i]  # normalized confidence
        (x, y, w, h) = (d['left'][i], d['top'][i], d['width'][i], d['height'][i])
        (x_left, y_top, x_right, y_down) = (x, y, x + w, y + h)

        texts.append(text)
        confidences.append(conf)
        boxes.append((x_left, y_top, x_right, y_down))
    return {"text": texts, "confidence": confidences, "box": boxes}

In [None]:
def tesseract_output_2_ocr_box(tesseract_out):
    """Convert Google Tesseract OCR output to custom OcrBox format."""
    return [OcrBbox(bbox=b, text=t, confidence=c) 
            for t, c, b in zip(tesseract_out["text"], tesseract_out["confidence"], tesseract_out["box"])]

In [None]:
def coco_bbox_to_polygon(bbox):
    """Convert a COCO-format bounding box (defined by 2 points: upper-left and lower-right) to a shapely polygon."""
    assert len(bbox) == 4
    upper_left = bbox[:2]
    lower_right = bbox[2:]
    upper_right = [lower_right[0], upper_left[1]]
    lower_left = [upper_left[0], lower_right[1]]
    return Polygon(shell=[upper_left, upper_right, lower_right, lower_left])

In [None]:
def calculate_bbox_iou(bb1, bb2):
    """
    Vectorized implementation which calculates the IoU between two lists of bounding boxes.
    
    Parameters
    ----------
    bb1 : list (of length N) of 4-tuples
        Format : [(x1, y1, x2, y2), b2, b3, ...]
        The (x1, y1) position is at the top left corner,
        the (x2, y2) position is at the bottom right corner
    bb2 : list (of length M) of 4-tuples
        Format : [(x1, y1, x2, y2), b2, b3, ...]
        The (x, y) position is at the top left corner,
        the (x2, y2) position is at the bottom right corner

    Returns
    -------
    PyTorch tensor of shape NxM. Each entry (i, j) contains the IoU between
    bounding box i from bb1 and bounding box j from bb2.
    """
    box1 = torch.tensor(bb1, dtype=torch.float)
    box2 = torch.tensor(bb2, dtype=torch.float)
    iou = bops.box_iou(box1, box2)
    return iou

In [None]:
def match_pred_to_gt_bboxes(gt_bboxes, pred_bboxes):
    """
    Match a list of predicted bounding boxes to a list of ground-truth bounding boxes 1-to-1. 
    
    Parameters
    ----------
    gt_bboxes : list (of length M) of ground-truth OcrWord bounding boxes
    pred_bboxes : list (of length N) of predicted OcrWord bounding boxes

    Returns
    -------
    List of dictionaries. This represents the matching from GT to PRED in OCR rectangles.
        Keys : {"gt", "pred", "iou"}
        The value at the "gt" key contains a ground-truth OcrWord bounding box
        The value at the "pred" key contains a predicted OcrWord bounding box
        The value at the "iou" key contains the IoU between these two matched bounding boxes
    """
    gt_bboxes_points = [b.bbox for b in gt_bboxes]
    pred_bboxes_points = [b.bbox for b in pred_bboxes]
    iou_matrix = calculate_bbox_iou(gt_bboxes_points, pred_bboxes_points)
    max_iou, max_iou_index = torch.max(iou_matrix, dim=1)
    matched_boxes = [{"gt": gt_bboxes[gt_idx], "pred": pred_bboxes[pred_idx], "iou": iou.item()} 
                     for gt_idx, (pred_idx, iou) in enumerate(zip(max_iou_index, max_iou))]
    return matched_boxes

In [None]:
def pre_process_text(text):
    """
    Pre-process a text before using it for OCR evaluation (e.g., use lower 
    case letters since knowing which letters are capitalized is not vital information)
    """
    text = text.lower()
    return text

In [None]:
def evaluate_ocr_system(dset, ocr_fn):
    """
    Run a given OCR system on all the images in the given dataset and evaluate the performance
    of the OCR system.
    
    Parameters
    ----------
    dset : PyTorch dataset, items should have the following keys: 
        "image": NumPy array, RGB format, channels last representation
        "bboxes": list of OcrBbox instances; The ground-truth bounding boxes from the given image
    ocr_fn : function which runs an OCR system on a image. Should return a dictionary with the
    following keys:
        "text": list of strings with the ground truth text from a bounding box
        "box": list of tuples of 4 integers (so 2 2D points) representing the rectangle of the bounding box
        "confidence": list of floats representing the confidence of the i'th detection
        All the lists from the above keys have the same length.
        
    Returns
    -------
    List of lists. Each element list l_i contains the similarity scores from the ground-truth
    bounding boxes in the i'th image.
    """
    similarities = []
    for dset_item in tqdm(dset):
        # get the current image and its GT bounding boxes
        image = dset_item["image"]
        bboxes_gt = dset_item["bboxes"]
        
        # run the OCR model on the image
        bboxes_dict = ocr_fn(image)
        bboxes_pred = tesseract_output_2_ocr_box(bboxes_dict)
        
        # match the GT bounding boxes with the PRED bounding boxes, apply filter, and sort 
        matching = match_pred_to_gt_bboxes(gt_bboxes=bboxes_gt, pred_bboxes=bboxes_pred)
        filtered_matching = [m for m in matching if len(m["gt"].text.strip()) and len(m["pred"].text.strip())]
        
        img_level_similarities = []
        for m in filtered_matching:
            gt_bbox = m["gt"]
            pred_bbox = m["pred"]
            iou = m["iou"]
            
            gt_text, pred_text = gt_bbox.text, pred_bbox.text
            gt_text, pred_text = pre_process_text(gt_text), pre_process_text(pred_text)
            
            # consider similarity ratio only if the IoU is above a threshold and thus the GT text was recognized
            img_level_similarities.append(fuzz.ratio(pred_text.lower(), gt_text.lower()) / 100. if iou > IOU_THRESHOLD else 0)
        similarities.append(img_level_similarities)
    return similarities

### Dataset and constants

In [None]:
funsd_train_dir = "/Users/bogdan.vlad/Documents/code/mac_notebooks/personal/ocr_data/funsd_dataset/training_data"
funsd_test_dir = "/Users/bogdan.vlad/Documents/code/mac_notebooks/personal/ocr_data/funsd_dataset/testing_data"

train_dset = FUNSDDataset(images_path=os.path.join(funsd_train_dir, "images"), 
                          annotations_path=os.path.join(funsd_train_dir, "annotations"))
test_dset = FUNSDDataset(images_path=os.path.join(funsd_test_dir, "images"), 
                          annotations_path=os.path.join(funsd_test_dir, "annotations"))

IOU_THRESHOLD = 0.25

### Visualize ground truth bounding boxes vs. predicted bounding boxes

In [None]:
dset_item = train_dset[2]
image = dset_item["image"]
bboxes_gt = dset_item["bboxes"]

# add gt bounding boxes
copy_gt = np.copy(image)
for bbox in bboxes_gt:
    cv2.rectangle(copy_gt, bbox.bbox[:2], bbox.bbox[2:], (0, 255, 0), 1)
    
# add pred bounding boxes
bboxes_dict = run_tesseract(image)
bboxes_pred = tesseract_output_2_ocr_box(bboxes_dict)
bboxes_pred = [b for b in bboxes_pred if len(b.text.strip())]
copy_pred = np.copy(image)
for bbox in bboxes_pred:
    cv2.rectangle(copy_pred, bbox.bbox[:2], bbox.bbox[2:], (0, 255, 0), 1)

In [None]:
plt.figure(figsize=(24, 8))

plt.subplot(1, 3, 1)
plt.imshow(image)
plt.axis("off")
plt.title("Original")

plt.subplot(1, 3, 2)
plt.imshow(copy_gt)
plt.axis("off")
plt.title("GT bounding boxes")

plt.subplot(1, 3, 3)
plt.imshow(copy_pred)
plt.axis("off")
plt.title("PRED bounding boxes")

plt.show()

### Visualize how the matching algorithm works

In [None]:
dset_item = train_dset[1]
image = dset_item["image"]
bboxes_gt = dset_item["bboxes"]

# run the OCR model on the image
bboxes_dict = run_tesseract(image)
bboxes_pred = tesseract_output_2_ocr_box(bboxes_dict)

# match the GT bounding boxes with the PRED bounding boxes, apply filter, and sort 
matching = match_pred_to_gt_bboxes(gt_bboxes=bboxes_gt, pred_bboxes=bboxes_pred)
filtered_matching = [m for m in matching if len(m["gt"].text.strip()) and len(m["pred"].text.strip())]  # remove spaces and tabs
filtered_matching = sorted(filtered_matching, key=lambda x: x["iou"], reverse=True)

for m in filtered_matching:
    gt_bbox = m["gt"]
    pred_bbox = m["pred"]
    iou = m["iou"]

    gt_points, pred_points = gt_bbox.bbox, pred_bbox.bbox
    gt_text, pred_text = gt_bbox.text, pred_bbox.text
    gt_text, pred_text = pre_process_text(gt_text), pre_process_text(pred_text)
    
    crop_gt = image[gt_points[1] : gt_points[3] + 1, gt_points[0] : gt_points[2] + 1, :]
    crop_pred = image[pred_points[1] : pred_points[3] + 1, pred_points[0] : pred_points[2] + 1, :]
    
    plt.figure(figsize=(4, 2))
    ax1, ax2 = plt.subplot(1, 2, 1), plt.subplot(1, 2, 2)
    
    ax1.imshow(crop_gt)
    ax1.axis("off")
    ax1.set_title("GT Bounding Box", fontdict={"fontsize": 10})

    ax2.imshow(crop_pred)
    ax2.axis("off")
    ax2.set_title("PRED Bounding Box", fontdict={"fontsize": 10})
    
    plt.subplots_adjust(top=0.6, wspace=0.5)
    plt.suptitle(f'GT: "{gt_text}"\nPRED: "{pred_text}"\nSIMILARITY: {(fuzz.ratio(gt_text, pred_text) / 100. if iou > IOU_THRESHOLD else 0):.2f}', fontsize=11)
    plt.show()

### Evaluate the OCR system on the original images

In [None]:
similarities = evaluate_ocr_system(test_dset, run_tesseract)
similarities_flattened = [s for similarity_list in similarities for s in similarity_list]  # flatten to list of word-level similarities
print(f"Mean similarity score: {np.mean(similarities_flattened):.2f}")
print(f"Median similarity score: {np.median(similarities_flattened):.2f}")