Skip to content

Commit

Permalink
simplify recognize digits example code (#10722)
Browse files Browse the repository at this point in the history
  • Loading branch information
kexinzhao authored and daming-lu committed May 17, 2018
1 parent 2a63652 commit bbd7580
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,18 @@ def event_handler(event):
if isinstance(event, fluid.EndEpochEvent):
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
test_metrics = trainer.test(
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
avg_cost_set = test_metrics[0]
acc_set = test_metrics[1]

# get test acc and loss
acc = numpy.array(acc_set).mean()
avg_cost = numpy.array(avg_cost_set).mean()

print("avg_cost: %s" % avg_cost)
print("acc : %s" % acc)

if float(acc) > 0.2: # Smaller value to increase CI speed
if acc > 0.2: # Smaller value to increase CI speed
trainer.save_params(save_dirname)
else:
print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format(
event.epoch + 1, float(avg_cost), float(acc)))
if math.isnan(float(avg_cost)):
event.epoch + 1, avg_cost, acc))
if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.")
elif isinstance(event, fluid.EndStepEvent):
print("Step {0}, Epoch {1} Metrics {2}".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,18 @@ def event_handler(event):
if isinstance(event, fluid.EndEpochEvent):
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
test_metrics = trainer.test(
avg_cost, acc = trainer.test(
reader=test_reader, feed_order=['img', 'label'])
avg_cost_set = test_metrics[0]
acc_set = test_metrics[1]

# get test acc and loss
acc = numpy.array(acc_set).mean()
avg_cost = numpy.array(avg_cost_set).mean()

print("avg_cost: %s" % avg_cost)
print("acc : %s" % acc)

if float(acc) > 0.2: # Smaller value to increase CI speed
if acc > 0.2: # Smaller value to increase CI speed
trainer.save_params(save_dirname)
else:
print('BatchID {0}, Test Loss {1:0.2}, Acc {2:0.2}'.format(
event.epoch + 1, float(avg_cost), float(acc)))
if math.isnan(float(avg_cost)):
event.epoch + 1, avg_cost, acc))
if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.")

train_reader = paddle.batch(
Expand Down

0 comments on commit bbd7580

Please sign in to comment.