diff --git a/requirements/app/test.txt b/requirements/app/test.txt index d93aae4eaf143..9d2ed0af910ca 100644 --- a/requirements/app/test.txt +++ b/requirements/app/test.txt @@ -12,3 +12,4 @@ isort>=5.0 mypy>=0.720 httpx trio +pympler diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 34fdb9665f5aa..0f9838b1efe2e 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -24,3 +24,5 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated ### Fixed + +- Resolved a bug where the work statuses will grow quickly and be duplicated ([#13970](https://github.com/Lightning-AI/lightning/pull/13970)) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 6599b53efcb95..ab41fb256ffe6 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -18,7 +18,7 @@ from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef from lightning_app.utilities.commands.base import _populate_commands_endpoint, _process_command_requests from lightning_app.utilities.component import _convert_paths_after_init -from lightning_app.utilities.enum import AppStage +from lightning_app.utilities.enum import AppStage, CacheCallsKeys from lightning_app.utilities.exceptions import CacheMissException, ExitAppException from lightning_app.utilities.layout import _collect_layout from lightning_app.utilities.proxies import ComponentDelta @@ -399,8 +399,8 @@ def _run(self) -> bool: if self.should_publish_changes_to_api and self.api_publish_state_queue: logger.debug("Publishing the state with changes") # Push two states to optimize start in the cloud. - self.api_publish_state_queue.put(self.state) - self.api_publish_state_queue.put(self.state) + self.api_publish_state_queue.put(self.state_vars) + self.api_publish_state_queue.put(self.state_vars) self._reset_run_time_monitor() @@ -412,7 +412,7 @@ def _run(self) -> bool: self._update_run_time_monitor() if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue: - self.api_publish_state_queue.put(self.state) + self.api_publish_state_queue.put(self.state_vars) return True @@ -430,16 +430,12 @@ def _apply_restarting(self) -> bool: self.stage = AppStage.BLOCKING return False - def _collect_work_finish_status(self) -> dict: - work_finished_status = {} - for work in self.works: - work_finished_status[work.name] = False - for key in work._calls: - if key == "latest_call_hash": - continue - fn_metadata = work._calls[key] - work_finished_status[work.name] = fn_metadata["name"] == "run" and "ret" in fn_metadata + def _has_work_finished(self, work): + latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH] + return "ret" in work._calls[latest_call_hash] + def _collect_work_finish_status(self) -> dict: + work_finished_status = {work.name: self._has_work_finished(work) for work in self.works} assert len(work_finished_status) == len(self.works) return work_finished_status diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py index 53c9e07e80020..e7c800c0d15fa 100644 --- a/src/lightning_app/core/work.py +++ b/src/lightning_app/core/work.py @@ -12,8 +12,15 @@ from lightning_app.storage.drive import _maybe_create_drive, Drive from lightning_app.storage.payload import Payload from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef -from lightning_app.utilities.component import _sanitize_state -from lightning_app.utilities.enum import make_status, WorkFailureReasons, WorkStageStatus, WorkStatus, WorkStopReasons +from lightning_app.utilities.component import _is_flow_context, _sanitize_state +from lightning_app.utilities.enum import ( + CacheCallsKeys, + make_status, + WorkFailureReasons, + WorkStageStatus, + WorkStatus, + WorkStopReasons, +) from lightning_app.utilities.exceptions import LightningWorkException from lightning_app.utilities.introspection import _is_init_context from lightning_app.utilities.network import find_free_network_port @@ -107,7 +114,21 @@ def __init__( # setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator self._setattr_replacement: Optional[Callable[[str, Any], None]] = None self._name = "" - self._calls = {"latest_call_hash": None} + # The ``self._calls`` is used to track whether the run + # method with a given set of input arguments has already been called. + # Example of its usage: + # { + # 'latest_call_hash': '167fe2e', + # '167fe2e': { + # 'statuses': [ + # {'stage': 'pending', 'timestamp': 1659433519.851271}, + # {'stage': 'running', 'timestamp': 1659433519.956482}, + # {'stage': 'stopped', 'timestamp': 1659433520.055768}]} + # ] + # }, + # ... + # } + self._calls = {CacheCallsKeys.LATEST_CALL_HASH: None} self._changes = {} self._raise_exception = raise_exception self._paths = {} @@ -215,13 +236,13 @@ def status(self) -> WorkStatus: All statuses are stored in the state. """ - call_hash = self._calls["latest_call_hash"] - if call_hash: + call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH] + if call_hash in self._calls: statuses = self._calls[call_hash]["statuses"] # deltas aren't necessarily coming in the expected order. statuses = sorted(statuses, key=lambda x: x["timestamp"]) latest_status = statuses[-1] - if latest_status["reason"] == WorkFailureReasons.TIMEOUT: + if latest_status.get("reason") == WorkFailureReasons.TIMEOUT: return self._aggregate_status_timeout(statuses) return WorkStatus(**latest_status) return WorkStatus(stage=WorkStageStatus.NOT_STARTED, timestamp=time.time()) @@ -229,8 +250,8 @@ def status(self) -> WorkStatus: @property def statuses(self) -> List[WorkStatus]: """Return all the status of the work.""" - call_hash = self._calls["latest_call_hash"] - if call_hash: + call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH] + if call_hash in self._calls: statuses = self._calls[call_hash]["statuses"] # deltas aren't necessarily coming in the expected order. statuses = sorted(statuses, key=lambda x: x["timestamp"]) @@ -398,10 +419,13 @@ def __getattr__(self, item): return path return self.__getattribute__(item) - def _call_hash(self, fn, args, kwargs): + def _call_hash(self, fn, args, kwargs) -> str: hash_args = args[1:] if len(args) > 0 and args[0] == self else args call_obj = {"args": hash_args, "kwargs": kwargs} - return f"{fn.__name__}:{DeepHash(call_obj)[call_obj]}" + # Note: Generate a hash as 167fe2e. + # Seven was selected after checking upon Github default SHA length + # and to minimize hidden state size. + return str(DeepHash(call_obj)[call_obj])[:7] def _wrap_run_for_caching(self, fn): @wraps(fn) @@ -415,11 +439,11 @@ def new_fn(*args, **kwargs): entry = self._calls[call_hash] return entry["ret"] - self._calls[call_hash] = {"name": fn.__name__, "call_hash": call_hash} + self._calls[call_hash] = {} result = fn(*args, **kwargs) - self._calls[call_hash] = {"name": fn.__name__, "call_hash": call_hash, "ret": result} + self._calls[call_hash] = {"ret": result} return result @@ -457,8 +481,40 @@ def set_state(self, provided_state): if isinstance(v, Dict): v = _maybe_create_drive(self.name, v) setattr(self, k, v) + self._changes = provided_state["changes"] - self._calls.update(provided_state["calls"]) + + # Note, this is handled by the flow only. + if _is_flow_context(): + self._cleanup_calls(provided_state["calls"]) + + self._calls = provided_state["calls"] + + @staticmethod + def _cleanup_calls(calls: Dict[str, Any]): + # 1: Collect all the in_progress call hashes + in_progress_call_hash = [k for k in list(calls) if k not in (CacheCallsKeys.LATEST_CALL_HASH)] + + for call_hash in in_progress_call_hash: + if "statuses" not in calls[call_hash]: + continue + + # 2: Filter the statuses by timestamp + statuses = sorted(calls[call_hash]["statuses"], key=lambda x: x["timestamp"]) + + # If the latest status is succeeded, then drop everything before. + if statuses[-1]["stage"] == WorkStageStatus.SUCCEEDED: + status = statuses[-1] + status["timestamp"] = int(status["timestamp"]) + calls[call_hash]["statuses"] = [status] + else: + # TODO: Some status are being duplicated, + # this seems related to the StateObserver. + final_statuses = [] + for status in statuses: + if status not in final_statuses: + final_statuses.append(status) + calls[call_hash]["statuses"] = final_statuses @abc.abstractmethod def run(self, *args, **kwargs): @@ -479,7 +535,7 @@ def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus: if succeeded_statuses: succeed_status_id = succeeded_statuses[-1] + 1 statuses = statuses[succeed_status_id:] - timeout_statuses = [status for status in statuses if status["reason"] == WorkFailureReasons.TIMEOUT] + timeout_statuses = [status for status in statuses if status.get("reason") == WorkFailureReasons.TIMEOUT] assert statuses[0]["stage"] == WorkStageStatus.PENDING status = {**timeout_statuses[-1], "timestamp": statuses[0]["timestamp"]} return WorkStatus(**status, count=len(timeout_statuses)) @@ -501,9 +557,8 @@ def stop(self): ) if self.status.stage == WorkStageStatus.STOPPED: return - latest_hash = self._calls["latest_call_hash"] - self._calls[latest_hash]["statuses"].append( - make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.PENDING) - ) + latest_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH] + stop_status = make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.PENDING) + self._calls[latest_hash]["statuses"].append(stop_status) app = _LightningAppRef().get_current() self._backend.stop_work(app, self) diff --git a/src/lightning_app/runners/runtime.py b/src/lightning_app/runners/runtime.py index 3e15f958b8538..123e16d89ede5 100644 --- a/src/lightning_app/runners/runtime.py +++ b/src/lightning_app/runners/runtime.py @@ -10,7 +10,7 @@ from lightning_app import LightningApp from lightning_app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT from lightning_app.runners.backends import Backend, BackendType -from lightning_app.utilities.enum import AppStage, make_status, WorkStageStatus +from lightning_app.utilities.enum import AppStage, CacheCallsKeys, make_status, WorkStageStatus from lightning_app.utilities.load_app import load_app_from_file from lightning_app.utilities.proxies import WorkRunner @@ -133,9 +133,10 @@ def dispatch(self, *args, **kwargs): raise NotImplementedError def _add_stopped_status_to_work(self, work: "lightning_app.LightningWork") -> None: + if work.status.stage == WorkStageStatus.STOPPED: return - latest_hash = work._calls["latest_call_hash"] - if latest_hash is None: - return - work._calls[latest_hash]["statuses"].append(make_status(WorkStageStatus.STOPPED)) + + latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH] + if latest_call_hash in work._calls: + work._calls[latest_call_hash]["statuses"].append(make_status(WorkStageStatus.STOPPED)) diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index 10abdac4aad5d..dd34614a34353 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -22,6 +22,7 @@ from lightning_app.runners.multiprocess import MultiProcessRuntime from lightning_app.testing.config import Config from lightning_app.utilities.cloud import _get_project +from lightning_app.utilities.enum import CacheCallsKeys from lightning_app.utilities.imports import _is_playwright_available, requires from lightning_app.utilities.network import _configure_session, LightningClient from lightning_app.utilities.proxies import ProxyWorkRun @@ -114,8 +115,11 @@ def run_work_isolated(work, *args, start_server: bool = False, **kwargs): start_server=start_server, ).dispatch() # pop the stopped status. - call_hash = work._calls["latest_call_hash"] - work._calls[call_hash]["statuses"].pop(-1) + call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH] + + if call_hash in work._calls: + work._calls[call_hash]["statuses"].pop(-1) + if isinstance(work.run, ProxyWorkRun): work.run = work.run.work_run @@ -176,7 +180,7 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator: # 3. Launch the application in the cloud from the Lightning CLI. with tempfile.TemporaryDirectory() as tmpdir: env_copy = os.environ.copy() - env_copy["PREPARE_LIGHTING"] = "1" + env_copy["PACKAGE_LIGHTNING"] = "1" shutil.copytree(app_folder, tmpdir, dirs_exist_ok=True) # TODO - add -no-cache to the command line. process = Popen( @@ -216,7 +220,10 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator: record_har_path=Config.har_location, ) admin_page = context.new_page() - res = requests.post(Config.url + "/v1/auth/login", data=json.dumps(payload)) + url = Config.url + if url.endswith("/"): + url = url[:-1] + res = requests.post(url + "/v1/auth/login", data=json.dumps(payload)) token = res.json()["token"] print(f"The Lightning App Token is: {token}") print(f"The Lightning App user key is: {Config.key}") diff --git a/src/lightning_app/utilities/enum.py b/src/lightning_app/utilities/enum.py index 9469deffd925e..dbf20413aa9d9 100644 --- a/src/lightning_app/utilities/enum.py +++ b/src/lightning_app/utilities/enum.py @@ -59,9 +59,16 @@ def __post_init__(self): def make_status(stage: str, message: Optional[str] = None, reason: Optional[str] = None): - return { + status = { "stage": stage, - "message": message, - "reason": reason, "timestamp": datetime.now(tz=timezone.utc).timestamp(), } + if message: + status["message"] = message + if reason: + status["reason"] = reason + return status + + +class CacheCallsKeys: + LATEST_CALL_HASH = "latest_call_hash" diff --git a/src/lightning_app/utilities/network.py b/src/lightning_app/utilities/network.py index a9ebcf37ab564..7fd03750a515d 100644 --- a/src/lightning_app/utilities/network.py +++ b/src/lightning_app/utilities/network.py @@ -48,7 +48,7 @@ def _configure_session() -> Session: return http -def _check_service_url_is_ready(url: str, timeout: float = 1) -> bool: +def _check_service_url_is_ready(url: str, timeout: float = 100) -> bool: try: response = requests.get(url, timeout=timeout) return response.status_code in (200, 404) diff --git a/src/lightning_app/utilities/packaging/build_config.py b/src/lightning_app/utilities/packaging/build_config.py index 9231875d5d7fd..b776e202666de 100644 --- a/src/lightning_app/utilities/packaging/build_config.py +++ b/src/lightning_app/utilities/packaging/build_config.py @@ -110,7 +110,7 @@ def _find_requirements(self, work: "LightningWork") -> List[str]: file = inspect.getfile(work.__class__) # 2. Try to find a requirement file associated the file. - dirname = os.path.dirname(file) + dirname = os.path.dirname(file) or "." requirement_files = [os.path.join(dirname, f) for f in os.listdir(dirname) if f == "requirements.txt"] if not requirement_files: return [] @@ -126,7 +126,7 @@ def _find_dockerfile(self, work: "LightningWork") -> List[str]: file = inspect.getfile(work.__class__) # 2. Check for Dockerfile. - dirname = os.path.dirname(file) + dirname = os.path.dirname(file) or "." dockerfiles = [os.path.join(dirname, f) for f in os.listdir(dirname) if f == "Dockerfile"] if not dockerfiles: diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index c33e41bb70203..28d436f3e4a23 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -23,7 +23,13 @@ from lightning_app.utilities.app_helpers import affiliation from lightning_app.utilities.apply_func import apply_to_collection from lightning_app.utilities.component import _set_work_context -from lightning_app.utilities.enum import make_status, WorkFailureReasons, WorkStageStatus, WorkStopReasons +from lightning_app.utilities.enum import ( + CacheCallsKeys, + make_status, + WorkFailureReasons, + WorkStageStatus, + WorkStopReasons, +) from lightning_app.utilities.exceptions import CacheMissException, LightningSigtermStateException if TYPE_CHECKING: @@ -45,19 +51,13 @@ def unwrap(fn): return fn -def _send_data_to_caller_queue( - work: "LightningWork", caller_queue: "BaseQueue", data: Dict, call_hash: str, work_run: Callable, use_args: bool -) -> Dict: - if work._calls["latest_call_hash"] is None: - work._calls["latest_call_hash"] = call_hash +def _send_data_to_caller_queue(work: "LightningWork", caller_queue: "BaseQueue", data: Dict, call_hash: str) -> Dict: + + if work._calls[CacheCallsKeys.LATEST_CALL_HASH] is None: + work._calls[CacheCallsKeys.LATEST_CALL_HASH] = call_hash if call_hash not in work._calls: - work._calls[call_hash] = { - "name": work_run.__name__, - "call_hash": call_hash, - "use_args": use_args, - "statuses": [], - } + work._calls[call_hash] = {"statuses": []} else: # remove ret when relaunching the work. work._calls[call_hash].pop("ret", None) @@ -65,9 +65,19 @@ def _send_data_to_caller_queue( work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.PENDING)) work_state = work.state + + # There is no need to send all call hashes to the work. + calls = deepcopy(work_state["calls"]) + work_state["calls"] = { + k: v for k, v in work_state["calls"].items() if k in (call_hash, CacheCallsKeys.LATEST_CALL_HASH) + } + data.update({"state": work_state}) logger.debug(f"Sending to {work.name}: {data}") caller_queue.put(data) + + # Reset the calls entry. + work_state["calls"] = calls work._restarting = False return work_state @@ -85,9 +95,6 @@ def __post_init__(self): self.work_state = None def __call__(self, *args, **kwargs): - provided_none = len(args) == 1 and args[0] is None - use_args = len(kwargs) > 0 or (len(args) > 0 and not provided_none) - self._validate_call_args(args, kwargs) args, kwargs = self._process_call_args(args, kwargs) @@ -103,18 +110,18 @@ def __call__(self, *args, **kwargs): # for the readers. if self.cache_calls: if not entered or stopped_on_sigterm: - _send_data_to_caller_queue(self.work, self.caller_queue, data, call_hash, self.work_run, use_args) + _send_data_to_caller_queue(self.work, self.caller_queue, data, call_hash) else: if returned: return else: if not entered or stopped_on_sigterm: - _send_data_to_caller_queue(self.work, self.caller_queue, data, call_hash, self.work_run, use_args) + _send_data_to_caller_queue(self.work, self.caller_queue, data, call_hash) else: if returned or stopped_on_sigterm: # the previous task has completed and we can re-queue the next one. # overriding the return value for next loop iteration. - _send_data_to_caller_queue(self.work, self.caller_queue, data, call_hash, self.work_run, use_args) + _send_data_to_caller_queue(self.work, self.caller_queue, data, call_hash) if not self.parallel: raise CacheMissException("Task never called before. Triggered now") @@ -171,10 +178,9 @@ def sanitize(obj: Union[Path, Drive]) -> Union[Path, Dict]: class WorkStateObserver(Thread): - """This thread runs alongside LightningWork and periodically checks for state changes. - - If the state changed from one interval to the next, it will compute the delta and add it to the queue which is - connected to the Flow. This enables state changes to be captured that are not triggered through a setattr call. + """This thread runs alongside LightningWork and periodically checks for state changes. If the state changed + from one interval to the next, it will compute the delta and add it to the queue which is connected to the + Flow. This enables state changes to be captured that are not triggered through a setattr call. Args: work: The LightningWork for which the state should be monitored @@ -371,21 +377,24 @@ def run_once(self): self._proxy_setattr() # 8. Deepcopy the work state and send the first `RUNNING` status delta to the flow. - state = deepcopy(self.work.state) - self.work._calls["latest_call_hash"] = call_hash - self.work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.RUNNING)) - self.delta_queue.put(ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state)))) + reference_state = deepcopy(self.work.state) - # 9. Start the state observer thread. It will look for state changes and send them back to the Flow - # The observer has to be initialized here, after the set_state call above so that the thread can start with - # the proper initial state of the work - self.state_observer.start() + # 9. Inform the flow the work is running and add the delta to the deepcopy. + self.work._calls[CacheCallsKeys.LATEST_CALL_HASH] = call_hash + self.work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.RUNNING)) + delta = Delta(DeepDiff(reference_state, self.work.state)) + self.delta_queue.put(ComponentDelta(id=self.work_name, delta=delta)) # 10. Unwrap the run method if wrapped. work_run = self.work.run if hasattr(work_run, "__wrapped__"): work_run = work_run.__wrapped__ + # 11. Start the state observer thread. It will look for state changes and send them back to the Flow + # The observer has to be initialized here, after the set_state call above so that the thread can start with + # the proper initial state of the work + self.state_observer.start() + # 12. Run the `work_run` method. # If an exception is raised, send a `FAILED` status delta to the flow and call the `on_exception` hook. try: @@ -394,23 +403,26 @@ def run_once(self): raise e except BaseException as e: # 10.2 Send failed delta to the flow. + reference_state = deepcopy(self.work.state) self.work._calls[call_hash]["statuses"].append( make_status(WorkStageStatus.FAILED, message=str(e), reason=WorkFailureReasons.USER_EXCEPTION) ) - self.delta_queue.put(ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state)))) + self.delta_queue.put( + ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(reference_state, self.work.state))) + ) self.work.on_exception(e) print("########## CAPTURED EXCEPTION ###########") print(traceback.print_exc()) print("########## CAPTURED EXCEPTION ###########") return - # 14. Copy all artifacts to the shared storage so other Works can access them while this Work gets scaled down - persist_artifacts(work=self.work) - - # 15. Destroy the state observer. + # 13. Destroy the state observer. self.state_observer.join(0) self.state_observer = None + # 14. Copy all artifacts to the shared storage so other Works can access them while this Work gets scaled down + persist_artifacts(work=self.work) + # 15. An asynchronous work shouldn't return a return value. if ret is not None: raise RuntimeError( @@ -418,23 +430,24 @@ def run_once(self): "HINT: Use the Payload API instead." ) - # 16. DeepCopy the state and send the latest delta to the flow. + # 17. DeepCopy the state and send the latest delta to the flow. # use the latest state as we have already sent delta # during its execution. # inform the task has completed - state = deepcopy(self.work.state) + reference_state = deepcopy(self.work.state) self.work._calls[call_hash]["statuses"].append(make_status(WorkStageStatus.SUCCEEDED)) self.work._calls[call_hash]["ret"] = ret - self.delta_queue.put(ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(state, self.work.state)))) + self.delta_queue.put(ComponentDelta(id=self.work_name, delta=Delta(DeepDiff(reference_state, self.work.state)))) - # 17. Update the work for the next delta if any. + # 18. Update the work for the next delta if any. self._proxy_setattr(cleanup=True) def _sigterm_signal_handler(self, signum, frame, call_hash: str) -> None: """Signal handler used to react when spot instances are being retrived.""" - logger.debug("Received SIGTERM signal. Gracefully terminating...") + logger.info(f"Received SIGTERM signal. Gracefully terminating {self.work.name.replace('root.', '')}...") persist_artifacts(work=self.work) with _state_observer_lock: + self.work._calls[call_hash]["statuses"] = [] state = deepcopy(self.work.state) self.work._calls[call_hash]["statuses"].append( make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.SIGTERM_SIGNAL_HANDLER) diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py index 678655d6ee908..a8554e133e1a9 100644 --- a/tests/tests_app/components/python/test_python.py +++ b/tests/tests_app/components/python/test_python.py @@ -10,6 +10,7 @@ from lightning_app.testing.helpers import RunIf from lightning_app.testing.testing import run_work_isolated from lightning_app.utilities.component import _set_work_context +from lightning_app.utilities.enum import CacheCallsKeys COMPONENTS_SCRIPTS_FOLDER = str(os.path.join(_PROJECT_ROOT, "tests/tests_app/components/python/scripts/")) @@ -112,7 +113,7 @@ def test_tracer_component_with_code(): with open("file.py", "w") as f: f.write('raise Exception("An error")') - call_hash = python_script._calls["latest_call_hash"] + call_hash = python_script._calls[CacheCallsKeys.LATEST_CALL_HASH] python_script._calls[call_hash]["statuses"].pop(-1) python_script._calls[call_hash]["statuses"].pop(-1) diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index f55e7cb84b66a..a3a15085b98e3 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -6,6 +6,7 @@ import pytest from deepdiff import Delta +from pympler import asizeof from tests_app import _PROJECT_ROOT from lightning_app import LightningApp, LightningFlow, LightningWork # F401 @@ -486,12 +487,11 @@ def _dump_checkpoint(self): raise SuccessException -@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime]) -def test_snapshotting(runtime_cls, tmpdir): +def test_snap_shotting(): try: app = CheckpointLightningApp(FlowA()) app.checkpointing = True - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() except SuccessException: pass checkpoint_dir = os.path.join(storage_root_dir(), "checkpoints") @@ -765,15 +765,17 @@ def run(self): def test_protected_attributes_not_in_state(): flow = ProtectedAttributesFlow() - MultiProcessRuntime(LightningApp(flow)).dispatch() + MultiProcessRuntime(LightningApp(flow), start_server=False).dispatch() class WorkExit(LightningWork): def __init__(self): - super().__init__() + super().__init__(raise_exception=False) + self.counter = 0 def run(self): - pass + self.counter += 1 + raise Exception("Hello") class FlowExit(LightningFlow): @@ -782,13 +784,14 @@ def __init__(self): self.work = WorkExit() def run(self): + if self.work.counter == 1: + self._exit() self.work.run() - self._exit() def test_lightning_app_exit(): app = LightningApp(FlowExit()) - MultiProcessRuntime(app).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.work.status.stage == WorkStageStatus.STOPPED @@ -860,12 +863,12 @@ def run(self): def test_slow_flow(): app0 = LightningApp(SleepyFlow(sleep_interval=0.5 * FLOW_DURATION_THRESHOLD)) - MultiProcessRuntime(app0).dispatch() + MultiProcessRuntime(app0, start_server=False).dispatch() app1 = LightningApp(SleepyFlow(sleep_interval=2 * FLOW_DURATION_THRESHOLD)) with pytest.warns(LightningFlowWarning): - MultiProcessRuntime(app1).dispatch() + MultiProcessRuntime(app1, start_server=False).dispatch() app0 = LightningApp( SleepyFlowWithWork( @@ -875,7 +878,7 @@ def test_slow_flow(): ) ) - MultiProcessRuntime(app0).dispatch() + MultiProcessRuntime(app0, start_server=False).dispatch() app1 = LightningApp( SleepyFlowWithWork( @@ -883,4 +886,36 @@ def test_slow_flow(): ) ) - MultiProcessRuntime(app1).dispatch() + MultiProcessRuntime(app1, start_server=False).dispatch() + + +class SizeWork(LightningWork): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.counter = 0 + + def run(self, signal: int): + self.counter += 1 + + +class SizeFlow(LightningFlow): + def __init__(self): + super().__init__() + self.work0 = SizeWork(parallel=True, cache_calls=True) + self._state_sizes = {} + + def run(self): + for idx in range(self.work0.counter + 2): + self.work0.run(idx) + + self._state_sizes[self.work0.counter] = asizeof.asizeof(self.state) + + if self.work0.counter >= 20: + self._exit() + + +def test_state_size_constant_growth(): + app = LightningApp(SizeFlow()) + MultiProcessRuntime(app, start_server=False).dispatch() + assert app.root._state_sizes[0] <= 5904 + assert app.root._state_sizes[20] <= 23736 diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index 26841e057621b..1966c6d7b23d6 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -17,6 +17,7 @@ from lightning_app.storage.path import storage_root_dir from lightning_app.testing.helpers import EmptyFlow, EmptyWork from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef +from lightning_app.utilities.enum import CacheCallsKeys from lightning_app.utilities.exceptions import ExitAppException @@ -320,7 +321,7 @@ def run(self): "_restarting": False, "_internal_ip": "", }, - "calls": {"latest_call_hash": None}, + "calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "changes": {}, }, "work_a": { @@ -334,7 +335,7 @@ def run(self): "_restarting": False, "_internal_ip": "", }, - "calls": {"latest_call_hash": None}, + "calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "changes": {}, }, }, @@ -364,7 +365,7 @@ def run(self): "_restarting": False, "_internal_ip": "", }, - "calls": {"latest_call_hash": None}, + "calls": {CacheCallsKeys.LATEST_CALL_HASH: None}, "changes": {}, }, "work_a": { @@ -379,10 +380,8 @@ def run(self): "_internal_ip": "", }, "calls": { - "latest_call_hash": None, - "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c": { - "name": "run", - "call_hash": "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c", + CacheCallsKeys.LATEST_CALL_HASH: None, + "fe3fa0f": { "ret": None, }, }, @@ -435,7 +434,7 @@ def test_populate_changes_status_removed(): "work": { "vars": {}, "calls": { - "latest_call_hash": "run:fe3f", + CacheCallsKeys.LATEST_CALL_HASH: "run:fe3f", "run:fe3f": { "statuses": [ {"stage": "requesting", "message": None, "reason": None, "timestamp": 1}, diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py index 913fdf04c3299..14d8d26a458a6 100644 --- a/tests/tests_app/core/test_lightning_work.py +++ b/tests/tests_app/core/test_lightning_work.py @@ -8,7 +8,6 @@ from lightning_app.core.work import LightningWork, LightningWorkException from lightning_app.runners import MultiProcessRuntime from lightning_app.storage import Path -from lightning_app.storage.requests import GetRequest from lightning_app.testing.helpers import EmptyFlow, EmptyWork, MockQueue from lightning_app.utilities.enum import WorkStageStatus from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner @@ -130,8 +129,8 @@ def run(self): FlowFixed().run() -@pytest.mark.parametrize("raise_exception", [False, True]) @pytest.mark.parametrize("enable_exception", [False, True]) +@pytest.mark.parametrize("raise_exception", [False, True]) def test_lightning_status(enable_exception, raise_exception): class Work(EmptyWork): def __init__(self, raise_exception, enable_exception=True): @@ -143,17 +142,6 @@ def run(self): if self.enable_exception: raise Exception("Custom Exception") - class BlockingQueue(MockQueue): - """A Mock for the file copier queues that keeps blocking until we want to end the thread.""" - - keep_blocking = True - - def get(self, timeout: int = 0): - while BlockingQueue.keep_blocking: - pass - # A dummy request so the Copier gets something to process without an error - return GetRequest(source="src", name="dummy_path", path="test", hash="123", destination="dst") - work = Work(raise_exception, enable_exception=enable_exception) work._name = "root.w" assert work.status.stage == WorkStageStatus.NOT_STARTED @@ -163,9 +151,9 @@ def get(self, timeout: int = 0): error_queue = MockQueue("error_queue") request_queue = MockQueue("request_queue") response_queue = MockQueue("response_queue") - copy_request_queue = BlockingQueue("copy_request_queue") - copy_response_queue = BlockingQueue("copy_response_queue") - call_hash = "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c" + copy_request_queue = MockQueue("copy_request_queue") + copy_response_queue = MockQueue("copy_response_queue") + call_hash = "fe3fa0f" work._calls[call_hash] = { "args": (), "kwargs": {}, @@ -203,14 +191,13 @@ def get(self, timeout: int = 0): if enable_exception: exception_cls = Exception if raise_exception else Empty assert isinstance(error_queue._queue[0], exception_cls) - res[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "failed" - res[f"root['calls']['{call_hash}']['statuses'][0]"]["message"] == "Custom Exception" + res_end[f"root['calls']['{call_hash}']['statuses'][1]"]["stage"] == "failed" + res_end[f"root['calls']['{call_hash}']['statuses'][1]"]["message"] == "Custom Exception" else: assert res[f"root['calls']['{call_hash}']['statuses'][0]"]["stage"] == "running" assert res_end[f"root['calls']['{call_hash}']['statuses'][1]"]["stage"] == "succeeded" # Stop blocking and let the thread join - BlockingQueue.keep_blocking = False work_runner.copier.join() @@ -281,3 +268,15 @@ def run(self): assert flow.work.state["vars"]["none_to_path"] == Path("lit://none/to/path") assert flow.work.state["vars"]["path_to_none"] is None assert flow.work.state["vars"]["path_to_path"] == Path("lit://path/to/path") + + +def test_lightning_work_calls(): + class W(LightningWork): + def run(self, *args, **kwargs): + pass + + w = W() + assert len(w._calls) == 1 + w.run(1, [2], (3, 4), {"1": "3"}) + assert len(w._calls) == 2 + assert w._calls["0d824f7"] == {"ret": None} diff --git a/tests/tests_app/storage/test_payload.py b/tests/tests_app/storage/test_payload.py index 7a64750a01a92..2481320ff2d57 100644 --- a/tests/tests_app/storage/test_payload.py +++ b/tests/tests_app/storage/test_payload.py @@ -1,3 +1,4 @@ +import os import pathlib import pickle from copy import deepcopy @@ -146,3 +147,7 @@ def test_payload_works(tmpdir): with mock.patch("lightning_app.storage.path.storage_root_dir", lambda: pathlib.Path(tmpdir)): app = LightningApp(Flow(), debug=True) MultiProcessRuntime(app, start_server=False).dispatch() + + os.remove("value_all") + os.remove("value_b") + os.remove("value_c") diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index aaa7db18a5af2..18a6d372bfee9 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -8,7 +8,7 @@ from lightning_app.storage.payload import Payload from lightning_app.structures import Dict, List from lightning_app.testing.helpers import EmptyFlow -from lightning_app.utilities.enum import WorkStageStatus +from lightning_app.utilities.enum import CacheCallsKeys, WorkStageStatus def test_dict(): @@ -49,7 +49,7 @@ def run(self): for k in ("a", "b", "c", "d") ) assert all( - flow.state["structures"]["dict"]["works"][f"work_{k}"]["calls"] == {"latest_call_hash": None} + flow.state["structures"]["dict"]["works"][f"work_{k}"]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None} for k in ("a", "b", "c", "d") ) assert all(flow.state["structures"]["dict"]["works"][f"work_{k}"]["changes"] == {} for k in ("a", "b", "c", "d")) @@ -95,7 +95,8 @@ def run(self): for k in ("a", "b", "c", "d") ) assert all( - flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["calls"] == {"latest_call_hash": None} + flow.state_with_changes["structures"]["dict"]["works"][f"work_{k}"]["calls"] + == {CacheCallsKeys.LATEST_CALL_HASH: None} for k in ("a", "b", "c", "d") ) assert all( @@ -169,7 +170,8 @@ def run(self): for i in range(4) ) assert all( - flow.state["structures"]["list"]["works"][str(i)]["calls"] == {"latest_call_hash": None} for i in range(4) + flow.state["structures"]["list"]["works"][str(i)]["calls"] == {CacheCallsKeys.LATEST_CALL_HASH: None} + for i in range(4) ) assert all(flow.state["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4)) @@ -209,7 +211,8 @@ def run(self): for i in range(4) ) assert all( - flow.state_with_changes["structures"]["list"]["works"][str(i)]["calls"] == {"latest_call_hash": None} + flow.state_with_changes["structures"]["list"]["works"][str(i)]["calls"] + == {CacheCallsKeys.LATEST_CALL_HASH: None} for i in range(4) ) assert all(flow.state_with_changes["structures"]["list"]["works"][str(i)]["changes"] == {} for i in range(4)) diff --git a/tests/tests_app/utilities/test_login.py b/tests/tests_app/utilities/test_login.py index 43b10519e20ee..e0ad4b110c868 100644 --- a/tests/tests_app/utilities/test_login.py +++ b/tests/tests_app/utilities/test_login.py @@ -6,7 +6,7 @@ from lightning_app.utilities import login -LIGHTNING_CLOUD_URL = "https://lightning.ai" +LIGHTNING_CLOUD_URL = os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai") @pytest.fixture(autouse=True) diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index 3331a5e69e42b..cd0dfd7026e09 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -18,7 +18,7 @@ from lightning_app.storage.requests import GetRequest from lightning_app.testing.helpers import EmptyFlow, MockQueue from lightning_app.utilities.component import _convert_paths_after_init -from lightning_app.utilities.enum import WorkFailureReasons, WorkStageStatus +from lightning_app.utilities.enum import CacheCallsKeys, WorkFailureReasons, WorkStageStatus from lightning_app.utilities.exceptions import CacheMissException, ExitAppException from lightning_app.utilities.proxies import ( ComponentDelta, @@ -240,7 +240,7 @@ class WorkRunnerPatch(WorkRunner): counter = 0 def __call__(self): - call_hash = "run:fe3fa0f34fc1317e152e5afb023332995392071046f1ea51c34c7c9766e3676c" + call_hash = "fe3fa0f" while True: try: called = self.caller_queue.get() @@ -267,7 +267,7 @@ def test_proxy_timeout(): app = LightningApp(FlowTimeout(), debug=True) MultiProcessRuntime(app, start_server=False).dispatch() - call_hash = app.root.work._calls["latest_call_hash"] + call_hash = app.root.work._calls[CacheCallsKeys.LATEST_CALL_HASH] assert len(app.root.work._calls[call_hash]["statuses"]) == 3 assert app.root.work._calls[call_hash]["statuses"][0]["stage"] == "pending" assert app.root.work._calls[call_hash]["statuses"][1]["stage"] == "failed" @@ -308,7 +308,7 @@ def run(self, *args, **kwargs): "state": { "vars": {"_paths": {}, "_urls": {}}, "calls": { - "latest_call_hash": "any", + CacheCallsKeys.LATEST_CALL_HASH: "any", "any": { "name": "run", "call_hash": "any", @@ -361,7 +361,7 @@ def run(self, *args, **kwargs): ], ) @mock.patch("lightning_app.utilities.proxies.Copier") -def test_path_attributes_to_transfer(_, monkeypatch, origin, exists_remote, expected_get): +def test_path_attributes_to_transfer(_, origin, exists_remote, expected_get): """Test that any Lightning Path objects passed to the run method get transferred automatically (if they exist).""" path_mock = Mock() @@ -399,7 +399,7 @@ def run(self): "state": { "vars": {"_paths": flow.work._paths, "_urls": {}}, "calls": { - "latest_call_hash": "any", + CacheCallsKeys.LATEST_CALL_HASH: "any", "any": { "name": "run", "call_hash": "any", @@ -550,9 +550,9 @@ def run(self, use_setattr=False, use_containers=False): ############################ work.run(use_setattr=True, use_containers=False) - # this is necessary only in this test where we siumulate the calls + # this is necessary only in this test where we simulate the calls work._calls.clear() - work._calls.update({"latest_call_hash": None}) + work._calls.update({CacheCallsKeys.LATEST_CALL_HASH: None}) delta = delta_queue.get().delta.to_dict() assert delta["values_changed"] == {"root['vars']['var']": {"new_value": 2}} @@ -583,7 +583,7 @@ def run(self, use_setattr=False, use_containers=False): # this is necessary only in this test where we siumulate the calls work._calls.clear() - work._calls.update({"latest_call_hash": None}) + work._calls.update({CacheCallsKeys.LATEST_CALL_HASH: None}) delta = delta_queue.get().delta.to_dict() assert delta == {"values_changed": {"root['vars']['var']": {"new_value": 3}}} diff --git a/tests/tests_app_examples/collect_failures/app.py b/tests/tests_app_examples/collect_failures/app.py index 7f82f2367775d..89e302b2e6723 100644 --- a/tests/tests_app_examples/collect_failures/app.py +++ b/tests/tests_app_examples/collect_failures/app.py @@ -11,7 +11,7 @@ class SimpleWork(LightningWork): def __init__(self): - super().__init__(cache_calls=False, parallel=True) + super().__init__(cache_calls=False, parallel=True, raise_exception=False) self.is_running_now = False def run(self):