<a href="https://colab.research.google.com/github/Tessellate-Imaging/Monk_Object_Detection/blob/master/application_model_zoo/Example%20-%20Document%20Layout%20Analysis%20(SSD512).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Document Layout Analysis Using SSD

## About the network:
1. Paper on SSD: https://arxiv.org/abs/1512.02325

2. Blog-1 on SSD: https://towardsdatascience.com/review-ssd-single-shot-detector-object-detection-851a94607d11

3. Blog-2 on SSD: https://medium.com/@jonathan_hui/ssd-object-detection-single-shot-multibox-detector-for-real-time-processing-9bd8deac0e06

# Table of Contents

### 1. Installation Instructions
### 2. Use trained Model for Document Layout Analysis
### 3. How to train using PRImA Layout Analysis Dataset

# Installation

- Run these commands

    - git clone https://github.com/Tessellate-Imaging/Monk_Object_Detection.git

    - cd Monk_Object_Detection/1_gluoncv_finetune/installation

- Select the right requirements file and run

    - cat requirements_cuda10.1.txt | xargs -n 1 -L 1 pip install

In [None]:
! git clone https://github.com/Tessellate-Imaging/Monk_Object_Detection.git

In [None]:
# For colab use the command below
#! cd Monk_Object_Detection/1_gluoncv_finetune/installation && cat requirements_colab.txt | xargs -n 1 -L 1 pip install


# For Local systems and cloud select the right CUDA version
!cd Monk_Object_Detection/1_gluoncv_finetune/installation && cat requirements_cuda10.1.txt | xargs -n 1 -L 1 pip install

# Use Already Trained Model for Demo

In [None]:
import os
import sys
sys.path.append("Monk_Object_Detection/1_gluoncv_finetune/lib/");

In [None]:
from inference_prototype import Infer

In [None]:
#Download trained model

In [None]:
! wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1E6T7RKGwy-v1MUxVJm-rxt5XcRyr2SQ7' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1E6T7RKGwy-v1MUxVJm-rxt5XcRyr2SQ7" -O obj_dla_ssd512_trained.zip && rm -rf /tmp/cookies.txt

In [None]:
! unzip -qq obj_dla_ssd512_trained.zip

In [None]:
model_name = "ssd_512_vgg16_atrous_coco";
params_file = "dla_ssd512/dla_ssd512-vgg16.params";
class_list = ["paragraph", "heading", "credit", "footer", "drop-capital", "floating", "noise", "maths", "header", "caption", "image", "linedrawing", "graphics", "fname", "page-number", "chart", "separator", "table"];

In [None]:
gtf = Infer(model_name, params_file, class_list, use_gpu=True);

In [None]:
img_name = "Test_Images/test1.jpg"; 
visualize = True;
thresh = 0.3;
output = gtf.run(img_name, visualize=visualize, thresh=thresh);

In [None]:
img_name = "Test_Images/test2.jpg"; 
visualize = True;
thresh = 0.3;
output = gtf.run(img_name, visualize=visualize, thresh=thresh);

In [None]:
img_name = "Test_Images/test3.jpg"; 
visualize = True;
thresh = 0.4;
output = gtf.run(img_name, visualize=visualize, thresh=thresh);

# Train Your Own Model

## Dataset Credits
- https://www.primaresearch.org/datasets/Layout_Analysis

In [None]:
#Download Dataset

In [None]:
! wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1iBfafT1WHAtKAW0a1ifLzvW5f0ytm2i_' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1iBfafT1WHAtKAW0a1ifLzvW5f0ytm2i_" -O PRImA_Layout_Analysis_Dataset.zip && rm -rf /tmp/cookies.txt

In [None]:
! unzip -qq PRImA_Layout_Analysis_Dataset.zip

# Data Preprocessing

### Library for Data Augmentation
Refer to https://github.com/albumentations-team/albumentations for more details

In [None]:
! pip install albumentations

In [None]:
import os
import sys
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import albumentations as A
import glob
import matplotlib.pyplot as plt
import xmltodict
import json
from tqdm.notebook import tqdm
from pycocotools.coco import COCO

In [None]:
root_dir = "PRImA Layout Analysis Dataset/";
img_dir = "Images/";
anno_dir = "XML/";
final_root_dir="Document_Layout_Analysis/" #Directory for jpeg and augmented images

In [None]:
if not os.path.exists(final_root_dir):
    os.makedirs(final_root_dir)

if not os.path.exists(final_root_dir+img_dir):
    os.makedirs(final_root_dir+img_dir)

## TIFF Image Format to JPEG Image Format

In [None]:
for name in glob.glob(root_dir+img_dir+'*.tif'):
    im = Image.open(name)
    name = str(name).rstrip(".tif")
    name = str(name).lstrip(root_dir)
    name = str(name).lstrip(img_dir)
    im.save(final_root_dir+ img_dir+ name + '.jpg', 'JPEG')

# Format Conversion and Data Augmentation

## Given format- VOC Format

### Dataset Directory Structure

    ./PRImA Layout Analysis Dataset/ (root_dir)
          |
          |-----------Images (img_dir)
          |              |
          |              |------------------img1.jpg
          |              |------------------img2.jpg
          |              |------------------.........(and so on)
          |
          |
          |-----------Annotations (anno_dir)
          |              |
          |              |------------------img1.xml
          |              |------------------img2.xml
          |              |------------------.........(and so on)
          


## Required Format- Monk Format

### Dataset Directory Structure

    ./Document_Layout_Analysis/ (final_root_dir)
          |
          |-----------Images (img_dir)
          |              |
          |              |------------------img1.jpg
          |              |------------------img2.jpg
          |              |------------------.........(and so on)
          |
          |
          |-----------train_labels.csv (anno_file)
          
          
### Annotation file format

           | Id         | Labels                                 |
           | img1.jpg   | x1 y1 x2 y2 label1 x1 y1 x2 y2 label2  |
           
- Labels:  xmin ymin xmax ymax label
- xmin, ymin - top left corner of bounding box
- xmax, ymax - bottom right corner of bounding box

In [None]:
files = os.listdir(root_dir + anno_dir);

In [None]:
combined = [];

### Data Augmentation Function

In [None]:
def augmentData(fname, boxes):
    image = cv2.imread(final_root_dir+img_dir+fname)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
    transform = A.Compose([
        A.IAAPerspective(p=0.7),   
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=5, p=0.5),
        A.IAAAdditiveGaussianNoise(),
        A.ChannelShuffle(),
        A.RandomBrightnessContrast(),
        A.RGBShift(p=0.8),
        A.HueSaturationValue(p=0.8)
        ], bbox_params=A.BboxParams(format='pascal_voc', min_visibility=0.2))
    
    for i in range(1, 9):
        label=""
        transformed = transform(image=image, bboxes=boxes)
        transformed_image = transformed['image']
        transformed_bboxes = transformed['bboxes']
        #print(transformed_bboxes)
        flag=False
        for box in transformed_bboxes:
            x_min, y_min, x_max, y_max, class_name = box
            if(xmax<=xmin or ymax<=ymin):
                flag=True
                break
            label+= str(int(x_min))+' '+str(int(y_min))+' '+str(int(x_max))+' '+str(int(y_max))+' '+class_name+' '
                        
        if(flag):
            continue
        cv2.imwrite(final_root_dir+img_dir+str(i)+fname, transformed_image)
        label=label[:-1]
        combined.append([str(i) + fname, label])


# VOC to Monk Format Conversion
Applying Data Augmentation only on those images which contain atleast 1 minority class so as to reduce bias in the dataset

In [None]:
#label generation for csv
for i in tqdm(range(len(files))):
    box=[];
    augment=False;
    annoFile = root_dir + anno_dir + files[i];
    f = open(annoFile, 'r');
    my_xml = f.read();
    anno= dict(dict(dict(xmltodict.parse(my_xml))['PcGts'])['Page'])
    fname=""
    for j in range(len(files[i])):
        if((files[i][j])>='0' and files[i][j]<='9'):
            fname+=files[i][j];
    fname+=".jpg"
    image = cv2.imread(final_root_dir+img_dir+fname)
    height, width = image.shape[:2]    
    label_str = ""
    for key in anno.keys():
        if(key=='@imageFilename' or key=='@imageWidth' or key=='@imageHeight'):
            continue
        if(key=="TextRegion"):
            if(type(anno["TextRegion"]) == list):
                for j in range(len(anno["TextRegion"])):
                    text=anno["TextRegion"][j]
                    xmin=width
                    ymin=height
                    xmax=0
                    ymax=0
                    if(text["Coords"]):
                        if(text["Coords"]["Point"]):
                            for k in range(len(text["Coords"]["Point"])):
                                coordinates=anno["TextRegion"][j]["Coords"]["Point"][k]
                                xmin= min(xmin, int(coordinates['@x']));
                                ymin= min(ymin, int(coordinates['@y']));
                                xmax= min(max(xmax, int(coordinates['@x'])), width);
                                ymax= min(max(ymax, int(coordinates['@y'])), height);
                            if('@type' in text.keys()):    
                                label_str+= str(xmin)+' '+str(ymin)+' '+str(xmax)+' '+str(ymax)+' '+text['@type']+' '
                                if(xmax<=xmin or ymax<=ymin):
                                    continue
                                tbox=[];
                                tbox.append(xmin)
                                tbox.append(ymin)
                                tbox.append(xmax)
                                tbox.append(ymax)
                                tbox.append(text['@type'])
                                box.append(tbox)
            else:
                text=anno["TextRegion"]
                xmin=width
                ymin=height
                xmax=0
                ymax=0
                if(text["Coords"]):
                    if(text["Coords"]["Point"]):
                        for k in range(len(text["Coords"]["Point"])):
                            coordinates=anno["TextRegion"]["Coords"]["Point"][k]
                            xmin= min(xmin, int(coordinates['@x']));
                            ymin= min(ymin, int(coordinates['@y']));
                            xmax= min(max(xmax, int(coordinates['@x'])), width);
                            ymax= min(max(ymax, int(coordinates['@y'])), height);
                        if('@type' in text.keys()):    
                            label_str+= str(xmin)+' '+str(ymin)+' '+str(xmax)+' '+str(ymax)+' '+text['@type']+' '
                            if(xmax<=xmin or ymax<=ymin):
                                continue
                            tbox=[];
                            tbox.append(xmin)
                            tbox.append(ymin)
                            tbox.append(xmax)
                            tbox.append(ymax)
                            tbox.append(text['@type'])
                            box.append(tbox)
        
        else:
            val=""
            if(key=='GraphicRegion'):
                val="graphics"
                augment=True
            elif(key=='ImageRegion'):
                val="image"
            elif(key=='NoiseRegion'):
                val="noise"
                augment=True
            elif(key=='ChartRegion'):
                val="chart"
                augment=True
            elif(key=='TableRegion'):
                val="table"
                augment=True
            elif(key=='SeparatorRegion'):
                val="separator"
            elif(key=='MathsRegion'):
                val="maths"
                augment=True
            elif(key=='LineDrawingRegion'):
                val="linedrawing"
                augment=True
            else:
                val="frame"
                augment=True

            
            if(type(anno[key]) == list):
                for j in range(len(anno[key])):
                    text=anno[key][j]
                    xmin=width
                    ymin=height
                    xmax=0
                    ymax=0
                    if(text["Coords"]):
                        if(text["Coords"]["Point"]):
                            for k in range(len(text["Coords"]["Point"])):
                                coordinates=anno[key][j]["Coords"]["Point"][k]
                                xmin= min(xmin, int(coordinates['@x']));
                                ymin= min(ymin, int(coordinates['@y']));
                                xmax= min(max(xmax, int(coordinates['@x'])), width);
                                ymax= min(max(ymax, int(coordinates['@y'])), height);
                        label_str+= str(xmin)+' '+str(ymin)+' '+str(xmax)+' '+str(ymax)+' '+ val +' '
                        if(xmax<=xmin or ymax<=ymin):
                            continue
                        tbox=[];
                        tbox.append(xmin)
                        tbox.append(ymin)
                        tbox.append(xmax)
                        tbox.append(ymax)
                        tbox.append(val)
                        box.append(tbox)
            else:
                text=anno[key]
                xmin=width
                ymin=height
                xmax=0
                ymax=0
                if(text["Coords"]):
                    if(text["Coords"]["Point"]):
                        for k in range(len(text["Coords"]["Point"])):
                            coordinates=anno[key]["Coords"]["Point"][k]
                            xmin= min(xmin, int(coordinates['@x']));
                            ymin= min(ymin, int(coordinates['@y']));
                            xmax= min(max(xmax, int(coordinates['@x'])), width);
                            ymax= min(max(ymax, int(coordinates['@y'])), height);  
                        label_str+= str(xmin)+' '+str(ymin)+' '+str(xmax)+' '+str(ymax)+' '+val+' '
                        if(xmax<=xmin or ymax<=ymin):
                            continue
                        tbox=[];
                        tbox.append(xmin)
                        tbox.append(ymin)
                        tbox.append(xmax)
                        tbox.append(ymax)
                        tbox.append(val)
                        box.append(tbox)

    label_str=label_str[:-1]
    combined.append([fname, label_str])

    if(augment):
        augmentData(fname, box)
        

In [None]:
df = pd.DataFrame(combined, columns = ['ID', 'Label']);
df.to_csv(final_root_dir + "/train_labels.csv", index=False);

# Training

In [None]:
import os
import sys

sys.path.append("Monk_Object_Detection/1_gluoncv_finetune/lib/");

In [None]:
from detector_prototype import Detector

In [None]:
gtf = Detector();

In [None]:
root = "Document_Layout_Analysis/";
img_dir = "Images/";
anno_file = "train_labels.csv";
batch_size=8;

In [None]:
gtf.Dataset(root, img_dir, anno_file, batch_size=batch_size);

### Available models
    ssd_300_vgg16_atrous_coco
    ssd_300_vgg16_atrous_voc
    ssd_512_vgg16_atrous_coco
    ssd_512_vgg16_atrous_voc
    ssd_512_resnet50_v1_coco
    ssd_512_resnet50_v1_voc
    ssd_512_mobilenet1.0_voc
    ssd_512_mobilenet1.0_coco
    yolo3_darknet53_voc
    yolo3_darknet53_coco
    yolo3_mobilenet1.0_voc
    yolo3_mobilenet1.0_coco

In [None]:
#vgg16 architecture, with atrous convolutions, pretrained on COCO dataset is used for this task
pretrained = True;         
gpu=True;
model_name = "ssd_512_vgg16_atrous_coco";

In [None]:
gtf.Model(model_name, use_pretrained=pretrained, use_gpu=gpu);

In [None]:
gtf.Set_Learning_Rate(0.003);

In [None]:
epochs=30;
params_file = "saved_model.params";

In [None]:
gtf.Train(epochs, params_file);

# Inference

In [None]:
import os
import sys
sys.path.append("Monk_Object_Detection/1_gluoncv_finetune/lib/");

In [None]:
from inference_prototype import Infer

In [None]:
model_name = "ssd_512_vgg16_atrous_coco";
params_file = "saved_model.params";
class_list = ["paragraph", "heading", "credit", "footer", "drop-capital", "floating", "noise", "maths", "header", "caption", "image", "linedrawing", "graphics", "fname", "page-number", "chart", "separator", "table"];

In [None]:
gtf = Infer(model_name, params_file, class_list, use_gpu=True);

In [None]:
img_name = "Test_Images/test1.jpg"; 
visualize = True;
thresh = 0.3;
output = gtf.run(img_name, visualize=visualize, thresh=thresh);

In [None]:
img_name = "Test_Images/test2.jpg"; 
visualize = True;
thresh = 0.3;
output = gtf.run(img_name, visualize=visualize, thresh=thresh);

In [None]:
img_name = "Test_Images/test3.jpg"; 
visualize = True;
thresh = 0.4;
output = gtf.run(img_name, visualize=visualize, thresh=thresh);

The model is performing well in identifying objects with very high confidence but it is biased a lot towards paragraphs. Its performance can be improved by using bigger batch size, training for more epochs and more data augmentation techniques to reduce bias.