From 52f165b04e369ad65bf6c782e42915df6ed0e644 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 21 Jun 2021 00:18:04 +0100 Subject: [PATCH] fixes loader jit Signed-off-by: Wenqi Li --- monai/_extensions/loader.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/monai/_extensions/loader.py b/monai/_extensions/loader.py index 20736cfd2b..5f77480ecc 100644 --- a/monai/_extensions/loader.py +++ b/monai/_extensions/loader.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import platform from _thread import interrupt_main from contextlib import contextmanager from glob import glob @@ -18,7 +19,7 @@ import torch -from monai.utils.module import optional_import +from monai.utils.module import get_torch_version_tuple, optional_import dir_path = path.dirname(path.realpath(__file__)) @@ -61,6 +62,8 @@ def load_module( if not path.exists(module_dir): raise ValueError(f"No extension module named {module_name}") + platform_str = f"_{platform.system()}_{platform.python_version()}_" + platform_str += "".join(f"{v}" for v in get_torch_version_tuple()[:2]) # Adding configuration to module name. if defines is not None: module_name = "_".join([module_name] + [f"{v}" for v in defines.values()]) @@ -69,6 +72,7 @@ def load_module( source = glob(path.join(module_dir, "**", "*.cpp"), recursive=True) if torch.cuda.is_available(): source += glob(path.join(module_dir, "**", "*.cu"), recursive=True) + platform_str += f"_{torch.version.cuda}" # Constructing compilation argument list. define_args = [] if not defines else [f"-D {key}={defines[key]}" for key in defines] @@ -78,8 +82,9 @@ def load_module( with timeout(build_timeout, "Build appears to be blocked. Is there a stopped process building the same extension?"): load, _ = optional_import("torch.utils.cpp_extension", name="load") # main trigger some JIT config in pytorch # This will either run the build or return the existing .so object. + name = module_name + platform_str.replace(".", "_") module = load( - name=module_name, + name=name, sources=source, extra_cflags=define_args, extra_cuda_cflags=define_args,