# End to End PDF Annotation Demo

**Author:** Alan Meeson <alan@carefullycalculated.co.uk>

**Date:** 2023-04-16

This notebook applies the custom trained handwriting region detection model (based on Fast R CNN), and the TROCR handwriting OCR model to detect and parse handwriting.

In addition we have Non-Max Suppression, and a custom removal of sub-set regions to identify the key text regions to process.

The output is a json file of regions coordinates and parsed text, also an annotated PDF with invisitext of the parsed text overlaid ontop of the original handwriting.

## Setup

In [None]:
import os
import json
import numpy as np
import fitz
import cv2
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms.functional as F
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

## Declare functions

In [None]:
def get_text_region_model(
    num_classes: int = 1, weights_file: str = None, for_eval: bool = True
) -> torch.nn.Module:

    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    if weights_file:
        device = None

        if torch.cuda.is_available():
            # We're using GPU
            device = torch.device("cuda")
            model.load_state_dict(torch.load(weights_file))
            model.to(device)
        else:
            # We're using CPU
            device = torch.device('cpu')
            model.load_state_dict(torch.load(weights_file, map_location=device))

        if for_eval:
            model.eval()

    return model, device

In [None]:
def nms_pytorch(boxes : torch.tensor, scores: torch.tensor, thresh_iou : float=0.5) -> torch.tensor:
    """
    Apply non-maximum suppression to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the image [num_boxes, 4]
        scores: (tensor) The score of each box [num_boxes, 1]
        thresh_iou: (float) The overlap thresh for suppressing unnecessary boxes.
    Returns:
        A list of filtered boxes, Shape: [ , 4]
    """
 
    #TODO: replace this with the torchvision version
    # https://pytorch.org/vision/master/generated/torchvision.ops.nms.html
    
    # we extract coordinates for every 
    # prediction box present in P
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
 
    # calculate area of every block in P
    areas = (x2 - x1) * (y2 - y1)
     
    # sort the prediction boxes in P
    # according to their confidence scores
    order = scores.argsort()
 
    # initialise an empty list for 
    # filtered prediction boxes
    keep = []
     
 
    while len(order) > 0:
         
        # extract the index of the 
        # prediction with highest score
        # we call this prediction S
        idx = order[-1]
 
        # push S in filtered predictions list
        keep.append([x1[idx], y1[idx], x2[idx], y2[idx], scores[idx]])
 
        # remove S from P
        order = order[:-1]
 
        # sanity check
        if len(order) == 0:
            break
         
        # select coordinates of BBoxes according to 
        # the indices in order
        xx1 = torch.index_select(x1,dim = 0, index = order)
        xx2 = torch.index_select(x2,dim = 0, index = order)
        yy1 = torch.index_select(y1,dim = 0, index = order)
        yy2 = torch.index_select(y2,dim = 0, index = order)
 
        # find the coordinates of the intersection boxes
        xx1 = torch.max(xx1, x1[idx])
        yy1 = torch.max(yy1, y1[idx])
        xx2 = torch.min(xx2, x2[idx])
        yy2 = torch.min(yy2, y2[idx])
 
        # find height and width of the intersection boxes
        w = xx2 - xx1
        h = yy2 - yy1
         
        # take max with 0.0 to avoid negative w and h
        # due to non-overlapping boxes
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
 
        # find the intersection area
        inter = w*h
 
        # find the areas of BBoxes according the indices in order
        rem_areas = torch.index_select(areas, dim = 0, index = order) 
 
        # find the union of every prediction T in P
        # with the prediction S
        # Note that areas[idx] represents area of S
        union = (rem_areas - inter) + areas[idx]
         
        # find the IoU of every prediction in P with S
        IoU = inter / union
 
        # keep the boxes with IoU less than thresh_iou
        mask = IoU < thresh_iou
        order = order[mask]
     
    return keep

In [None]:
def sss_pytorch(boxes : torch.tensor, thresh_overlap : float=0.9) -> torch.tensor:
    """
    Apply sub-set suppression to avoid detecting too many
    overlapping bounding boxes for a given object.
    This specifically removes boxes which are (almost) entirely contained within 
    another box.
    
    Args:
        boxes: (tensor) The location preds for the image and scores [num_boxes, 5]
        thresh_overlap: (float) The overlap thresh for suppressing unnecessary boxes.
    Returns:
        A list of filtered boxes, Shape: [ , 5]
    """

    # TODO: Can we tidy this up?

    # we extract coordinates for every 
    # prediction box present in P
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
 
    # calculate area of every block in P
    areas = (x2 - x1) * (y2 - y1)
     
    # sort the prediction boxes in P
    # according to their area
    order = areas.argsort()
 
    # initialise an empty list for 
    # filtered prediction boxes
    keep = torch.as_tensor([False] * boxes.shape[0])
    
    if len(order) > 0:
        idx = order[-1]
        keep[idx] = True
        order = order[:-1]
 
    while len(order) > 0:
         
        # extract the index of the 
        # prediction with highest score
        # we call this prediction S
        idx = order[-1]
        order = order[:-1]
 
        # sanity check
        if len(order) == 0:
            break
         
        # select coordinates of BBoxes according to 
        # the indices in order
        xx1 = x1[keep]
        xx2 = x2[keep]
        yy1 = y1[keep]
        yy2 = y2[keep]
 
        # find the coordinates of the intersection boxes
        xx1 = torch.max(xx1, x1[idx])
        yy1 = torch.max(yy1, y1[idx])
        xx2 = torch.min(xx2, x2[idx])
        yy2 = torch.min(yy2, y2[idx])
 
        # find height and width of the intersection boxes
        w = xx2 - xx1
        h = yy2 - yy1
         
        # take max with 0.0 to avoid negative w and h
        # due to non-overlapping boxes
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
 
        # find the intersection area
        inter = w*h
 
        overlap = inter / areas[idx]

        keep[idx] = all(overlap < thresh_overlap)
     
    return boxes[keep, :]

In [None]:
def parse_image(model, device, image: Image):
    """Applies the text region detection to an Image.
    
    Args:
        model: the model to apply
        device: the device the model is on, so we can send image there.
        image: the PIL image to process
    Returns:
        A list of dicts containing class label, confidence score and coordinates.
    """
    
    # TODO: refactor,  device should probably not be in here.
    #  Perhaps wrap this up in a class?
    
    # Prepare image and apply model
    pil_image = F.pil_to_tensor(image) / 255
    cuda_pil_image = pil_image.to(device)
    result = model([cuda_pil_image])
    
    # Apply Non-Max Suppression
    target_class = 1
    mask = result[0]['labels'] == target_class
    target_boxes = result[0]['boxes'][mask,:]
    target_scores = result[0]['scores'][mask]
    if target_boxes.shape[0] > 0:
        filtered_boxes = nms_pytorch(target_boxes, target_scores, thresh_iou=0.25)
        filtered_boxes = torch.as_tensor(filtered_boxes)
        filtered_boxes = sss_pytorch(filtered_boxes, thresh_overlap = 0.9)
    else:
        filtered_boxes = target_boxes
    
    num_boxes = filtered_boxes.shape[0]
    
    # reformat results
    #label = result[0]['labels'][mask]
    results = [
        {
            'label': target_class,
            'score': float(filtered_boxes[idx, 4]),
            'box': {
                'x1': int(filtered_boxes[idx,0]),
                'y1': int(filtered_boxes[idx,1]),
                'x2': int(filtered_boxes[idx,2]),
                'y2': int(filtered_boxes[idx,3])
            }
        }
        for idx in range(num_boxes)
    ]
    return results

In [None]:
def apply_ocr(ocr_model, processor, image, boxes, device=None):
    """Applies the TROCR model to the image.
    
    Args:
        ocr_model - the trocr model to apply
        processor - required preprocessing steps
        image - the image to process
        boxes - the list of dicts describing the identified text regions
        device - the device the model is on
    Return:
        The list of boxes, with an added field for the parsed text.
    """
    
    # TODO: refactor
    
    for box in boxes:
        x1 = box['box']['x1']
        y1 = box['box']['y1']
        x2 = box['box']['x2']
        y2 = box['box']['y2']
        region_of_interest = image.crop((x1, y1, x2, y2))

        pixel_values = processor(images=region_of_interest, return_tensors="pt").pixel_values
        if device:
            pixel_values = pixel_values.to(device)
        
        generated_ids = ocr_model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
        box['text'] = generated_text
    
    return boxes

In [None]:
def add_text_annotations(page, results):
    """Annotates a PDF page with the parsed text.
    
    Args:
        page - the PyMuPDF page to annotate
        results - the boxes and text to add
    """
    tw = fitz.TextWriter(page.rect)  # need the intended page's size here

    font = fitz.Font("helv")
    
    for box in results:
        # Note: the /2 is to account for the fact that the parsed
        # boxes were found from a page with zoom=2.
        # TODO: fix this nasty hack.
        pos = (box['box']['x1']/2, box['box']['y2']/2)
        text = box['text'][0]
        fontsize = (box['box']['y2'] - box['box']['y1']) / (2*2)
        
        # for each text piece (a word, a string, a character, ... everything goes)
        tw.append(
            pos,  # the insertion point
            text,
            font=font,  # a fitz.Font(...) object
            fontsize=fontsize,
        )
    
    # ... repeat the above with arbitrary other fonts / fontsizes, when done:
    tw.write_text(page, render_mode=3)  # write the whole text writer as hidden (render mode 3) text.

## Declare Data

In [None]:
input_file = '../data/raw/demo.pdf'
model_file = '../models/handwriting_region.pth'

## Load data and Models

In [None]:
# TODO: download seperately to allow no-internet deploy
td_model, device = get_text_region_model(num_classes = 2, weights_file = model_file)
ocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
ocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
ocr_model.to(device)
ocr_model.eval()

## Apply the models as a Demo

In [None]:
input_doc = fitz.open(input_file)
page = input_doc[0]
zoom = 2    # zoom factor
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat)
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
results = parse_image(td_model, device, image)
results

### Plot the detected text regions

In [None]:
num_boxes = len(results)

img = np.array(image)

for box in results:
    x1 = box['box']['x1']
    y1 = box['box']['y1']
    x2 = box['box']['x2']
    y2 = box['box']['y2']
    
    cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color=(0,0,255),thickness=3)

plt.figure(figsize=(16,8))
plt.imshow(img)

### Apply the OCR step and visualise

In [None]:
for box in results:
    x1 = box['box']['x1']
    y1 = box['box']['y1']
    x2 = box['box']['x2']
    y2 = box['box']['y2']
    region_of_interest = image.crop((x1, y1, x2, y2))
    
    pixel_values = ocr_processor(images=region_of_interest, return_tensors="pt").pixel_values
    cuda_pix_values = pixel_values.to(device)
    generated_ids = ocr_model.generate(cuda_pix_values)
    generated_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)
    box['text'] = generated_text

In [None]:
num_samples = len(results)
f, axarr = plt.subplots(len(results),1, figsize=(8,num_samples))

sorted_index = np.argsort([box['box']['y1'] for box in results])
for ax, idx in enumerate(sorted_index):
    box = results[idx]
    x1 = box['box']['x1']
    y1 = box['box']['y1']
    x2 = box['box']['x2']
    y2 = box['box']['y2']
    region_of_interest = image.crop((x1, y1, x2, y2))
    
    axarr[ax].title.set_text(box['text'])
    axarr[ax].imshow(region_of_interest)
    axarr[ax].axis('off')

plt.show()

## Apply to a PDF

In [None]:
output_path = '../data/output'
output_pdf = True
output_json = True

In [None]:
basename = os.path.splitext(os.path.basename(input_file))[0]

input_doc = fitz.open(input_file)
output_doc = None
if output_pdf:
    output_doc = fitz.open()

results = []
zoom = 2    # zoom factor
mat = fitz.Matrix(zoom, zoom)

for page_no, page in enumerate(input_doc):
    pix = page.get_pixmap(matrix=mat)
    image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    page_results = parse_image(td_model, device, image)
    page_results = apply_ocr(ocr_model, ocr_processor, image, page_results, device)
    results.append(page_results)

    if output_pdf:
        out_page = output_doc.new_page(width=page.rect.width, height=page.rect.height)
        out_page.show_pdf_page(page.rect, input_doc, page_no)
        add_text_annotations(out_page, page_results)
        
    print("Processed Page: %d" % page_no)

if output_json:
    output_json_file = os.path.join(output_path, basename + '.json')
    with open(output_json_file, 'w') as fp:
        json.dump(results, fp)

if output_pdf:
    output_pdf_file = os.path.join(output_path, basename + '.pdf')
    output_doc.save(output_pdf_file)