Skip to content

Commit

Permalink
Merge pull request #740 from wanghaoshuang/model_avg
Browse files Browse the repository at this point in the history
Add model average option for OCR CTC model
  • Loading branch information
wanghaoshuang committed Mar 27, 2018
2 parents 36ca387 + 7df53c9 commit 61499cd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
15 changes: 13 additions & 2 deletions fluid/ocr_recognition/crnn_ctc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,19 @@ def ctc_train_net(images, label, args, num_classes):
optimizer = fluid.optimizer.Momentum(
learning_rate=args.learning_rate, momentum=args.momentum)
_, params_grads = optimizer.minimize(sum_cost)

return sum_cost, error_evaluator, inference_program
model_average = None
if args.model_average:
model_average = fluid.optimizer.ModelAverage(
params_grads,
args.average_window,
min_average_window=args.min_average_window,
max_average_window=args.max_average_window)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
casted_label = fluid.layers.cast(x=label, dtype='int64')
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
return sum_cost, error_evaluator, inference_program, model_average


def ctc_infer(images, num_classes):
Expand Down
22 changes: 15 additions & 7 deletions fluid/ocr_recognition/ctc_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
add_arg('model_average', bool, True, "Whether to aevrage model for evaluation.")
add_arg('min_average_window', int, 10000, "Min average window.")
add_arg('max_average_window', int, 15625, "Max average window.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, True, "Whether use parallel training.")
# yapf: disable

Expand All @@ -40,7 +44,7 @@ def train(args, data_reader=dummy_reader):
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1)
sum_cost, error_evaluator, inference_program = ctc_train_net(images, label, args, num_classes)
sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(images, label, args, num_classes)

# data reader
train_reader = data_reader.train(args.batch_size)
Expand Down Expand Up @@ -75,12 +79,16 @@ def train(args, data_reader=dummy_reader):
sys.stdout.flush()
batch_id += 1

error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
print "\nEnd pass[%d]; Test seq error: %s.\n" % (
pass_id, str(test_seq_error[0]))
with model_average.apply(exe):
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
if model_average != None:
model_average.restore(exe)

print "\nEnd pass[%d]; Test seq error: %s.\n" % (
pass_id, str(test_seq_error[0]))

def main():
args = parser.parse_args()
Expand Down

0 comments on commit 61499cd

Please sign in to comment.