Skip to content

Commit

Permalink
Update the mnist usage for more async-style
Browse files Browse the repository at this point in the history
  • Loading branch information
bichengying committed May 25, 2020
1 parent b76ac04 commit 5ad46ae
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions examples/pytorch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@
default=False, help="disables bluefog library. Use horovod instead.")
parser.add_argument("--no-rma", action="store_true",
default=False, help="Do no use remote memory access(no window ops).")
parser.add_argument("--average-test-result", action="store_true",
default=False,
help=("Allreduce called to average test result. Warning this will " +
"force the algorithm to sync every end of epoch."))

parser.add_argument(
"--seed", type=int, default=42, metavar="S", help="random seed (default: 42)"
Expand Down Expand Up @@ -222,7 +226,8 @@ def train(epoch):
# Bluefog: use train_sampler to determine the number of examples in
# this worker's partition.
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
"[{}] Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
bf.rank(),
epoch,
batch_idx * len(data),
len(train_sampler),
Expand Down Expand Up @@ -259,8 +264,9 @@ def test():
test_accuracy /= len(test_sampler)

# Bluefog: average metric values across workers.
test_loss = metric_average(test_loss, "avg_loss")
test_accuracy = metric_average(test_accuracy, "avg_accuracy")
if args.average_test_result:
test_loss = metric_average(test_loss, "avg_loss")
test_accuracy = metric_average(test_accuracy, "avg_accuracy")

# Bluefog: print output only on first rank.
if bf.rank() == 0:
Expand All @@ -276,5 +282,5 @@ def test():
for epoch in range(1, args.epochs + 1):
train(epoch)
record.append(test())
if bf.rank() == 0:
print(record)

print(f"[{bf.rank()}]: ", record)

0 comments on commit 5ad46ae

Please sign in to comment.