From bdbb7a2183c7d77da9c64b67d912ed4f9be0d5e3 Mon Sep 17 00:00:00 2001 From: prometheus <57171759+NLGithubWP@users.noreply.github.com> Date: Fri, 14 Jun 2024 10:58:02 +0800 Subject: [PATCH] Add multiprocess version of the model --- .../cnn_ms/autograd/mnist_multiprocess.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 examples/cnn_ms/autograd/mnist_multiprocess.py diff --git a/examples/cnn_ms/autograd/mnist_multiprocess.py b/examples/cnn_ms/autograd/mnist_multiprocess.py new file mode 100644 index 000000000..a1a1fe58e --- /dev/null +++ b/examples/cnn_ms/autograd/mnist_multiprocess.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from mnist_cnn import * +import multiprocessing +import sys + +if __name__ == '__main__': + + # Generate a NCCL ID to be used for collective communication + nccl_id = singa.NcclIdHolder() + + # Number of GPUs to be used + world_size = int(sys.argv[1]) + + process = [] + for local_rank in range(0, world_size): + process.append( + multiprocessing.Process(target=train_mnist_cnn, + args=(True, local_rank, world_size, nccl_id))) + + for p in process: + p.start() \ No newline at end of file