Skip to content

Commit

Permalink
fix resnet dygraph model time print (#4868)
Browse files Browse the repository at this point in the history
  • Loading branch information
luotao1 committed Sep 23, 2020
1 parent ba9a787 commit 38ada7f
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion dygraph/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@ def parse_args():
batch_size = args.batch_size


class TimeCostAverage(object):
def __init__(self):
self.reset()

def reset(self):
self.cnt = 0
self.total_time = 0

def record(self, usetime):
self.cnt += 1
self.total_time += usetime

def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt


def optimizer_setting(parameter_list=None):

total_images = IMAGENET1000
Expand Down Expand Up @@ -433,6 +451,8 @@ def train_resnet():
total_acc5 = 0.0
total_sample = 0

train_batch_cost_avg = TimeCostAverage()
train_reader_cost_avg = TimeCostAverage()
batch_start = time.time()
for batch_id, data in enumerate(train_loader()):
#NOTE: used in benchmark
Expand Down Expand Up @@ -469,13 +489,19 @@ def train_resnet():
total_sample += 1

train_batch_cost = time.time() - batch_start
train_batch_cost_avg.record(train_batch_cost)
train_reader_cost_avg.record(train_reader_cost)

total_batch_num = total_batch_num + 1 #this is for benchmark
if batch_id % 10 == 0:
print(
"[Epoch %d, batch %d] loss %.5f, acc1 %.5f, acc5 %.5f, batch_cost: %.5f s, reader_cost: %.5f s"
% (eop, batch_id, total_loss / total_sample,
total_acc1 / total_sample, total_acc5 / total_sample,
train_batch_cost, train_reader_cost))
train_batch_cost_avg.get_average(),
train_reader_cost_avg.get_average()))
train_batch_cost_avg.reset()
train_reader_cost_avg.reset()
batch_start = time.time()

if args.ce:
Expand Down

0 comments on commit 38ada7f

Please sign in to comment.