Template required for fine tuning Mask_RCNN

In [8]:
from xml.etree import ElementTree
import mrcnn.config
import mrcnn.utils
import mrcnn.model
import numpy as np
import os

In [4]:
CLASS_NAMES = ['BG', 'table']

In [7]:
# Class that defines and loads the tablebank dataset
class TableBankDataset(mrcnn.utils.Dataset):
	# Load the dataset definitions
	def load_dataset(self, dataset_dir, is_train=True):
		# Define one class
		self.add_class("dataset", 1, "table")
		# Define data locations
		images_dir = dataset_dir + '/images/'
		annotations_dir = dataset_dir + '/annots/'
		# Find all images
		for filename in os.listdir(images_dir):
			# Extract image id
			image_id = filename[:-4]
			# Skip bad images
			# if image_id in ['00090']:
			#   continue
			# Skip all images after 150 if we are building the train set
			# if is_train and int(image_id) >= 150:
			#	continue
			# Skip all images before 150 if we are building the test/val set
			# if not is_train and int(image_id) < 150:
			#	continue
			img_path = images_dir + filename
			ann_path = annotations_dir + image_id + '.xml'
			# Add to dataset
			self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)
			
    # Load the masks for an image
	def load_mask(self, image_id):
		# Get details of image
		info = self.image_info[image_id]
		# Define box file location
		path = info['annotation']
		# Load XML
		boxes, w, h = self.extract_boxes(path)
		# Create one array for all masks, each on a different channel
		masks = np.zeros([h, w, len(boxes)], dtype='uint8')
		# Create masks
		class_ids = list()
		for i in range(len(boxes)):
			box = boxes[i]
			row_s, row_e = box[1], box[3]
			col_s, col_e = box[0], box[2]
			masks[row_s:row_e, col_s:col_e, i] = 1
			class_ids.append(self.class_names.index('table'))
		return masks, np.asarray(class_ids, dtype='int32')

	# Extract bounding boxes from an annotation file
	def extract_boxes(self, filename):
		# Load and parse the file
		tree = ElementTree.parse(filename)
		# Get the root of the document
		root = tree.getroot()
		# Extract each bounding box
		boxes = list()
		for box in root.findall('.//bndbox'):
			xmin = int(box.find('xmin').text)
			ymin = int(box.find('ymin').text)
			xmax = int(box.find('xmax').text)
			ymax = int(box.find('ymax').text)
			coors = [xmin, ymin, xmax, ymax]
			boxes.append(coors)
		# Extract image dimensions
		width = int(root.find('.//size/width').text)
		height = int(root.find('.//size/height').text)
		return boxes, width, height
	
	# Load an image reference
	def image_reference(self, image_id):
		info = self.image_info[image_id]
		return info['path']

In [6]:
# Define a configuration for the model
class TableBankConfig(mrcnn.config.Config):
	# Define the name of the configuration
	NAME = "template_model_cfg"
	BACKBONE = "resnet50"
	IMAGE_RESIZE_MODE = "square"
	IMAGE_MIN_DIM = 512
	IMAGE_MAX_DIM = 512
	# Number of classes
	NUM_CLASSES = len(CLASS_NAMES)
	GPU_COUNT = 1
	IMAGES_PER_GPU = 4
	# Number of training steps per epoch
	# <- MODIFY -> Replace 178 with the number of images of your dataset
	STEPS_PER_EPOCH = 178 // (GPU_COUNT * IMAGES_PER_GPU)

In [None]:
# Prepare train set
train_set = TableBankDataset()

# <- MODIFY -> Replace mydataset with the name of the folder containing your images and annotations
train_set.load_dataset('fine_dataset', is_train=True)
train_set.prepare()
print('Train: %d' % len(train_set.image_ids))

# Prepare config
config = TableBankConfig()
config.display()

# Define the model
model = mrcnn.model.MaskRCNN(mode='training', model_dir=os.getcwd(), config=config)

# Load weights (mscoco) and exclude the output layers
model.load_weights('tablebank_mask_rcnn_trained.h5', by_name=True, exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",  "mrcnn_bbox", "mrcnn_mask"])

# Train weights (output layers or 'heads')
model.train(train_set, None, learning_rate=config.LEARNING_RATE, epochs=2, layers='heads')

# Unfreeze the body of the network and train *all* layers
model.train(train_set, None, epochs=5,layers="all", learning_rate=config.LEARNING_RATE / 10)