diff --git a/src/lightning/app/utilities/load_app.py b/src/lightning/app/utilities/load_app.py index 1f49fed2ad888..7f50c344db842 100644 --- a/src/lightning/app/utilities/load_app.py +++ b/src/lightning/app/utilities/load_app.py @@ -112,7 +112,7 @@ def load_app_from_file(filepath: str, raise_exception: bool = False, mock_import ) # TODO: Remove this, downstream code shouldn't depend on side-effects here but it does - _patch_sys_path(os.path.dirname(os.path.abspath(filepath))).__enter__() + sys.path.append(os.path.dirname(os.path.abspath(filepath))) sys.modules["__main__"] = main_module if len(apps) > 1: diff --git a/tests/tests_app/core/scripts/app_with_local_import.py b/tests/tests_app/core/scripts/app_with_local_import.py new file mode 100644 index 0000000000000..d1cfefed788fa --- /dev/null +++ b/tests/tests_app/core/scripts/app_with_local_import.py @@ -0,0 +1,5 @@ +from app_metadata import RootFlow + +from lightning.app.core.app import LightningApp + +app = LightningApp(RootFlow()) diff --git a/tests/tests_app/utilities/test_load_app.py b/tests/tests_app/utilities/test_load_app.py index c8a08682c6dee..14da8a8acc1b9 100644 --- a/tests/tests_app/utilities/test_load_app.py +++ b/tests/tests_app/utilities/test_load_app.py @@ -1,15 +1,15 @@ import os +import sys from unittest.mock import ANY import pytest -import tests_app.core.scripts from lightning.app.utilities.exceptions import MisconfigurationException from lightning.app.utilities.load_app import extract_metadata_from_app, load_app_from_file -def test_load_app_from_file(): - test_script_dir = os.path.join(os.path.dirname(tests_app.core.__file__), "scripts") +def test_load_app_from_file_errors(): + test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts") with pytest.raises(MisconfigurationException, match="There should not be multiple apps instantiated within a file"): load_app_from_file(os.path.join(test_script_dir, "two_apps.py")) @@ -20,8 +20,19 @@ def test_load_app_from_file(): load_app_from_file(os.path.join(test_script_dir, "script_with_error.py")) +@pytest.mark.parametrize("app_path", ["app_metadata.py", "app_with_local_import.py"]) +def test_load_app_from_file(app_path): + """Test that apps load without error and that sys.path and main module are set.""" + original_main = sys.modules["__main__"] + test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts") + load_app_from_file(os.path.join(test_script_dir, app_path), raise_exception=True) + + assert test_script_dir in sys.path + assert sys.modules["__main__"] != original_main + + def test_extract_metadata_from_component(): - test_script_dir = os.path.join(os.path.dirname(tests_app.core.__file__), "scripts") + test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts") app = load_app_from_file(os.path.join(test_script_dir, "app_metadata.py")) metadata = extract_metadata_from_app(app) assert metadata == [