In [2]:
import os
import json
import xml.etree
import numpy as np
import mrcnn.utils
import mrcnn.config
import mrcnn.model

Using TensorFlow backend.


In [3]:
CLASS_NAMES = ['BG', 'table']

In [4]:
class TableBankDataset(mrcnn.utils.Dataset):

    def load_dataset(self, dataset_dir, is_train=True):
        # Adds information (image ID, image path, and annotation file path) about each image in a dictionary.
        self.add_class("dataset", 1, "table")
        
        images_dir = dataset_dir + '/images/'
        annotations_dir = dataset_dir + '/annotations/'

        annotations = json.load(open(dataset_dir + '/annotations/' + 'tablebank_latex_train.json'))

        for i, image in enumerate(annotations["images"]):
            if i > 50000:
                break
            for annotation in annotations["annotations"]:
                if image["id"] == annotation["id"] and os.path.exists(images_dir + image["file_name"]):
                    img_path = images_dir + image["file_name"]
                    img_info = annotation
                    img_info["width"] = image["width"]
                    img_info["height"] = image["height"]
                    self.add_image('dataset', image_id=image["id"], path=img_path, annotation=img_info)
                    break
    
    # A helper method to extract the bounding boxes from the annotation file
    def extract_boxes(self, filename):
        img_info = filename
        boxes = []
        boxes.append(img_info["bbox"])
        width = img_info["width"]
        height = img_info["height"]
        return boxes, width, height

    # Loads the binary masks for an image.
    def load_mask(self, image_id):
        info = self.image_info[image_id]
        path = info['annotation']
        boxes, w, h = self.extract_boxes(path)
        masks = np.zeros([h, w, len(boxes)], dtype='uint8')
        
        class_ids = list()
        for i in range(len(boxes)):
            box = boxes[i]
            row_s, row_e = box[1], box[3]
            col_s, col_e = box[0], box[2]
            masks[row_s:row_e, col_s:col_e, i] = 1
            class_ids.append(self.class_names.index('table'))
        return masks, np.asarray(class_ids, dtype='int32')

In [5]:
class TableBankConfig(mrcnn.config.Config):
    NAME = "mask_rcnn_tablebank_cfg"

    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    BACKBONE = "resnet50"
    NUM_CLASSES = len(CLASS_NAMES)

In [6]:
# Train
train_dataset = TableBankDataset()
train_dataset.load_dataset(dataset_dir='C:/Users/dsash/Repository/table/TableBank/Detection', is_train=True)
train_dataset.prepare()

# Validation
validation_dataset = TableBankDataset()
validation_dataset.load_dataset(dataset_dir='C:/Users/dsash/Repository/table/TableBank/Detection', is_train=False)
validation_dataset.prepare()

# Model Configuration
tablebank_config = TableBankConfig()

# Build the Mask R-CNN Model Architecture
model = mrcnn.model.MaskRCNN(mode='training', 
                             model_dir=os.getcwd(), 
                             config=tablebank_config)

model.load_weights(filepath='mask_rcnn_coco.h5', 
                   by_name=True, 
                   exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",  "mrcnn_bbox", "mrcnn_mask"])

model.train(train_dataset=train_dataset, 
            val_dataset=validation_dataset, 
            learning_rate=tablebank_config.LEARNING_RATE, 
            epochs=1, 
            layers='heads')

model_path = 'tablebank_mask_rcnn_trained.h5'
model.keras_model.save_weights(model_path)
