diff --git a/ml-agents/mlagents/torch_utils/cpu_utils.py b/ml-agents/mlagents/torch_utils/cpu_utils.py new file mode 100644 index 0000000000..e0272fad51 --- /dev/null +++ b/ml-agents/mlagents/torch_utils/cpu_utils.py @@ -0,0 +1,34 @@ +from typing import Optional + +import os + + +def get_num_threads_to_use() -> Optional[int]: + """ + Gets the number of threads to use. For most problems, 4 is all you + need, but for smaller machines, we'd like to scale to less than that. + By default, PyTorch uses 1/2 of the available cores. + """ + num_cpus = _get_num_available_cpus() + return max(min(num_cpus // 2, 4), 1) if num_cpus is not None else None + + +def _get_num_available_cpus() -> Optional[int]: + """ + Returns number of CPUs using cgroups if possible. This accounts + for Docker containers that are limited in cores. + """ + period = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_period_us") + quota = _read_in_integer_file("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") + if period > 0 and quota > 0: + return int(quota // period) + else: + return os.cpu_count() + + +def _read_in_integer_file(filename: str) -> int: + try: + with open(filename) as f: + return int(f.read().rstrip()) + except FileNotFoundError: + return -1 diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index 98463fa3b6..f2fd8d18aa 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -1,5 +1,7 @@ import os +from mlagents.torch_utils import cpu_utils + # Detect availability of torch package here. # NOTE: this try/except is temporary until torch is required for ML-Agents. try: @@ -7,7 +9,7 @@ # Everywhere else is caught by the banned-modules setting for flake8 import torch # noqa I201 - torch.set_num_interop_threads(2) + torch.set_num_threads(cpu_utils.get_num_threads_to_use()) os.environ["KMP_BLOCKTIME"] = "0" # Known PyLint compatibility with PyTorch https://github.com/pytorch/pytorch/issues/701