In [1]:
from __future__ import division
import random
import pprint
import sys
import time
import numpy as np
from optparse import OptionParser
import pickle
import re

from keras import backend as K
from keras.optimizers import Adam, SGD, RMSprop
from keras.layers import Input
from keras.models import Model
from keras_frcnn import config, data_generators
from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
# from keras.utils import generic_utils
from tensorflow.python.keras.utils import generic_utils

In [17]:
import os
import cv2
import xml.etree.ElementTree as ET
import numpy as np
def get_data(input_path):
	all_imgs = []

	classes_count = {}

	class_mapping = {}

	visualise = False

	# data_paths = [os.path.join(input_path,s) for s in ['VOC2007', 'VOC2012']]
	data_paths = [os.path.join(input_path,s) for s in ['VOC2012']]
	

	print('Parsing annotation files')

	for data_path in data_paths:

		annot_path = os.path.join(data_path, 'Annotations')
		imgs_path = os.path.join(data_path, 'JPEGImages')
		imgsets_path_trainval = os.path.join(data_path, 'ImageSets','Main','trainval.txt')
		imgsets_path_test = os.path.join(data_path, 'ImageSets','Main','test.txt')

		trainval_files = []
		test_files = []
		try:
			with open(imgsets_path_trainval) as f:
				for line in f:
					if line.strip().startswith("2008"):
						trainval_files.append(line.strip() + '.jpg')
		except Exception as e:
			print(e)

		try:
			with open(imgsets_path_test) as f:
				for line in f:
					if line.strip().startswith("2008"):
						test_files.append(line.strip() + '.jpg')
		except Exception as e:
			if data_path[-7:] == 'VOC2012':
				# this is expected, most pascal voc distibutions dont have the test.txt file
				pass
			else:
				print(e)
		
		# annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]
		annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path) if s.startswith("2008")]
		idx = 0
		for annot in annots:
			print(annot)
			try:
				idx += 1

				et = ET.parse(annot)
				element = et.getroot()

				element_objs = element.findall('object')
				element_filename = element.find('filename').text
				element_width = int(element.find('size').find('width').text)
				element_height = int(element.find('size').find('height').text)

				if len(element_objs) > 0:
					annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
									   'height': element_height, 'bboxes': []}

					if element_filename in trainval_files:
						annotation_data['imageset'] = 'trainval'
					elif element_filename in test_files:
						annotation_data['imageset'] = 'test'
					else:
						annotation_data['imageset'] = 'trainval'

				for element_obj in element_objs:
					class_name = element_obj.find('name').text
					if class_name not in classes_count:
						classes_count[class_name] = 1
					else:
						classes_count[class_name] += 1

					if class_name not in class_mapping:
						class_mapping[class_name] = len(class_mapping)

					obj_bbox = element_obj.find('bndbox')
					x1 = int(round(float(obj_bbox.find('xmin').text)))
					y1 = int(round(float(obj_bbox.find('ymin').text)))
					x2 = int(round(float(obj_bbox.find('xmax').text)))
					y2 = int(round(float(obj_bbox.find('ymax').text)))
					difficulty = int(element_obj.find('difficult').text) == 1
					annotation_data['bboxes'].append(
						{'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
				all_imgs.append(annotation_data)

				if visualise:
					img = cv2.imread(annotation_data['filepath'])
					for bbox in annotation_data['bboxes']:
						cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox[
									  'x2'], bbox['y2']), (0, 0, 255))
					cv2.imshow('img', img)
					cv2.waitKey(0)

			except Exception as e:
				print(e)
				continue
	return all_imgs, classes_count, class_mapping


In [16]:
def get_data_test(input_path):
	all_imgs = []

	classes_count = {}

	class_mapping = {}

	visualise = False

	# data_paths = [os.path.join(input_path,s) for s in ['VOC2007', 'VOC2012']]
	data_paths = [os.path.join(input_path,s) for s in ['VOC2012']]
	

	print('Parsing annotation files')

	for data_path in data_paths:

		annot_path = os.path.join(data_path, 'Annotations')
		imgs_path = os.path.join(data_path, 'JPEGImages')
		# imgsets_path_trainval = os.path.join(data_path, 'ImageSets','Main','trainval.txt')
		imgsets_path_test = os.path.join(data_path, 'ImageSets','Main','test.txt')

		# trainval_files = []
		test_files = []
		# try:
		# 	with open(imgsets_path_test) as f:
		# 		for line in f:
		# 			trainval_files.append(line.strip() + '.jpg')
		# except Exception as e:
		# 	print(e)

		try:
			with open(imgsets_path_test) as f:
				for line in f:
					if line.strip().startswith("2008"):
						test_files.append(line.strip() + '.jpg')
		except Exception as e:
			if data_path[-7:] == 'VOC2012':
				# this is expected, most pascal voc distibutions dont have the test.txt file
				pass
			else:
				print(e)
		
		# annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]
		annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path) if s.startswith("2008")]
		idx = 0
		for annot in annots:
			try:
				idx += 1

				et = ET.parse(annot)
				element = et.getroot()

				element_objs = element.findall('object')
				element_filename = element.find('filename').text
				element_width = int(element.find('size').find('width').text)
				element_height = int(element.find('size').find('height').text)

				if len(element_objs) > 0:
					annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
									   'height': element_height, 'bboxes': []}

					# if element_filename in trainval_files:
					# 	annotation_data['imageset'] = 'trainval'
					# if element_filename in test_files:
					# 	annotation_data['imageset'] = 'test'
					# else:
					# 	annotation_data['imageset'] = 'trainval'

				for element_obj in element_objs:
					class_name = element_obj.find('name').text
					if class_name not in classes_count:
						classes_count[class_name] = 1
					else:
						classes_count[class_name] += 1

					if class_name not in class_mapping:
						class_mapping[class_name] = len(class_mapping)

					obj_bbox = element_obj.find('bndbox')
					x1 = int(round(float(obj_bbox.find('xmin').text)))
					y1 = int(round(float(obj_bbox.find('ymin').text)))
					x2 = int(round(float(obj_bbox.find('xmax').text)))
					y2 = int(round(float(obj_bbox.find('ymax').text)))
					# difficulty = int(element_obj.find('difficult').text) == 1
					annotation_data['bboxes'].append(
						# {'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
						{'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2})
				all_imgs.append(annotation_data)

				if visualise:
					img = cv2.imread(annotation_data['filepath'])
					for bbox in annotation_data['bboxes']:
						cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox[
									  'x2'], bbox['y2']), (0, 0, 255))
					cv2.imshow('img', img)
					cv2.waitKey(0)

			except Exception as e:
				print(e)
				continue
	return all_imgs, classes_count, class_mapping


In [18]:
# from keras_frcnn.pascal_voc_parser import get_data
# from keras_frcnn.pascal_voc_parser import get_data
# from keras_frcnn import vgg as nn
from keras_frcnn import resnet as nn

train_path = "D:\FasterRCNN\MaskRCNN\data"
test_path = "../DataTest/"
C = config.Config()

model_path_regex = re.match("^(.+)(\.hdf5)$", C.model_path)
if model_path_regex.group(2) != '.hdf5':
	print('Output weights must have .hdf5 filetype')
	exit(1)
	
# train_imgs, classes_count, class_mapping = get_data(train_path, 'trainval')
train_imgs, classes_count, class_mapping = get_data(train_path)
val_imgs, _, _ = get_data_test(test_path)

Parsing annotation files
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000002.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000003.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000007.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000008.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000009.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000015.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000016.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000019.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000021.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000023.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000026.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000027.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000028.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000032.xml
D:\FasterRCNN\MaskRCNN\data\VOC2012\Annotations\2008_000033.xml
D:\FasterRCNN\M

In [5]:
C.use_horizontal_flips = False
C.use_vertical_flips = False
C.rot_90 = False

In [6]:
# from keras_frcnn import vgg as nn
if 'bg' not in classes_count:
	classes_count['bg'] = 0
	class_mapping['bg'] = len(class_mapping)
	
C.class_mapping = class_mapping

In [7]:
C.base_net_weights = nn.get_weight_path()

In [8]:
inv_map = {v: k for k, v in class_mapping.items()}

In [9]:
print('Training images per class:')
pprint.pprint(classes_count)
print(f'Num classes (including bg) = {len(classes_count)}')

Training images per class:
{'aeroplane': 1002,
 'bg': 0,
 'bicycle': 837,
 'bird': 1271,
 'boat': 1059,
 'bottle': 1561,
 'bus': 685,
 'car': 2492,
 'cat': 1277,
 'chair': 3056,
 'cow': 771,
 'diningtable': 800,
 'dog': 1598,
 'horse': 803,
 'motorbike': 801,
 'person': 17401,
 'pottedplant': 1202,
 'sheep': 1084,
 'sofa': 841,
 'train': 704,
 'tvmonitor': 893}
Num classes (including bg) = 21


In [10]:
config_output_filename = "config_output.txt"

with open(config_output_filename, 'wb') as config_f:
	pickle.dump(C,config_f)
	print(f'Config has been written to {config_output_filename}, and can be loaded when testing to ensure correct results')


Config has been written to config_output.txt, and can be loaded when testing to ensure correct results


In [11]:
random.shuffle(train_imgs)

In [20]:
len(train_imgs)

4340

In [13]:
num_imgs = len(train_imgs)

In [14]:
print(f'Num train samples {len(train_imgs)}')
print(f'Num val samples {len(val_imgs)}')

Num train samples 17125
Num val samples 5138


In [15]:
data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_data_format(), mode='train')
data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, C, nn.get_img_output_length,K.image_data_format(), mode='val')


In [16]:
if K.image_data_format() == 'channels_first':
	input_shape_img = (3, None, None)
else:
	input_shape_img = (None, None, 3)

In [17]:
img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(None, 4))

In [18]:
# from keras_frcnn import resnet as nn

In [19]:
# resnet50
shared_layers = nn.nn_base(img_input, trainable=True)

In [20]:
# define the RPN, built on the base layers
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
rpn = nn.rpn(shared_layers, num_anchors)

In [21]:
classifier = nn.classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count), trainable=True)

In [22]:
model_rpn = Model(img_input, rpn[:2])
model_classifier = Model([img_input, roi_input], classifier)

In [23]:
# this is a model that holds both the RPN and the classifier, used to load/save weights for the models
model_all = Model([img_input, roi_input], rpn[:2] + classifier)

In [24]:
try:
	print('loading weights from {C.base_net_weights}')
	model_rpn.load_weights(C.base_net_weights, by_name=True)
	model_classifier.load_weights(C.base_net_weights, by_name=True)
except:
	print('Could not load pretrained model weights. Weights can be found in the keras application folder \
		https://github.com/fchollet/keras/tree/master/keras/applications')


loading weights from {C.base_net_weights}


In [25]:
optimizer = Adam(learning_rate=0.01) #1e-5
optimizer_classifier = Adam(learning_rate=0.01) #1e-5
model_rpn.compile(optimizer=optimizer, loss=[losses.rpn_loss_cls(num_anchors), losses.rpn_loss_regr(num_anchors)])
model_classifier.compile(optimizer=optimizer_classifier, loss=[losses.class_loss_cls, losses.class_loss_regr(len(classes_count)-1)], metrics={f'dense_class_{len(classes_count)}': 'accuracy'})
model_all.compile(optimizer='sgd', loss='mae')

In [26]:
epoch_length = 1000
num_epochs = 20
iter_num = 0

losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()

best_loss = np.Inf

class_mapping_inv = {v: k for k, v in class_mapping.items()}
print('Starting training')

Starting training


In [27]:
X, Y, img_data = next(data_gen_train)
P_rpn = model_rpn.predict_on_batch(X)
print(P_rpn)

[array([[[[0.78429365, 0.13850963, 0.34131494, ..., 0.24807344,
          0.9292226 , 0.16375497],
         [0.22021264, 0.01816508, 0.62012225, ..., 0.06712058,
          0.25343436, 0.14186546],
         [0.75310093, 0.00334291, 0.2581548 , ..., 0.12264472,
          0.84411603, 0.18229905],
         ...,
         [0.34604576, 0.6926964 , 0.06389381, ..., 0.23937437,
          0.10374898, 0.02382954],
         [0.74565077, 0.5895226 , 0.01954159, ..., 0.09034041,
          0.73331743, 0.26255432],
         [0.22933578, 0.28391963, 0.01483251, ..., 0.09523329,
          0.602917  , 0.16784957]],

        [[0.90222687, 0.01453712, 0.14070512, ..., 0.14171956,
          0.91788423, 0.01728046],
         [0.7842915 , 0.01838453, 0.04620825, ..., 0.44030032,
          0.14490482, 0.35395217],
         [0.48683107, 0.00495991, 0.0038912 , ..., 0.07348019,
          0.90329146, 0.24315943],
         ...,
         [0.5685059 , 0.49352807, 0.00173805, ..., 0.09927638,
          0.16878685, 0.

In [28]:
loss_rpn = model_rpn.train_on_batch(X, Y)
print(loss_rpn)

[8.432779312133789, 7.699596405029297, 0.7331826686859131]


In [30]:
R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_data_format(), use_regr=True, overlap_thresh=0.7, max_boxes=300)

In [33]:
X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
# print(img_data)

In [41]:
neg_samples = np.where(Y1[0, :, -1] == 1)
pos_samples = np.where(Y1[0, :, -1] == 0)

print(neg_samples[0])

[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  91  92
  93  94  95  96  98  99 100 101 102 103 105 106 107 108 109 110 111 112
 113 114 115 116]


In [42]:
if len(neg_samples) > 0:
    neg_samples = neg_samples[0]
else:
    neg_samples = []

if len(pos_samples) > 0:
    pos_samples = pos_samples[0]
else:
    pos_samples = []

rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))

if C.num_rois > 1:
    if len(pos_samples) < C.num_rois//2:
        selected_pos_samples = pos_samples.tolist()
    else:
        selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
    try:
        selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
    except:
        selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()

    sel_samples = selected_pos_samples + selected_neg_samples
else:
    # in the extreme case where num_rois = 1, we pick a random pos or neg sample
    selected_pos_samples = pos_samples.tolist()
    selected_neg_samples = neg_samples.tolist()
    if np.random.randint(0, 2):
        sel_samples = random.choice(neg_samples)
    else:
        sel_samples = random.choice(pos_samples)

In [43]:
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

In [45]:
print(loss_class)

[3.5904674530029297, 3.044522285461426, 0.5459451675415039, 0.0]


In [29]:
for epoch_num in range(num_epochs):

	progbar = generic_utils.Progbar(epoch_length)
	print(f'Epoch {epoch_num + 1}/{num_epochs}')

	while True:
		try:

			if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
				mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
				rpn_accuracy_rpn_monitor = []
				print(f'Average number of overlapping bounding boxes from RPN = {mean_overlapping_bboxes} for {epoch_length} previous iterations')
				if mean_overlapping_bboxes == 0:
					print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')

			X, Y, img_data = next(data_gen_train)

			loss_rpn = model_rpn.train_on_batch(X, Y)

			P_rpn = model_rpn.predict_on_batch(X)

			R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_data_format(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
			# note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
			X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)

			if X2 is None:
				rpn_accuracy_rpn_monitor.append(0)
				rpn_accuracy_for_epoch.append(0)
				continue

			neg_samples = np.where(Y1[0, :, -1] == 1)
			pos_samples = np.where(Y1[0, :, -1] == 0)

			if len(neg_samples) > 0:
				neg_samples = neg_samples[0]
			else:
				neg_samples = []

			if len(pos_samples) > 0:
				pos_samples = pos_samples[0]
			else:
				pos_samples = []
			
			rpn_accuracy_rpn_monitor.append(len(pos_samples))
			rpn_accuracy_for_epoch.append((len(pos_samples)))

			if C.num_rois > 1:
				if len(pos_samples) < C.num_rois//2:
					selected_pos_samples = pos_samples.tolist()
				else:
					selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
				try:
					selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
				except:
					selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()

				sel_samples = selected_pos_samples + selected_neg_samples
			else:
				# in the extreme case where num_rois = 1, we pick a random pos or neg sample
				selected_pos_samples = pos_samples.tolist()
				selected_neg_samples = neg_samples.tolist()
				if np.random.randint(0, 2):
					sel_samples = random.choice(neg_samples)
				else:
					sel_samples = random.choice(pos_samples)

			loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

			losses[iter_num, 0] = loss_rpn[1]
			losses[iter_num, 1] = loss_rpn[2]

			losses[iter_num, 2] = loss_class[1]
			losses[iter_num, 3] = loss_class[2]
			losses[iter_num, 4] = loss_class[3]

			progbar.update(iter_num+1, [('rpn_cls', losses[iter_num, 0]), ('rpn_regr', losses[iter_num, 1]),
									  ('detector_cls', losses[iter_num, 2]), ('detector_regr', losses[iter_num, 3])])

			iter_num += 1
			
			if iter_num == epoch_length:
				loss_rpn_cls = np.mean(losses[:, 0])
				loss_rpn_regr = np.mean(losses[:, 1])
				loss_class_cls = np.mean(losses[:, 2])
				loss_class_regr = np.mean(losses[:, 3])
				class_acc = np.mean(losses[:, 4])

				mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
				rpn_accuracy_for_epoch = []

				if C.verbose:
					print(f'Mean number of bounding boxes from RPN overlapping ground truth boxes: {mean_overlapping_bboxes}')
					print(f'Classifier accuracy for bounding boxes from RPN: {class_acc}')
					print(f'Loss RPN classifier: {loss_rpn_cls}')
					print(f'Loss RPN regression: {loss_rpn_regr}')
					print(f'Loss Detector classifier: {loss_class_cls}')
					print(f'Loss Detector regression: {loss_class_regr}')
					print(f'Elapsed time: {time.time() - start_time}')

				curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
				iter_num = 0
				start_time = time.time()

				if curr_loss < best_loss:
					if C.verbose:
						print(f'Total loss decreased from {best_loss} to {curr_loss}, saving weights')
					best_loss = curr_loss
				model_all.save_weights(model_path_regex.group(1) + "_" + '{:04d}'.format(epoch_num) + model_path_regex.group(2))

				break

		except Exception as e:
			print(f'Exception: {e}')
			continue

print('Training complete, exiting.')
