From 3091ae33e3180137f5b0dee6f463a915a5b03a76 Mon Sep 17 00:00:00 2001 From: Ervin T Date: Fri, 11 Sep 2020 14:23:59 -0700 Subject: [PATCH] [bug-fix] Set number of threads based on allocated CPU count in Docker containers (#4471) * Set num threads properly for Docker * Pylint-friendly logic * Use f.read().rstrip() * Change function names --- ml-agents/mlagents/torch_utils/cpu_utils.py | 34 +++++++++++++++++++++ ml-agents/mlagents/torch_utils/torch.py | 4 ++- 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 ml-agents/mlagents/torch_utils/cpu_utils.py diff --git a/ml-agents/mlagents/torch_utils/cpu_utils.py b/ml-agents/mlagents/torch_utils/cpu_utils.py new file mode 100644 index 0000000000..e0272fad51 --- /dev/null +++ b/ml-agents/mlagents/torch_utils/cpu_utils.py @@ -0,0 +1,34 @@ +from typing import Optional + +import os + + +def get_num_threads_to_use() -> Optional[int]: + """ + Gets the number of threads to use. For most problems, 4 is all you + need, but for smaller machines, we'd like to scale to less than that. + By default, PyTorch uses 1/2 of the available cores. + """ + num_cpus = _get_num_available_cpus() + return max(min(num_cpus // 2, 4), 1) if num_cpus is not None else None + + +def _get_num_available_cpus() -> Optional[int]: + """ + Returns number of CPUs using cgroups if possible. This accounts + for Docker containers that are limited in cores. + """ + period = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_period_us") + quota = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") + if period > 0 and quota > 0: + return int(quota // period) + else: + return os.cpu_count() + + +def _read_in_integer_file(filename: str) -> int: + try: + with open(filename) as f: + return int(f.read().rstrip()) + except FileNotFoundError: + return -1 diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index 98463fa3b6..f2fd8d18aa 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -1,5 +1,7 @@ import os +from mlagents.torch_utils import cpu_utils + # Detect availability of torch package here. # NOTE: this try/except is temporary until torch is required for ML-Agents. try: @@ -7,7 +9,7 @@ # Everywhere else is caught by the banned-modules setting for flake8 import torch # noqa I201 - torch.set_num_interop_threads(2) + torch.set_num_threads(cpu_utils.get_num_threads_to_use()) os.environ["KMP_BLOCKTIME"] = "0" # Known PyLint compatibility with PyTorch https://github.com/pytorch/pytorch/issues/701