diff --git a/openequivariance/extlib/__init__.py b/openequivariance/extlib/__init__.py index b6c16b24..b5e25372 100644 --- a/openequivariance/extlib/__init__.py +++ b/openequivariance/extlib/__init__.py @@ -120,6 +120,8 @@ def postprocess(kernel): extra_include_paths=include_dirs, extra_ldflags=extra_link_args, ) + if "generic_module" not in sys.modules: + sys.modules["generic_module"] = generic_module if not TORCH_COMPILE: warnings.warn(