diff --git a/fluid/ocr_recognition/crnn_ctc_model.py b/fluid/ocr_recognition/crnn_ctc_model.py index 73616ecb36..945cc334c8 100644 --- a/fluid/ocr_recognition/crnn_ctc_model.py +++ b/fluid/ocr_recognition/crnn_ctc_model.py @@ -194,8 +194,19 @@ def ctc_train_net(images, label, args, num_classes): optimizer = fluid.optimizer.Momentum( learning_rate=args.learning_rate, momentum=args.momentum) _, params_grads = optimizer.minimize(sum_cost) - - return sum_cost, error_evaluator, inference_program + model_average = None + if args.model_average: + model_average = fluid.optimizer.ModelAverage( + params_grads, + args.average_window, + min_average_window=args.min_average_window, + max_average_window=args.max_average_window) + decoded_out = fluid.layers.ctc_greedy_decoder( + input=fc_out, blank=num_classes) + casted_label = fluid.layers.cast(x=label, dtype='int64') + error_evaluator = fluid.evaluator.EditDistance( + input=decoded_out, label=casted_label) + return sum_cost, error_evaluator, inference_program, model_average def ctc_infer(images, num_classes): diff --git a/fluid/ocr_recognition/ctc_train.py b/fluid/ocr_recognition/ctc_train.py index c2d8fd26bb..a02017ccd0 100644 --- a/fluid/ocr_recognition/ctc_train.py +++ b/fluid/ocr_recognition/ctc_train.py @@ -23,6 +23,10 @@ add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.") add_arg('device', int, 0, "Device id.'-1' means running on CPU" "while '0' means GPU-0.") +add_arg('model_average', bool, True, "Whether to aevrage model for evaluation.") +add_arg('min_average_window', int, 10000, "Min average window.") +add_arg('max_average_window', int, 15625, "Max average window.") +add_arg('average_window', float, 0.15, "Average window.") add_arg('parallel', bool, True, "Whether use parallel training.") # yapf: disable @@ -40,7 +44,7 @@ def train(args, data_reader=dummy_reader): # define network images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1) - sum_cost, error_evaluator, inference_program = ctc_train_net(images, label, args, num_classes) + sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(images, label, args, num_classes) # data reader train_reader = data_reader.train(args.batch_size) @@ -75,12 +79,16 @@ def train(args, data_reader=dummy_reader): sys.stdout.flush() batch_id += 1 - error_evaluator.reset(exe) - for data in test_reader(): - exe.run(inference_program, feed=get_feeder_data(data, place)) - _, test_seq_error = error_evaluator.eval(exe) - print "\nEnd pass[%d]; Test seq error: %s.\n" % ( - pass_id, str(test_seq_error[0])) + with model_average.apply(exe): + error_evaluator.reset(exe) + for data in test_reader(): + exe.run(inference_program, feed=get_feeder_data(data, place)) + _, test_seq_error = error_evaluator.eval(exe) + if model_average != None: + model_average.restore(exe) + + print "\nEnd pass[%d]; Test seq error: %s.\n" % ( + pass_id, str(test_seq_error[0])) def main(): args = parser.parse_args()