From ae9a7e6554cf8d631bd37c363600e0e7c3936c4e Mon Sep 17 00:00:00 2001 From: Dang Date: Tue, 29 May 2018 08:51:34 +0000 Subject: [PATCH 1/5] Support groups in de-conv and fix bug. --- fluid/face_detction/pyramidbox.py | 18 +++++++++++++----- fluid/face_detction/train.py | 13 ++++++++----- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/fluid/face_detction/pyramidbox.py b/fluid/face_detction/pyramidbox.py index b91f75c1b2..93faad05bf 100644 --- a/fluid/face_detction/pyramidbox.py +++ b/fluid/face_detction/pyramidbox.py @@ -82,8 +82,7 @@ def _vgg(self): self.conv5 = conv_block(self.conv4, 3, [512] * 3, [3] * 3) # fc6 and fc7 in paper, priorbox min_size is 128 - self.conv6 = conv_block( - self.conv5, 2, [1024, 1024], [3, 1], with_pool=False) + self.conv6 = conv_block(self.conv5, 2, [1024, 1024], [3, 1]) # conv6_1 and conv6_2 in paper, priorbox min_size is 256 self.conv7 = conv_block( self.conv6, 2, [256, 512], [1, 3], [1, 2], with_pool=False) @@ -101,9 +100,15 @@ def fpn(up_from, up_to): b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.)) conv1 = fluid.layers.conv2d( up_from, ch, 1, act='relu', bias_attr=b_attr) - # TODO: add group conv_trans = fluid.layers.conv2d_transpose( - conv1, ch, None, 4, 1, 2, bias_attr=False) + conv1, + ch, + output_size=None, + filter_size=4, + padding=1, + stride=2, + groups=ch, + bias_attr=False) b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.)) conv2 = fluid.layers.conv2d( up_to, ch, 1, act='relu', bias_attr=b_attr) @@ -275,7 +280,10 @@ def train(self): head_loss = fluid.layers.ssd_loss( self.head_mbox_loc, self.head_mbox_conf, self.gt_box, self.gt_label, self.prior_boxes, self.box_vars) - return face_loss, head_loss + face_loss = fluid.layers.reduce_sum(face_loss) + head_loss = fluid.layers.reduce_sum(head_loss) + total_loss = face_loss + head_loss + return face_loss, head_loss, total_loss def test(self): test_program = fluid.default_main_program().clone(for_test=True) diff --git a/fluid/face_detction/train.py b/fluid/face_detction/train.py index 8fe122a60c..fcd7a0da26 100644 --- a/fluid/face_detction/train.py +++ b/fluid/face_detction/train.py @@ -19,6 +19,7 @@ add_arg('batch_size', int, 16, "Minibatch size.") add_arg('num_passes', int, 120, "Epoch number.") add_arg('use_gpu', bool, True, "Whether use GPU.") +add_arg('use_pyramidbox', bool, False, "Whether use GPU.") add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.") add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('pretrained_model', str, './vgg_model/', "The init model path.") @@ -37,8 +38,12 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, image_shape = [3, data_args.resize_h, data_args.resize_w] - network = PyramidBox(image_shape) - loss = network.vgg_ssd(num_classes, image_shape) + if args.use_pyramidbox: + network = PyramidBox(image_shape, sub_network=args.use_pyramidbox) + face_loss, head_loss, loss = network.train() + else: + network = PyramidBox(image_shape, sub_network=args.use_pyramidbox) + loss = network.vgg_ssd(num_classes, image_shape) epocs = 12880 / batch_size boundaries = [epocs * 100, epocs * 125, epocs * 150] @@ -46,6 +51,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, learning_rate, learning_rate * 0.1, learning_rate * 0.01, learning_rate * 0.001 ] + #print('main program ', fluid.default_main_program()) optimizer = fluid.optimizer.RMSProp( learning_rate=fluid.layers.piecewise_decay(boundaries, values), @@ -60,10 +66,8 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, # fluid.io.save_inference_model('./vgg_model/', ['image'], [loss], exe) if pretrained_model: - def if_exist(var): return os.path.exists(os.path.join(pretrained_model, var.name)) - print('Load pre-trained model.') fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) @@ -108,7 +112,6 @@ def save_model(postfix): if batch_id % 1 == 0: print("Pass {0}, batch {1}, loss {2}, time {3}".format( pass_id, batch_id, loss_v, start_time - prev_start_time)) - test(pass_id, best_map) if pass_id % 10 == 0 or pass_id == num_passes - 1: save_model(str(pass_id)) print("Best test map {0}".format(best_map)) From f530d784484cc3e5730e3f09b9a10d9971f1d1d6 Mon Sep 17 00:00:00 2001 From: Dang Date: Tue, 29 May 2018 08:54:30 +0000 Subject: [PATCH 2/5] Clean code. --- fluid/face_detction/image_util.py | 7 ++----- fluid/face_detction/pyramidbox.py | 3 +-- fluid/face_detction/train.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/fluid/face_detction/image_util.py b/fluid/face_detction/image_util.py index eb7be3ff84..210c49720b 100644 --- a/fluid/face_detction/image_util.py +++ b/fluid/face_detction/image_util.py @@ -75,11 +75,8 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0: return True for i in range(len(bbox_labels)): - object_bbox = bbox( - bbox_labels[i][0], - bbox_labels[i][1], # tangxu @ 2018-05-17 - bbox_labels[i][2], - bbox_labels[i][3]) + object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1], + bbox_labels[i][2], bbox_labels[i][3]) overlap = jaccard_overlap(sample_bbox, object_bbox) if sampler.min_jaccard_overlap != 0 and \ overlap < sampler.min_jaccard_overlap: diff --git a/fluid/face_detction/pyramidbox.py b/fluid/face_detction/pyramidbox.py index 93faad05bf..ed80bea532 100644 --- a/fluid/face_detction/pyramidbox.py +++ b/fluid/face_detction/pyramidbox.py @@ -240,8 +240,7 @@ def permute_and_reshape(input, last_dim): self.prior_boxes = fluid.layers.concat(boxes) self.box_vars = fluid.layers.concat(vars) - def vgg_ssd(self, num_classes, image_shape): # tangxu - + def vgg_ssd(self, num_classes, image_shape): self.conv3_norm = self._l2_norm_scale(self.conv3) self.conv4_norm = self._l2_norm_scale(self.conv4) self.conv5_norm = self._l2_norm_scale(self.conv5) diff --git a/fluid/face_detction/train.py b/fluid/face_detction/train.py index fcd7a0da26..ae79f6f3dc 100644 --- a/fluid/face_detction/train.py +++ b/fluid/face_detction/train.py @@ -19,7 +19,7 @@ add_arg('batch_size', int, 16, "Minibatch size.") add_arg('num_passes', int, 120, "Epoch number.") add_arg('use_gpu', bool, True, "Whether use GPU.") -add_arg('use_pyramidbox', bool, False, "Whether use GPU.") +add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.") add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.") add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('pretrained_model', str, './vgg_model/', "The init model path.") From e44757f1121bcc79672948eb5471d6310f0c8ef9 Mon Sep 17 00:00:00 2001 From: Dang Date: Tue, 29 May 2018 11:17:05 +0000 Subject: [PATCH 3/5] Fix model config. --- fluid/face_detction/pyramidbox.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/fluid/face_detction/pyramidbox.py b/fluid/face_detction/pyramidbox.py index ed80bea532..fa37a893fc 100644 --- a/fluid/face_detction/pyramidbox.py +++ b/fluid/face_detction/pyramidbox.py @@ -39,7 +39,7 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True): if with_pool: pool = fluid.layers.pool2d( input=conv, pool_size=2, pool_type='max', pool_stride=2) - return pool + return conv, pool else: return conv @@ -71,18 +71,19 @@ def _input(self): name='gt_difficult', shape=[1], dtype='int32', lod_level=1) def _vgg(self): - self.conv1 = conv_block(self.image, 2, [64] * 2, [3] * 2) - self.conv2 = conv_block(self.conv1, 2, [128] * 2, [3] * 2) + self.conv1, self.pool1 = conv_block(self.image, 2, [64] * 2, [3] * 2) + self.conv2, self.pool2 = conv_block(self.pool1, 2, [128] * 2, [3] * 2) #priorbox min_size is 16 - self.conv3 = conv_block(self.conv2, 3, [256] * 3, [3] * 3) + self.conv3, self.pool3 = conv_block(self.pool2, 3, [256] * 3, [3] * 3) #priorbox min_size is 32 - self.conv4 = conv_block(self.conv3, 3, [512] * 3, [3] * 3) + self.conv4, self.pool4 = conv_block(self.pool3, 3, [512] * 3, [3] * 3) #priorbox min_size is 64 - self.conv5 = conv_block(self.conv4, 3, [512] * 3, [3] * 3) + self.conv5, self.pool5 = conv_block(self.pool4, 3, [512] * 3, [3] * 3) # fc6 and fc7 in paper, priorbox min_size is 128 - self.conv6 = conv_block(self.conv5, 2, [1024, 1024], [3, 1]) + self.conv6 = conv_block( + self.pool5, 2, [1024, 1024], [3, 1], with_pool=False) # conv6_1 and conv6_2 in paper, priorbox min_size is 256 self.conv7 = conv_block( self.conv6, 2, [256, 512], [1, 3], [1, 2], with_pool=False) From 0683254df1634dd5fdbca239a6f9e571b0be884a Mon Sep 17 00:00:00 2001 From: Dang Date: Wed, 30 May 2018 09:30:30 +0000 Subject: [PATCH 4/5] Some mirror changes in data argumentations. --- fluid/face_detction/image_util.py | 71 ++++++++++++++++++++++--------- fluid/face_detction/reader.py | 22 +++++----- fluid/face_detction/train.py | 24 ++++++++--- 3 files changed, 81 insertions(+), 36 deletions(-) diff --git a/fluid/face_detction/image_util.py b/fluid/face_detction/image_util.py index 210c49720b..763d631dfd 100644 --- a/fluid/face_detction/image_util.py +++ b/fluid/face_detction/image_util.py @@ -8,9 +8,16 @@ class sampler(): - def __init__(self, max_sample, max_trial, min_scale, max_scale, - min_aspect_ratio, max_aspect_ratio, min_jaccard_overlap, - max_jaccard_overlap): + def __init__(self, + max_sample, + max_trial, + min_scale, + max_scale, + min_aspect_ratio, + max_aspect_ratio, + min_jaccard_overlap, + max_jaccard_overlap, + use_square=False): self.max_sample = max_sample self.max_trial = max_trial self.min_scale = min_scale @@ -19,6 +26,7 @@ def __init__(self, max_sample, max_trial, min_scale, max_scale, self.max_aspect_ratio = max_aspect_ratio self.min_jaccard_overlap = min_jaccard_overlap self.max_jaccard_overlap = max_jaccard_overlap + self.use_square = use_square class bbox(): @@ -35,13 +43,23 @@ def bbox_area(src_bbox): return width * height -def generate_sample(sampler): +def generate_sample(sampler, image_width, image_height): scale = random.uniform(sampler.min_scale, sampler.max_scale) - min_aspect_ratio = max(sampler.min_aspect_ratio, (scale**2.0)) - max_aspect_ratio = min(sampler.max_aspect_ratio, 1 / (scale**2.0)) - aspect_ratio = random.uniform(min_aspect_ratio, max_aspect_ratio) + aspect_ratio = random.uniform(sampler.min_aspect_ratio, + sampler.max_aspect_ratio) + aspect_ratio = max(aspect_ratio, (scale**2.0)) + aspect_ratio = min(aspect_ratio, 1 / (scale**2.0)) + bbox_width = scale * (aspect_ratio**0.5) bbox_height = scale / (aspect_ratio**0.5) + + # guarantee a squared image patch after cropping + if sampler.use_square: + if image_height < image_width: + bbox_width = bbox_height * image_height / image_width + else: + bbox_height = bbox_width * image_width / image_height + xmin_bound = 1 - bbox_width ymin_bound = 1 - bbox_height xmin = random.uniform(0, xmin_bound) @@ -77,6 +95,7 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): for i in range(len(bbox_labels)): object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1], bbox_labels[i][2], bbox_labels[i][3]) + # now only support constraint by jaccard overlap overlap = jaccard_overlap(sample_bbox, object_bbox) if sampler.min_jaccard_overlap != 0 and \ overlap < sampler.min_jaccard_overlap: @@ -88,7 +107,8 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): return False -def generate_batch_samples(batch_sampler, bbox_labels): +def generate_batch_samples(batch_sampler, bbox_labels, image_width, + image_height): sampled_bbox = [] index = [] c = 0 @@ -97,7 +117,7 @@ def generate_batch_samples(batch_sampler, bbox_labels): for i in range(sampler.max_trial): if found >= sampler.max_sample: break - sample_bbox = generate_sample(sampler) + sample_bbox = generate_sample(sampler, image_width, image_height) if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): sampled_bbox.append(sample_bbox) found = found + 1 @@ -125,15 +145,14 @@ def meet_emit_constraint(src_bbox, sample_bbox): return False -def transform_labels(bbox_labels, sample_bbox): - proj_bbox = bbox(0, 0, 0, 0) - sample_labels = [] - for i in range(len(bbox_labels)): - sample_label = [] - object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1], - bbox_labels[i][2], bbox_labels[i][3]) - if not meet_emit_constraint(object_bbox, sample_bbox): - continue +def project_bbox(object_bbox, sample_bbox): + if object_bbox.xmin >= sample_bbox.xmax or \ + object_bbox.xmax <= sample_bbox.xmin or \ + object_bbox.ymin >= sample_bbox.ymax or \ + object_bbox.ymax <= sample_bbox.ymin: + return False + else: + proj_bbox = bbox(0, 0, 0, 0) sample_width = sample_bbox.xmax - sample_bbox.xmin sample_height = sample_bbox.ymax - sample_bbox.ymin proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width @@ -142,12 +161,26 @@ def transform_labels(bbox_labels, sample_bbox): proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height proj_bbox = clip_bbox(proj_bbox) if bbox_area(proj_bbox) > 0: + return proj_bbox + else: + return False + + +def transform_labels(bbox_labels, sample_bbox): + sample_labels = [] + for i in range(len(bbox_labels)): + sample_label = [] + object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1], + bbox_labels[i][2], bbox_labels[i][3]) + if not meet_emit_constraint(object_bbox, sample_bbox): + continue + proj_bbox = project_bbox(object_bbox, sample_bbox) + if proj_bbox: sample_label.append(bbox_labels[i][0]) sample_label.append(float(proj_bbox.xmin)) sample_label.append(float(proj_bbox.ymin)) sample_label.append(float(proj_bbox.xmax)) sample_label.append(float(proj_bbox.ymax)) - #sample_label.append(bbox_labels[i][5]) sample_label = sample_label + bbox_labels[i][5:] sample_labels.append(sample_label) return sample_labels diff --git a/fluid/face_detction/reader.py b/fluid/face_detction/reader.py index aa7e45aaba..9ed82c02e2 100644 --- a/fluid/face_detction/reader.py +++ b/fluid/face_detction/reader.py @@ -29,9 +29,9 @@ def __init__(self, dataset=None, data_dir=None, label_file=None, - resize_h=300, - resize_w=300, - mean_value=[127.5, 127.5, 127.5], + resize_h=None, + resize_w=None, + mean_value=[104., 117., 123.], apply_distort=True, apply_expand=True, ap_version='11point', @@ -55,6 +55,8 @@ def __init__(self, self._saturation_prob = 0.5 self._saturation_delta = 0.5 self._brightness_prob = 0.5 + # _brightness_delta is the normalized value by 256 + # self._brightness_delta = 32 self._brightness_delta = 0.125 @property @@ -115,17 +117,17 @@ def preprocess(img, bbox_labels, mode, settings): batch_sampler = [] # hard-code here batch_sampler.append( - image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) + image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, True)) batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True)) batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True)) batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True)) batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0)) - sampled_bbox = image_util.generate_batch_samples(batch_sampler, - bbox_labels) + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 0.0, True)) + sampled_bbox = image_util.generate_batch_samples( + batch_sampler, bbox_labels, img_width, img_height) img = np.array(img) if len(sampled_bbox) > 0: diff --git a/fluid/face_detction/train.py b/fluid/face_detction/train.py index ae79f6f3dc..ce9d45807c 100644 --- a/fluid/face_detction/train.py +++ b/fluid/face_detction/train.py @@ -29,7 +29,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, - num_passes): + num_passes, optimizer_method): num_classes = 2 @@ -51,12 +51,19 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, learning_rate, learning_rate * 0.1, learning_rate * 0.01, learning_rate * 0.001 ] - #print('main program ', fluid.default_main_program()) - optimizer = fluid.optimizer.RMSProp( - learning_rate=fluid.layers.piecewise_decay(boundaries, values), - regularization=fluid.regularizer.L2Decay(0.0005), - ) + if optimizer_method == "momentum": + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=boundaries, values=values), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(0.0005), + ) + else: + optimizer = fluid.optimizer.RMSProp( + learning_rate=fluid.layers.piecewise_decay(boundaries, values), + regularization=fluid.regularizer.L2Decay(0.0005), + ) optimizer.minimize(loss) @@ -131,6 +138,8 @@ def save_model(postfix): data_dir=data_dir, resize_h=args.resize_h, resize_w=args.resize_w, + apply_expand=False, + mean_value=[104., 117., 123], ap_version='11point') train( args, @@ -138,4 +147,5 @@ def save_model(postfix): learning_rate=0.01, batch_size=args.batch_size, pretrained_model=args.pretrained_model, - num_passes=args.num_passes) + num_passes=args.num_passes, + optimizer_method="momentum") From 5c06cb82220f81e04104538f54bff07b78c6b4ec Mon Sep 17 00:00:00 2001 From: Dang Date: Wed, 30 May 2018 11:04:42 +0000 Subject: [PATCH 5/5] Set learning rate by input arguments. --- fluid/face_detction/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fluid/face_detction/train.py b/fluid/face_detction/train.py index ce9d45807c..e0a0b648e4 100644 --- a/fluid/face_detction/train.py +++ b/fluid/face_detction/train.py @@ -144,7 +144,7 @@ def save_model(postfix): train( args, data_args=data_args, - learning_rate=0.01, + learning_rate=args.learning_rate, batch_size=args.batch_size, pretrained_model=args.pretrained_model, num_passes=args.num_passes,