Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions monai/_extensions/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import platform
from _thread import interrupt_main
from contextlib import contextmanager
from glob import glob
Expand All @@ -18,7 +19,7 @@

import torch

from monai.utils.module import optional_import
from monai.utils.module import get_torch_version_tuple, optional_import

dir_path = path.dirname(path.realpath(__file__))

Expand Down Expand Up @@ -61,6 +62,8 @@ def load_module(
if not path.exists(module_dir):
raise ValueError(f"No extension module named {module_name}")

platform_str = f"_{platform.system()}_{platform.python_version()}_"
platform_str += "".join(f"{v}" for v in get_torch_version_tuple()[:2])
# Adding configuration to module name.
if defines is not None:
module_name = "_".join([module_name] + [f"{v}" for v in defines.values()])
Expand All @@ -69,6 +72,7 @@ def load_module(
source = glob(path.join(module_dir, "**", "*.cpp"), recursive=True)
if torch.cuda.is_available():
source += glob(path.join(module_dir, "**", "*.cu"), recursive=True)
platform_str += f"_{torch.version.cuda}"

# Constructing compilation argument list.
define_args = [] if not defines else [f"-D {key}={defines[key]}" for key in defines]
Expand All @@ -78,8 +82,9 @@ def load_module(
with timeout(build_timeout, "Build appears to be blocked. Is there a stopped process building the same extension?"):
load, _ = optional_import("torch.utils.cpp_extension", name="load") # main trigger some JIT config in pytorch
# This will either run the build or return the existing .so object.
name = module_name + platform_str.replace(".", "_")
module = load(
name=module_name,
name=name,
sources=source,
extra_cflags=define_args,
extra_cuda_cflags=define_args,
Expand Down