# Non Max Suppression

**Author:** Alan Meeson <alan@carefullycalculated.co.uk>

**Date:** 2023-07-29

Adapting the non-max and sub set suppression algorithms from my handwriting repo.

In [None]:
import os
import io
import fitz
import torch
import torchvision
import json
import numpy as np
import matplotlib.pyplot as plt
import layoutparser as lp
import pytesseract

from tqdm.notebook import tqdm
from typing import List, Dict, Set, Union
from pyprojroot import here
from PIL import Image

## Declare & Apply the analysis pipeline

In [None]:
model = lp.Detectron2LayoutModel(
    config_path=os.path.join(here(), 'model', 'config.yaml'), 
    model_path=os.path.join(here(), 'model', 'model_final.pth'),
    extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], 
    label_map={0: "Text", 1: "Title", 2: "List", 3: "Table", 4: "Figure"}
)

## Load a paper and display a page

In [None]:
paper_pdf = os.path.join(here(), 'data', 'Conditional-level-of-students-t-test.pdf')

In [None]:
pdf = fitz.open(paper_pdf)

In [None]:
page = pdf[0]
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))

mode = "RGBA" if pix.alpha else "RGB"
img = Image.frombytes(mode, [pix.width, pix.height], pix.samples)

layout = model.detect(img)
lp.draw_box(img, layout, box_width=3)

## Suppress Non Max Regions

There are some instances where we will detect two regions for what is likely the same object.  (See the example below).

When this happens we want to take the the box which we have the greatest confidence in.  This is where Non-Max Suppression comes in.  

The idea here is to take boxes in order of greatest score and as we do so, we exclude any boxes which significantly overlap with them.

In [None]:
lp.draw_box(img, [layout[7], layout[8]], box_width=3)

In [None]:
# non-max supression

# thresh_iou how much we allow two regions to overlap before we filter them out.
thresh_iou = 0.5

# Declare the list of regions we're going to keep
keep = []

candidate_regions = layout.sort(key=lambda x: x.score)
while candidate_regions:

    # extract and keep the region with highest score
    region = candidate_regions.pop()
    keep.append(region)

    # if we have any remaining candidate regions
    if candidate_regions:
        
        # then filter out the ones which significantly overlap with the current region
        # we do this with looking at the ratio of the intersection area over the union area.
        candidate_regions = [
            x
            for x in candidate_regions
            if (max(region.intersect(x).area, 0) / region.union(x).area) < thresh_iou
        ]
         
        

After we apply the above algorithm, we see below that we have eliminated one of the two overlapping sections, retaining the one with the highest score.
Worth noting is that we still have two sections where regions are overlapping, but the overlapping srctions are almost entirely contained.  
To solve these we look at Sub-section suppression.

In [None]:
lp.draw_box(img, keep, box_width=3)

## Sub-section Suppression

Even after we suppress overlapping boxes, we are still often left with situations where we have one box almost entirely as a subset of another. (See below for two examples)

In this case, particularly when we are dealing with regions of the same type, we want to take the largest box.

Here I have adapted the non-max suppression algorithm to choose the regions with the largest area first, and then exclude those which are within them.

In [None]:
lp.draw_box(img, [layout[4], layout[6], layout[5], layout[9]], box_width=3)

In [None]:
sample_layout = lp.Layout(keep)

In [None]:
# sub-section supression
tolerance = 0.1

keep = []

candidate_regions = sample_layout.sort(key=lambda x: x.area)
while candidate_regions:

    # extract and keep the region with largest area
    region = candidate_regions.pop()
    keep.append(region)

    # if we have any remaining candidate regions
    if candidate_regions:

        # We pad the larger region to allow a certain amount of tolerance for
        # being not quite entirely overlapping
        pad_x = (region.width * tolerance) / 2
        pad_y = (region.height * tolerance) / 2
        padded_region = region.pad(left=pad_x, right=pad_x, top=pad_y, bottom=pad_y)
        
        # then filter out the ones which are almost entirely inside the current region
        candidate_regions = [
            x
            for x in candidate_regions
            if not x.is_in(padded_region)
        ]
         
        

In [None]:
lp.draw_box(img, keep, box_width=3)

## Bring it all together

In [None]:
def non_max_suppression(layout: lp.Layout, threshold_iou: float = 0.5) -> lp.Layout:
    """
    Apply non-maximum suppression to avoid detecting too many
    overlapping bounding boxes for a given object.

    For any group of overlapping regions, the one with the highest score from the
    model is kept.
    
    Args:
        layout: (layoutparser.Layout) a Layout generated by layoutparser
        threshold_iou: (float) The overlap thresh for suppressing unnecessary boxes.
    Returns:
        A Layout with the overlapping regions removed.
    """

    # Declare the list of regions we're going to keep
    keep = []

    candidate_regions = layout.sort(key=lambda x: x.score)
    while candidate_regions:

        # extract and keep the region with highest score
        region = candidate_regions.pop()
        keep.append(region)
    
        # if we have any remaining candidate regions
        if candidate_regions:
            
            # then filter out the ones which significantly overlap with the current region
            # we do this with looking at the ratio of the intersection area over the union area.
            candidate_regions = [
                x
                for x in candidate_regions
                if (max(region.intersect(x).area, 0) / region.union(x).area) < threshold_iou
            ]

    return lp.Layout(keep)
    

In [None]:
def sub_section_suppression(layout: lp.Layout, tolerance: float = 0.1) -> lp.Layout:
    """
    Apply sub-section suppression to avoid detecting too many overlapping bounding 
    boxes for a given object.  This specifically removes boxes which are (almost) 
    entirely contained within another box.
    
    Args:
        layout: (layoutparser.Layout) a Layout generated by layoutparser
        tolerance: (float) how much of a box can be not within the larger box, and 
            still count as overlapping.
    Returns:
        A Layout with the overlapping regions removed.
    """

    keep = []
    
    candidate_regions = layout.sort(key=lambda x: x.area)
    while candidate_regions:
    
        # extract and keep the region with largest area
        region = candidate_regions.pop()
        keep.append(region)
    
        # if we have any remaining candidate regions
        if candidate_regions:
    
            # We pad the larger region to allow a certain amount of tolerance for
            # being not quite entirely overlapping
            pad_x = (region.width * tolerance) / 2
            pad_y = (region.height * tolerance) / 2
            padded_region = region.pad(left=pad_x, right=pad_x, top=pad_y, bottom=pad_y)
            
            # then filter out the ones which are almost entirely inside the current region
            candidate_regions = [
                x
                for x in candidate_regions
                if not x.is_in(padded_region)
            ]

    return lp.Layout(keep)


### Apply to all pages in PDF to get a nice size by size view

In [None]:
for page in pdf:
    
    pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
    
    mode = "RGBA" if pix.alpha else "RGB"
    img = Image.frombytes(mode, [pix.width, pix.height], pix.samples)
    
    layout = model.detect(img)
    before_img = lp.draw_box(img, layout, box_width=3)

    nms_layout = non_max_suppression(layout)
    nms_img = lp.draw_box(img, nms_layout, box_width=3)

    sss_layout = sub_section_suppression(nms_layout)
    sss_img = lp.draw_box(img, sss_layout, box_width=3)

    f, axarr = plt.subplots(1,3, figsize=(24,8))
    axarr[0].imshow(before_img)
    axarr[0].set_axis_off()
    axarr[1].imshow(nms_img)
    axarr[1].set_axis_off()
    axarr[2].imshow(sss_img)
    axarr[2].set_axis_off()
    plt.show()

    