Skip to content

Commit

Permalink
Enable CPU training for DyGraph MNIST Resnet (#4824)
Browse files Browse the repository at this point in the history
  • Loading branch information
arlesniak committed Sep 1, 2020
1 parent 64cde5d commit f9f0d30
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
24 changes: 20 additions & 4 deletions dygraph/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def parse_args():
)
parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')

args = parser.parse_args()
return args

Expand Down Expand Up @@ -149,8 +155,13 @@ def test_mnist(reader, model, batch_size):


def inference_mnist():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)

with fluid.dygraph.guard(place):
mnist_infer = MNIST()
# load checkpoint
Expand Down Expand Up @@ -180,8 +191,13 @@ def train_mnist(args):
epoch_num = args.epoch
BATCH_SIZE = 64

place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)

with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
Expand Down
15 changes: 13 additions & 2 deletions dygraph/resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def parse_args():
type=float,
default=[0.229, 0.224, 0.225],
help="The std of input image data")
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')

args = parser.parse_args()
return args
Expand Down Expand Up @@ -354,8 +359,14 @@ def eval(model, data):

def train_resnet():
epoch = args.epoch
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)

if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)

with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
Expand Down

0 comments on commit f9f0d30

Please sign in to comment.