In [None]:
"""
tk_detect_rows.ipynb

Created on Wed Nov 09 15:01:03 2022

@author: Lukas

This script is used to detect rows with text given a deskewed Teikoku page.
"""

# install Pytorch and Detectron2

!pip install -U torch==1.5 torchvision==0.6 -f https://download.pytorch.org/whl/cu101/torch_stable.html
!pip install cython pyyaml==5.1
!pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
!pip install detectron2==0.1.3 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/index.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://download.pytorch.org/whl/cu101/torch_stable.html
Collecting torch==1.5
  Downloading https://download.pytorch.org/whl/cu101/torch-1.5.0%2Bcu101-cp37-cp37m-linux_x86_64.whl (703.8 MB)
[K     |████████████████████████████████| 703.8 MB 20 kB/s 
[?25hCollecting torchvision==0.6
  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.6.0%2Bcu101-cp37-cp37m-linux_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 55.0 MB/s 
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.12.1+cu113
    Uninstalling torch-1.12.1+cu113:
      Successfully uninstalled torch-1.12.1+cu113
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.13.1+cu113
    Uninstalling torchvision-0.13.1+cu113:
      Successfully uninstalled torchvision-0.13.1+cu113
[31mERR

In [None]:
# import packages

import torch, torchvision
import detectron2

from detectron2.utils.visualizer import ColorMode
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.visualizer import Visualizer
from detectron2.engine import DefaultPredictor
from detectron2.utils.logger import setup_logger
from detectron2.structures import BoxMode

import os
import numpy as np
import json
import cv2
import matplotlib.pyplot as plt
import pickle

from skimage import io
from skimage.transform import resize
from skimage import img_as_bool

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# load the model

filename = '/content/drive/MyDrive/row_detection_with_blanks/comb_output/config.pkl'

with open(filename, 'rb') as f:
     cfg = pickle.load(f)
     
predictor = DefaultPredictor(cfg)

In [None]:
# apply the model to a given input image and return the bounding boxes

def get_text_rows(image):
    """
    This function takes an image as input and returns the bounding boxes of the rows with text.
    
    Parameters
    ----------

    image : string
        Path to the image.

    Returns
    -------

    boxes : list
        List of bounding boxes. Each bounding box is a list of four coordinates.

    """
    image = cv2.imread(image)
    outputs = predictor(image)
    boxes = []

    for (box, category) in zip(outputs["instances"].pred_boxes.to('cpu'), outputs["instances"].pred_classes.to('cpu')):
        if category == 4:
            boxes.append([int(x) for x in box])

    print(boxes)
    return boxes

In [None]:
# convert from input format to output format [center, size, angle]

def convert_to_output_format(box):
    """
    This function converts the output format of the Faster R-CNN model to the required output format.

    Parameters
    ----------

    boxes : list
        List of bounding boxes in the format [[x1, y1, x2, y2], ...] where 
        (x1, y1) is top left corner and (x2, y2) is the bottom right corner.

    Returns
    -------

    output : list
        List of bounding boxes in the format [[center, size, angle], ...]

    """
    x1, y1, x2, y2 = box
    center = [(x1 + x2) / 2, (y1 + y2) / 2]
    size = [x2 - x1, y2 - y1]
    angle = 0
    output = [center, size, angle]
    return output

In [None]:
# auxilliary function that takes an images and returns the parameters for all rectangles in a list

def get_rect_json(image):
    """
    This function takes an image and returns the rectangle parameters for all rectangles in a list.
    
    Parameters
    ----------
    
    image : numpy array
        The image to be processed.
    
    Returns
    -------
    
    rect_json : list
        The rectangle parameters for all rectangles in a list.
    
    """
    boxes = get_text_rows(image)
    rect_json = []
    for box in boxes:
        # rect = convert_to_output_format(box)
        rect_json.append(box)
      
    print(rect_json)
    return rect_json

In [None]:
# main function that applies get_rect_json to all images in a folder and saves the output in a json file

def main(image_path, output_path):
    """
    This function applies get_rect_json to all images in a folder and saves the output in a json file.
    
    Parameters
    ----------
    
    image_path : string
        The file path of the input image directory.
        
    output_path : string
        The file path of the output directory.
        
    Returns
    -------
    
    None (json files are saved in the output folder)
    
    """
    for image_name in os.listdir(image_path):
        print(image_name)
        name = image_name.split(".")[0]
        rect_json = get_rect_json(os.path.join(image_path, image_name))
        if rect_json == []:
            continue
        else:
            with open(os.path.join(output_path, name + '.json'), 'w') as f:
                json.dump(rect_json, f)
                f.close()

In [None]:
image_path = '/content/drive/MyDrive/deskew_scans_pipeline/Deskewed_Scans'

output_path = '/content/drive/MyDrive/row_detection_model/test_output(11 11 22)'

main(image_path, output_path)

dk_TK1935_90_3678_0.png
[[14, 446, 1497, 852], [14, 1233, 1497, 1636], [6, 26, 1494, 458], [4, 842, 1498, 1245], [0, 1624, 1490, 2048]]
[[14, 446, 1497, 852], [14, 1233, 1497, 1636], [6, 26, 1494, 458], [4, 842, 1498, 1245], [0, 1624, 1490, 2048]]
dk_TK1935_90_3678_1.png
[[13, 1215, 1488, 1626], [13, 427, 1488, 836], [0, 820, 1504, 1226], [0, 13, 1489, 440], [9, 1613, 1496, 2034]]
[[13, 1215, 1488, 1626], [13, 427, 1488, 836], [0, 820, 1504, 1226], [0, 13, 1489, 440], [9, 1613, 1496, 2034]]
dk_TK1934_1247_c76a_0.png
[]
[]
dk_TK1934_1247_c76a_1.png
[[2, 275, 1520, 532], [3, 21, 1496, 285], [4, 526, 1515, 782], [0, 1782, 1502, 2047], [0, 1532, 1494, 1785], [0, 777, 1513, 1033], [0, 1280, 1502, 1534], [0, 1028, 1496, 1282]]
[[2, 275, 1520, 532], [3, 21, 1496, 285], [4, 526, 1515, 782], [0, 1782, 1502, 2047], [0, 1532, 1494, 1785], [0, 777, 1513, 1033], [0, 1280, 1502, 1534], [0, 1028, 1496, 1282]]
dk_TK1934_870_0ee9_0.png
[[0, 17, 1498, 364], [0, 1347, 1497, 1696], [2, 1020, 1489, 1359], 

In [None]:
# for each rectangle in an input json, check if its center is included in another rectangle

def check_overlap(json_path):
    """
    This function for each rectangle in an input json, checks if its center is included in another rectangle.
    
    Parameters
    ----------
    
    json_path : string
        The file path of the json file.
        
    Returns
    -------
    
    None (error message is displayed if overlap is detected)
    
    """
    with open(json_path, 'r') as f:
        rect_json = json.load(f)
        f.close()
    for rect in rect_json:
        box = cv2.boxPoints(rect)
        box = np.int0(box)
        center = np.mean(box, axis=0)
        for rect2 in rect_json:
            box2 = cv2.boxPoints(rect2)
            box2 = np.int0(box2)
            if cv2.pointPolygonTest(box2, (center[0], center[1]), False) == 1:
                raise ValueError("Overlap detected")
    return