Skip to content

Commit

Permalink
chore(core): switch backend to spawn (#429)
Browse files Browse the repository at this point in the history
chore(core): switch backend to spawn
  • Loading branch information
FateScript committed Aug 10, 2021
1 parent 026844e commit 3bfac8f
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 159 deletions.
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,28 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o
* -b: total batch size, the recommended number for -b is num-gpu * 8
* --fp16: mixed precision training

When using -f, the above commands are equivalent to:
```shell
python tools/train.py -f exps/default/yolox_s.py -d 8 -b 64 --fp16 -o
exps/default/yolox_m.py
exps/default/yolox_l.py
exps/default/yolox_x.py
```

**Multi Machine Training**

We also support multi-nodes training. Just add the following args:
* --num\_machines: num of your total training nodes
* --machine\_rank: specify the rank of each node

When using -f, the above commands are equivalent to:

Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
On master machine, run
```shell
python tools/train.py -f exps/default/yolox_s.py -d 8 -b 64 --fp16 -o
exps/default/yolox_m.py
exps/default/yolox_l.py
exps/default/yolox_x.py
python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num-machines 2 --machine-rank 0
```
On the second machine, run
```shell
python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num-machines 2 --machine-rank 1
```

</details>
Expand Down Expand Up @@ -153,7 +162,7 @@ python tools/eval.py -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --fp16 --
</details>


<details open>
<details>
<summary>Tutorials</summary>

* [Training on custom data](docs/train_custom_data.md).
Expand Down
15 changes: 6 additions & 9 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from yolox.core import launch
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, setup_logger
from yolox.utils import configure_nccl, fuse_model, get_local_rank, get_model_info, setup_logger


def make_parser():
Expand All @@ -36,9 +36,6 @@ def make_parser():
parser.add_argument(
"-d", "--devices", default=None, type=int, help="device for training"
)
parser.add_argument(
"--local_rank", default=0, type=int, help="local rank for dist training"
)
parser.add_argument(
"--num_machines", default=1, type=int, help="num of node for training"
)
Expand Down Expand Up @@ -114,10 +111,10 @@ def main(exp, args, num_gpu):
is_distributed = num_gpu > 1

# set environment variables for distributed training
configure_nccl()
cudnn.benchmark = True

rank = args.local_rank
# rank = get_local_rank()
rank = get_local_rank()

file_name = os.path.join(exp.output_dir, args.experiment_name)

Expand Down Expand Up @@ -149,10 +146,9 @@ def main(exp, args, num_gpu):
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
logger.info("loading checkpoint")
logger.info("loading checkpoint from {}".format(ckpt_file))
loc = "cuda:{}".format(rank)
ckpt = torch.load(ckpt_file, map_location=loc)
# load the model state dict
model.load_state_dict(ckpt["model"])
logger.info("loaded checkpoint done.")

Expand Down Expand Up @@ -195,12 +191,13 @@ def main(exp, args, num_gpu):
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
assert num_gpu <= torch.cuda.device_count()

dist_url = "auto" if args.dist_url is None else args.dist_url
launch(
main,
num_gpu,
args.num_machines,
args.machine_rank,
backend=args.dist_backend,
dist_url=args.dist_url,
dist_url=dist_url,
args=(exp, args, num_gpu),
)
9 changes: 5 additions & 4 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from yolox.core import Trainer, launch
from yolox.exp import get_exp
from yolox.utils import configure_nccl, configure_omp


def make_parser():
Expand All @@ -33,9 +34,6 @@ def make_parser():
parser.add_argument(
"-d", "--devices", default=None, type=int, help="device for training"
)
parser.add_argument(
"--local_rank", default=0, type=int, help="local rank for dist training"
)
parser.add_argument(
"-f",
"--exp_file",
Expand Down Expand Up @@ -97,6 +95,8 @@ def main(exp, args):
)

# set environment variables for distributed training
configure_nccl()
configure_omp()
cudnn.benchmark = True

trainer = Trainer(exp, args)
Expand All @@ -114,12 +114,13 @@ def main(exp, args):
num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
assert num_gpu <= torch.cuda.device_count()

dist_url = "auto" if args.dist_url is None else args.dist_url
launch(
main,
num_gpu,
args.num_machines,
args.machine_rank,
backend=args.dist_backend,
dist_url=args.dist_url,
dist_url=dist_url,
args=(exp, args),
)
165 changes: 34 additions & 131 deletions yolox/core/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,19 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Megvii, Inc. and its affiliates.

import os
import subprocess
import sys
import time
from datetime import timedelta
from loguru import logger

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import yolox.utils.dist as comm
from yolox.utils import configure_nccl

# import torch.multiprocessing as mp
__all__ = ["launch"]


__all__ = ["launch"]
DEFAULT_TIMEOUT = timedelta(minutes=30)


def _find_free_port():
Expand All @@ -46,6 +43,7 @@ def launch(
backend="nccl",
dist_url=None,
args=(),
timeout=DEFAULT_TIMEOUT,
):
"""
Args:
Expand All @@ -59,128 +57,39 @@ def launch(
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
if int(os.environ.get("WORLD_SIZE", "1")) > 1:
dist_url = "{}:{}".format(
os.environ.get("MASTER_ADDR", None),
os.environ.get("MASTER_PORT", "None"),
)
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
_distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
num_machines,
machine_rank,
backend,
dist_url,
args,
)
exit()
launch_by_subprocess(
sys.argv,
world_size,
num_machines,
machine_rank,
num_gpus_per_machine,
dist_url,
args,
)
else:
main_func(*args)

# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes

def launch_by_subprocess(
raw_argv,
world_size,
num_machines,
machine_rank,
num_gpus_per_machine,
dist_url,
args,
):
assert (
world_size > 1
), "subprocess mode doesn't support single GPU, use spawn mode instead"

if dist_url is None:
# ------------------------hack for multi-machine training -------------------- #
if num_machines > 1:
master_ip = subprocess.check_output(["hostname", "--fqdn"]).decode("utf-8")
master_ip = str(master_ip).strip()
dist_url = "tcp://{}".format(master_ip)
ip_add_file = "./" + args[1].experiment_name + "_ip_add.txt"
if machine_rank == 0:
port = _find_free_port()
with open(ip_add_file, "w") as ip_add:
ip_add.write(dist_url+'\n')
ip_add.write(str(port))
else:
while not os.path.exists(ip_add_file):
time.sleep(0.5)

with open(ip_add_file, "r") as ip_add:
dist_url = ip_add.readline().strip()
port = ip_add.readline()
else:
dist_url = "tcp://127.0.0.1"
if dist_url == "auto":
assert num_machines == 1, "dist_url=auto cannot work with distributed training."
port = _find_free_port()

# set PyTorch distributed related environmental variables
current_env = os.environ.copy()
current_env["MASTER_ADDR"] = dist_url
current_env["MASTER_PORT"] = str(port)
current_env["WORLD_SIZE"] = str(world_size)
assert num_gpus_per_machine <= torch.cuda.device_count()

if "OMP_NUM_THREADS" not in os.environ and num_gpus_per_machine > 1:
current_env["OMP_NUM_THREADS"] = str(1)
logger.info(
"\n*****************************************\n"
"Setting OMP_NUM_THREADS environment variable for each process "
"to be {} in default, to avoid your system being overloaded, "
"please further tune the variable for optimal performance in "
"your application as needed. \n"
"*****************************************".format(
current_env["OMP_NUM_THREADS"]
)
dist_url = f"tcp://127.0.0.1:{port}"

mp.spawn(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(
main_func, world_size, num_gpus_per_machine,
machine_rank, backend, dist_url, args
),
daemon=False,
)

processes = []
for local_rank in range(0, num_gpus_per_machine):
# each process's rank
dist_rank = machine_rank * num_gpus_per_machine + local_rank
current_env["RANK"] = str(dist_rank)
current_env["LOCAL_RANK"] = str(local_rank)

# spawn the processes
cmd = ["python3", *raw_argv]

process = subprocess.Popen(cmd, env=current_env)
processes.append(process)

for process in processes:
process.wait()
if process.returncode != 0:
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
else:
main_func(*args)


def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
num_machines,
machine_rank,
backend,
dist_url,
args,
timeout=DEFAULT_TIMEOUT,
):
assert (
torch.cuda.is_available()
), "cuda is not available. Please check your installation."
configure_nccl()
assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
global_rank = machine_rank * num_gpus_per_machine + local_rank
logger.info("Rank {} initialization finished.".format(global_rank))
try:
Expand All @@ -189,32 +98,26 @@ def _distributed_worker(
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
except Exception:
logger.error("Process group URL: {}".format(dist_url))
raise

# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg

# synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()

if global_rank == 0 and os.path.exists(
"./" + args[1].experiment_name + "_ip_add.txt"
):
os.remove("./" + args[1].experiment_name + "_ip_add.txt")

assert num_gpus_per_machine <= torch.cuda.device_count()
torch.cuda.set_device(local_rank)

args[1].local_rank = local_rank
args[1].num_machines = num_machines

# Setup the local process group (which contains ranks within the same machine)
# assert comm._LOCAL_PROCESS_GROUP is None
# num_machines = world_size // num_gpus_per_machine
# for i in range(num_machines):
# ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
# pg = dist.new_group(ranks_on_i)
# if i == machine_rank:
# comm._LOCAL_PROCESS_GROUP = pg

main_func(*args)

0 comments on commit 3bfac8f

Please sign in to comment.