In [13]:
import os
import json
import xml.etree
import numpy as np
import mrcnn.utils
import mrcnn.config
import mrcnn.model

In [12]:
class TableBankDataset(mrcnn.utils.Dataset):

    def load_dataset(self, dataset_dir, is_train=True):
        # Adds information (image ID, image path, and annotation file path) about each image in a dictionary.
        self.add_class("dataset", 1, "table")

        images_dir = dataset_dir + '/images/'
        annotations_dir = dataset_dir + '/annotations/'

        for filename in os.listdir(images_dir):
            image_id = filename[:-4]

            img_path = images_dir + filename
            ann_path = annotations_dir + image_id + '.xml'

            self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)

    # Loads the binary masks for an image.
    def load_mask(self, image_id):
        info = self.image_info[image_id]
        path = info['annotation']
        boxes, w, h = self.extract_boxes(path)
        masks = np.zeros([h, w, len(boxes)], dtype='uint8')

        class_ids = list()
        for i in range(len(boxes)):
            box = boxes[i]
            row_s, row_e = box[1], box[3]
            col_s, col_e = box[0], box[2]
            masks[row_s:row_e, col_s:col_e, i] = 1
            class_ids.append(self.class_names.index('table'))
        return masks, np.asarray(class_ids, dtype='int32')

    # A helper method to extract the bounding boxes from the annotation file
    def extract_boxes(self, filename):
        tree = xml.etree.ElementTree.parse(filename)

        root = tree.getroot()

        boxes = list()
        for box in root.findall('.//bndbox'):
            xmin = int(box.find('xmin').text)
            ymin = int(box.find('ymin').text)
            xmax = int(box.find('xmax').text)
            ymax = int(box.find('ymax').text)
            coors = [xmin, ymin, xmax, ymax]
            boxes.append(coors)

        width = int(root.find('.//size/width').text)
        height = int(root.find('.//size/height').text)
        return boxes, width, height

In [27]:
dataset_dir='C:/Users/dsash/Repository/table/TableBank/Detection' 
annotations = json.load(open(dataset_dir + '/annotations/' + 'tablebank_latex_train.json'))
print(type(annotations))
# Заданный id для поиска
target_id = 19

for i, image in enumerate(annotations["images"]):
    if image["id"] == target_id:
        for annotation in annotations["annotations"]:
            if image["id"] == annotation["id"]:
                annotation["width"] = image["width"]
                annotation["height"] = image["height"]
                print(annotation)
                break

<class 'dict'>
{'segmentation': [[102, 601, 102, 672, 251, 672, 251, 601]], 'area': 10579, 'image_id': 15, 'category_id': 1, 'id': 19, 'iscrowd': 0, 'bbox': [102, 601, 149, 71], 'width': 596, 'height': 842}


In [23]:
class TableBankDataset(mrcnn.utils.Dataset):

    def load_dataset(self, dataset_dir, mode='train'):
        # Adds information (image ID, image path, and annotation file path) about each image in a dictionary.
        self.add_class("dataset", 1, "table")
        
        images_dir = dataset_dir + '/images/'
        annotations_dir = dataset_dir + '/annotations/'

        annotations = json.load(open(dataset_dir + '/annotations/' + 'tablebank_latex_train.json'))

        for i, image in enumerate(annotations["images"]):
            if i > 1000:
                break
            for annotation in annotations["annotations"]:
                if image["id"] == annotation["id"]:
                    img_path = images_dir + image["file_name"]
                    img_info = annotation
                    img_info["width"] = image["width"]
                    img_info["height"] = image["height"]
                    self.add_image('dataset', image_id=image["id"], path=img_path, annotation=img_info)
                    break
    
    # A helper method to extract the bounding boxes from the annotation file
    def extract_boxes(self, filename):
        img_info = filename
        boxes = []
        boxes.append(img_info["bbox"])
        width = img_info["width"]
        height = img_info["height"]
        return boxes, width, height

    # Loads the binary masks for an image.
    def load_mask(self, image_id):
        info = self.image_info[image_id]
        path = info['annotation']
        boxes, w, h = self.extract_boxes(path)
        masks = np.zeros([h, w, len(boxes)], dtype='uint8')
        
        class_ids = list()
        for i in range(len(boxes)):
            box = boxes[i]
            row_s, row_e = box[1], box[3]
            col_s, col_e = box[0], box[2]
            masks[row_s:row_e, col_s:col_e, i] = 1
            class_ids.append(self.class_names.index('table'))
        return masks, np.asarray(class_ids, dtype='int32')

In [None]:
# Train
train_dataset = TableBankDataset()
train_dataset.load_dataset(dataset_dir='C:/Users/dsash/Repository/table/TableBank/Detection', is_train=True)
train_dataset.

In [13]:
class TableBankConfig(mrcnn.config.Config):
    NAME = "mask_rcnn_tablebank_cfg"

    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    
    NUM_CLASSES = 2

    STEPS_PER_EPOCH = 131

In [17]:
# Train
train_dataset = TableBankDataset()
train_dataset.load_dataset(dataset_dir='C:/Users/dsash/Repository/table/TableBank/Detection', is_train=True)
train_dataset.prepare()

# Validation
validation_dataset = TableBankDataset()
validation_dataset.load_dataset(dataset_dir='C:/Users/dsash/Repository/table/TableBank/Detection', is_train=False)
validation_dataset.prepare()

# Model Configuration
tablebank_config = TableBankConfig()

# Build the Mask R-CNN Model Architecture
model = mrcnn.model.MaskRCNN(mode='training', 
                             model_dir='./', 
                             config=tablebank_config)

model.load_weights(filepath='mask_rcnn_coco.h5', 
                   by_name=True, 
                   exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",  "mrcnn_bbox", "mrcnn_mask"])

model.train(train_dataset=train_dataset, 
            val_dataset=validation_dataset, 
            learning_rate=tablebank_config.LEARNING_RATE, 
            epochs=1, 
            layers='heads')

model_path = 'tablebank_mask_rcnn_trained.h5'
model.keras_model.save_weights(model_path)


ValueError: invalid literal for int() with base 10: '%20%20%202013_2'