Skip to content

Commit

Permalink
Adjust resnet18+cifar10 benchmark and example
Browse files Browse the repository at this point in the history
  • Loading branch information
bichengying committed May 25, 2020
1 parent ff91803 commit 0eb4bec
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
6 changes: 5 additions & 1 deletion examples/pytorch_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,12 @@ def forward(self, x):
# Set up fake data
datasets = []
for _ in range(100):
# First two should be CPU usage only.
if args.model == "lenet":
data = torch.rand(args.batch_size, 1, 28, 28)
data = torch.rand(args.batch_size, 1, 28, 28) # mnist size
target = torch.LongTensor(args.batch_size).random_() % 10
elif args.model == 'resnet18':
data = torch.rand(args.batch_size, 3, 32, 32) # CIFAR10 size
target = torch.LongTensor(args.batch_size).random_() % 10
else:
data = torch.rand(args.batch_size, 3, 224, 224)
Expand Down
32 changes: 24 additions & 8 deletions examples/pytorch_cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"executing allreduce across workers; it multiplies "
"total batch size.",
)
parser.add_argument('--model', type=str, default='resnet50',
parser.add_argument('--model', type=str, default='resnet18',
help='model to benchmark')

# Default settings from https://arxiv.org/abs/1706.02677.
Expand All @@ -71,7 +71,7 @@
parser.add_argument(
"--val-batch-size", type=int, default=32, help="input batch size for validation"
)
parser.add_argument("--epochs", type=int, default=50,
parser.add_argument("--epochs", type=int, default=90,
help="number of epochs to train")
parser.add_argument(
"--base-lr", type=float, default=0.0125, help="learning rate for a single GPU"
Expand All @@ -90,6 +90,13 @@
default=False, help="disables bluefog library")
parser.add_argument("--no-rma", action="store_true",
default=False, help="Do no use remote memory access(no window ops).")
parser.add_argument("--average-test-result", action="store_true",
default=False,
help=("Allreduce called to average test result. Warning this will " +
"force the algorithm to sync every end of epoch."))
parser.add_argument("--enable-dynamic-topology", action="store_true",
default=False, help=("Enable each iteration to transmit one neighbor " +
"per iteration dynamically."))

args = parser.parse_args()
args.cuda = (not args.no_cuda) and (torch.cuda.is_available())
Expand Down Expand Up @@ -180,9 +187,13 @@
]
),
)
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset, num_replicas=bf.size(), rank=bf.rank()
)
if args.average_test_result:
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset, num_replicas=bf.size(), rank=bf.rank()
)
else:
val_sampler = None

val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.val_batch_size, sampler=val_sampler, **kwargs
)
Expand Down Expand Up @@ -251,7 +262,8 @@ def train(epoch):
disable=not verbose,) as t:
for batch_idx, (data, target) in enumerate(train_loader):
adjust_learning_rate(epoch, batch_idx)
dynamic_topology_update(epoch, batch_idx)
if args.enable_dynamic_topology:
dynamic_topology_update(epoch, batch_idx)

if args.cuda:
data, target = data.cuda(), target.cuda()
Expand Down Expand Up @@ -369,7 +381,10 @@ def __init__(self, name):
self.n = torch.tensor(0.0) # pylint: disable=not-callable

def update(self, val):
self.sum += bf.allreduce(val.detach().cpu(), name=self.name)
if args.average_test_result:
self.sum += bf.allreduce(val.detach().cpu(), name=self.name)
else:
self.sum += val.detach().cpu()
self.n += 1

@property
Expand All @@ -379,5 +394,6 @@ def avg(self):

for epoch in range(resume_from_epoch, args.epochs):
train(epoch)
validate(epoch)
if epoch % 3 == 0:
validate(epoch)
save_checkpoint(epoch)
16 changes: 9 additions & 7 deletions examples/pytorch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.utils.data.distributed
from torchvision import datasets, transforms

sys.path.insert(0, os.path.abspath(
os.path.join(os.path.dirname(__file__), "..")))
Expand Down Expand Up @@ -139,10 +139,12 @@
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
# Bluefog: use DistributedSampler to partition the test data.
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset, num_replicas=bf.size(), rank=bf.rank()
)
test_sampler = None
if args.average_test_result:
# Bluefog: use DistributedSampler to partition the test data.
test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset, num_replicas=bf.size(), rank=bf.rank()
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=args.test_batch_size, sampler=test_sampler, **kwargs
)
Expand Down Expand Up @@ -260,8 +262,8 @@ def test():

# Bluefog: use test_sampler to determine the number of examples in
# this worker's partition.
test_loss /= len(test_sampler)
test_accuracy /= len(test_sampler)
test_loss /= len(test_sampler) if test_sampler else len(test_dataset)
test_accuracy /= len(test_sampler) if test_sampler else len(test_dataset)

# Bluefog: average metric values across workers.
if args.average_test_result:
Expand Down

0 comments on commit 0eb4bec

Please sign in to comment.