Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,41 @@ def import_string(dotted_path: str):

Raise ImportError if the import failed.
"""
# TODO: Add support for nested classes. Currently, it only works for top-level classes.
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError:
raise ImportError(f"{dotted_path} doesn't look like a module path")

module = import_module(module_path)

nested_attributes = [class_name]
while True:
try:
module = import_module(module_path)
break
except (ModuleNotFoundError, ImportError) as e:
# Check if we can backtrack (need at least one dot to split)
if "." not in module_path:
# No more dots to split, this is a genuine import error
raise

# Check if this is a "not a package" error or module not found
# In both cases, we should try backtracking
error_msg = str(e)
if "is not a package" in error_msg or isinstance(e, ModuleNotFoundError):
# Backtrack: assume the last component is a class, not a module
module_path, nested_class_name = module_path.rsplit(".", 1)
nested_attributes.insert(0, nested_class_name)
else:
# Some other import error, don't backtrack
raise

attribute_path = ".".join(nested_attributes)
attribute_value = module
try:
return getattr(module, class_name)
for attribute_name in nested_attributes:
attribute_value = getattr(attribute_value, attribute_name)
return attribute_value
except AttributeError:
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute/class')
raise ImportError(f'Module "{module_path}" does not define a "{attribute_path}" attribute/class')


def qualname(o: object | Callable, use_qualname: bool = False, exclude_module: bool = False) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ def _sample_function():
pass


class _OuterClass:
class _InnerClass:
pass


class TestModuleImport:
def test_import_string(self):
cls = import_string("module_loading.test_module_loading._import_string")
assert cls == _import_string
nested_cls = import_string("module_loading.test_module_loading._OuterClass._InnerClass")
assert nested_cls == _OuterClass._InnerClass

# Test exceptions raised
with pytest.raises(ImportError):
Expand Down