Skip to content

Commit

Permalink
fix code format
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
luotao1 committed Sep 23, 2020
1 parent 6a13e7c commit 99a9490
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion dygraph/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,25 @@ def parse_args():
args = parse_args()
batch_size = args.batch_size


class TimeCostAverage(object):
def __init__(self):
self.reset()

def reset(self):
self.cnt = 0
self.total_time = 0

def record(self, usetime):
self.cnt += 1
self.total_time += usetime

def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt


def optimizer_setting(parameter_list=None):

total_images = IMAGENET1000
Expand Down Expand Up @@ -493,7 +498,8 @@ def train_resnet():
"[Epoch %d, batch %d] loss %.5f, acc1 %.5f, acc5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s"
% (eop, batch_id, total_loss / total_sample,
total_acc1 / total_sample, total_acc5 / total_sample,
train_batch_cost_avg.get_average(), train_reader_cost_avg.get_average()))
train_batch_cost_avg.get_average(),
train_reader_cost_avg.get_average()))
train_batch_cost_avg.reset()
train_reader_cost_avg.reset()
batch_start = time.time()
Expand Down

0 comments on commit 99a9490

Please sign in to comment.