Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
376f39e
update
tchaton Aug 1, 2022
94f3348
update
tchaton Aug 1, 2022
cd03c6d
update
tchaton Aug 1, 2022
192ece9
update
tchaton Aug 2, 2022
add4369
update
tchaton Aug 2, 2022
93494d7
update
tchaton Aug 2, 2022
7a04c2c
update
tchaton Aug 2, 2022
943826a
update
tchaton Aug 2, 2022
36a3527
update
tchaton Aug 2, 2022
3dffbf8
update
tchaton Aug 2, 2022
001f131
update
tchaton Aug 2, 2022
0331f82
Merge branch 'master' into reduce_state_size
tchaton Aug 2, 2022
1738bc9
update
tchaton Aug 2, 2022
04f8bee
Merge branch 'reduce_state_size' of https://github.com/Lightning-AI/l…
tchaton Aug 2, 2022
bad5f6c
update
tchaton Aug 2, 2022
8b9fd44
Merge branch 'master' into reduce_state_size
tchaton Aug 2, 2022
a506bd5
update
tchaton Aug 2, 2022
1e464fa
Merge branch 'reduce_state_size' of https://github.com/Lightning-AI/l…
tchaton Aug 2, 2022
d7cdc0e
update
tchaton Aug 2, 2022
ed0a08b
update
tchaton Aug 2, 2022
8a633bf
update
tchaton Aug 2, 2022
616b7df
Merge branch 'master' into reduce_state_size
tchaton Aug 2, 2022
0b85cad
update
tchaton Aug 2, 2022
9f69712
Merge branch 'reduce_state_size' of https://github.com/Lightning-AI/l…
tchaton Aug 2, 2022
1818874
update
tchaton Aug 2, 2022
935c3a1
update
tchaton Aug 2, 2022
ead2d5d
update
tchaton Aug 2, 2022
8529e28
update
tchaton Aug 2, 2022
b545ead
update
tchaton Aug 2, 2022
64699aa
update
tchaton Aug 3, 2022
19c7c9a
update
tchaton Aug 3, 2022
491858b
Merge branch 'master' into reduce_state_size
tchaton Aug 3, 2022
f34d5d7
update
tchaton Aug 3, 2022
1070df8
update
tchaton Aug 3, 2022
0256145
Merge branch 'reduce_state_size' of https://github.com/Lightning-AI/l…
tchaton Aug 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/app/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ isort>=5.0
mypy>=0.720
httpx
trio
pympler
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated

### Fixed

- Resolved a bug where the work statuses will grow quickly and be duplicated ([#13970](https://github.com/Lightning-AI/lightning/pull/13970))
22 changes: 9 additions & 13 deletions src/lightning_app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning_app.utilities.app_helpers import _delta_to_appstate_delta, _LightningAppRef
from lightning_app.utilities.commands.base import _populate_commands_endpoint, _process_command_requests
from lightning_app.utilities.component import _convert_paths_after_init
from lightning_app.utilities.enum import AppStage
from lightning_app.utilities.enum import AppStage, CacheCallsKeys
from lightning_app.utilities.exceptions import CacheMissException, ExitAppException
from lightning_app.utilities.layout import _collect_layout
from lightning_app.utilities.proxies import ComponentDelta
Expand Down Expand Up @@ -399,8 +399,8 @@ def _run(self) -> bool:
if self.should_publish_changes_to_api and self.api_publish_state_queue:
logger.debug("Publishing the state with changes")
# Push two states to optimize start in the cloud.
self.api_publish_state_queue.put(self.state)
self.api_publish_state_queue.put(self.state)
self.api_publish_state_queue.put(self.state_vars)
self.api_publish_state_queue.put(self.state_vars)

self._reset_run_time_monitor()

Expand All @@ -412,7 +412,7 @@ def _run(self) -> bool:
self._update_run_time_monitor()

if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue:
self.api_publish_state_queue.put(self.state)
self.api_publish_state_queue.put(self.state_vars)

return True

Expand All @@ -430,16 +430,12 @@ def _apply_restarting(self) -> bool:
self.stage = AppStage.BLOCKING
return False

def _collect_work_finish_status(self) -> dict:
work_finished_status = {}
for work in self.works:
work_finished_status[work.name] = False
for key in work._calls:
if key == "latest_call_hash":
continue
fn_metadata = work._calls[key]
work_finished_status[work.name] = fn_metadata["name"] == "run" and "ret" in fn_metadata
def _has_work_finished(self, work):
latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]
return "ret" in work._calls[latest_call_hash]

def _collect_work_finish_status(self) -> dict:
work_finished_status = {work.name: self._has_work_finished(work) for work in self.works}
assert len(work_finished_status) == len(self.works)
return work_finished_status

Expand Down
91 changes: 73 additions & 18 deletions src/lightning_app/core/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,15 @@
from lightning_app.storage.drive import _maybe_create_drive, Drive
from lightning_app.storage.payload import Payload
from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef
from lightning_app.utilities.component import _sanitize_state
from lightning_app.utilities.enum import make_status, WorkFailureReasons, WorkStageStatus, WorkStatus, WorkStopReasons
from lightning_app.utilities.component import _is_flow_context, _sanitize_state
from lightning_app.utilities.enum import (
CacheCallsKeys,
make_status,
WorkFailureReasons,
WorkStageStatus,
WorkStatus,
WorkStopReasons,
)
from lightning_app.utilities.exceptions import LightningWorkException
from lightning_app.utilities.introspection import _is_init_context
from lightning_app.utilities.network import find_free_network_port
Expand Down Expand Up @@ -107,7 +114,21 @@ def __init__(
# setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator
self._setattr_replacement: Optional[Callable[[str, Any], None]] = None
self._name = ""
self._calls = {"latest_call_hash": None}
# The ``self._calls`` is used to track whether the run
# method with a given set of input arguments has already been called.
# Example of its usage:
# {
# 'latest_call_hash': '167fe2e',
# '167fe2e': {
# 'statuses': [
# {'stage': 'pending', 'timestamp': 1659433519.851271},
# {'stage': 'running', 'timestamp': 1659433519.956482},
# {'stage': 'stopped', 'timestamp': 1659433520.055768}]}
# ]
# },
# ...
# }
self._calls = {CacheCallsKeys.LATEST_CALL_HASH: None}
self._changes = {}
self._raise_exception = raise_exception
self._paths = {}
Expand Down Expand Up @@ -215,22 +236,22 @@ def status(self) -> WorkStatus:

All statuses are stored in the state.
"""
call_hash = self._calls["latest_call_hash"]
if call_hash:
call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
if call_hash in self._calls:
statuses = self._calls[call_hash]["statuses"]
# deltas aren't necessarily coming in the expected order.
statuses = sorted(statuses, key=lambda x: x["timestamp"])
latest_status = statuses[-1]
if latest_status["reason"] == WorkFailureReasons.TIMEOUT:
if latest_status.get("reason") == WorkFailureReasons.TIMEOUT:
return self._aggregate_status_timeout(statuses)
return WorkStatus(**latest_status)
return WorkStatus(stage=WorkStageStatus.NOT_STARTED, timestamp=time.time())

@property
def statuses(self) -> List[WorkStatus]:
"""Return all the status of the work."""
call_hash = self._calls["latest_call_hash"]
if call_hash:
call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
if call_hash in self._calls:
statuses = self._calls[call_hash]["statuses"]
# deltas aren't necessarily coming in the expected order.
statuses = sorted(statuses, key=lambda x: x["timestamp"])
Expand Down Expand Up @@ -398,10 +419,13 @@ def __getattr__(self, item):
return path
return self.__getattribute__(item)

def _call_hash(self, fn, args, kwargs):
def _call_hash(self, fn, args, kwargs) -> str:
hash_args = args[1:] if len(args) > 0 and args[0] == self else args
call_obj = {"args": hash_args, "kwargs": kwargs}
return f"{fn.__name__}:{DeepHash(call_obj)[call_obj]}"
# Note: Generate a hash as 167fe2e.
# Seven was selected after checking upon Github default SHA length
# and to minimize hidden state size.
return str(DeepHash(call_obj)[call_obj])[:7]

def _wrap_run_for_caching(self, fn):
@wraps(fn)
Expand All @@ -415,11 +439,11 @@ def new_fn(*args, **kwargs):
entry = self._calls[call_hash]
return entry["ret"]

self._calls[call_hash] = {"name": fn.__name__, "call_hash": call_hash}
self._calls[call_hash] = {}

result = fn(*args, **kwargs)

self._calls[call_hash] = {"name": fn.__name__, "call_hash": call_hash, "ret": result}
self._calls[call_hash] = {"ret": result}

return result

Expand Down Expand Up @@ -457,8 +481,40 @@ def set_state(self, provided_state):
if isinstance(v, Dict):
v = _maybe_create_drive(self.name, v)
setattr(self, k, v)

self._changes = provided_state["changes"]
self._calls.update(provided_state["calls"])

# Note, this is handled by the flow only.
if _is_flow_context():
self._cleanup_calls(provided_state["calls"])

self._calls = provided_state["calls"]

@staticmethod
def _cleanup_calls(calls: Dict[str, Any]):
# 1: Collect all the in_progress call hashes
in_progress_call_hash = [k for k in list(calls) if k not in (CacheCallsKeys.LATEST_CALL_HASH)]

for call_hash in in_progress_call_hash:
if "statuses" not in calls[call_hash]:
continue

# 2: Filter the statuses by timestamp
statuses = sorted(calls[call_hash]["statuses"], key=lambda x: x["timestamp"])

# If the latest status is succeeded, then drop everything before.
if statuses[-1]["stage"] == WorkStageStatus.SUCCEEDED:
status = statuses[-1]
status["timestamp"] = int(status["timestamp"])
calls[call_hash]["statuses"] = [status]
else:
# TODO: Some status are being duplicated,
# this seems related to the StateObserver.
final_statuses = []
for status in statuses:
if status not in final_statuses:
final_statuses.append(status)
calls[call_hash]["statuses"] = final_statuses

@abc.abstractmethod
def run(self, *args, **kwargs):
Expand All @@ -479,7 +535,7 @@ def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus:
if succeeded_statuses:
succeed_status_id = succeeded_statuses[-1] + 1
statuses = statuses[succeed_status_id:]
timeout_statuses = [status for status in statuses if status["reason"] == WorkFailureReasons.TIMEOUT]
timeout_statuses = [status for status in statuses if status.get("reason") == WorkFailureReasons.TIMEOUT]
assert statuses[0]["stage"] == WorkStageStatus.PENDING
status = {**timeout_statuses[-1], "timestamp": statuses[0]["timestamp"]}
return WorkStatus(**status, count=len(timeout_statuses))
Expand All @@ -501,9 +557,8 @@ def stop(self):
)
if self.status.stage == WorkStageStatus.STOPPED:
return
latest_hash = self._calls["latest_call_hash"]
self._calls[latest_hash]["statuses"].append(
make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.PENDING)
)
latest_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
stop_status = make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.PENDING)
self._calls[latest_hash]["statuses"].append(stop_status)
app = _LightningAppRef().get_current()
self._backend.stop_work(app, self)
11 changes: 6 additions & 5 deletions src/lightning_app/runners/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightning_app import LightningApp
from lightning_app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT
from lightning_app.runners.backends import Backend, BackendType
from lightning_app.utilities.enum import AppStage, make_status, WorkStageStatus
from lightning_app.utilities.enum import AppStage, CacheCallsKeys, make_status, WorkStageStatus
from lightning_app.utilities.load_app import load_app_from_file
from lightning_app.utilities.proxies import WorkRunner

Expand Down Expand Up @@ -133,9 +133,10 @@ def dispatch(self, *args, **kwargs):
raise NotImplementedError

def _add_stopped_status_to_work(self, work: "lightning_app.LightningWork") -> None:

if work.status.stage == WorkStageStatus.STOPPED:
return
latest_hash = work._calls["latest_call_hash"]
if latest_hash is None:
return
work._calls[latest_hash]["statuses"].append(make_status(WorkStageStatus.STOPPED))

latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]
if latest_call_hash in work._calls:
work._calls[latest_call_hash]["statuses"].append(make_status(WorkStageStatus.STOPPED))
15 changes: 11 additions & 4 deletions src/lightning_app/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lightning_app.runners.multiprocess import MultiProcessRuntime
from lightning_app.testing.config import Config
from lightning_app.utilities.cloud import _get_project
from lightning_app.utilities.enum import CacheCallsKeys
from lightning_app.utilities.imports import _is_playwright_available, requires
from lightning_app.utilities.network import _configure_session, LightningClient
from lightning_app.utilities.proxies import ProxyWorkRun
Expand Down Expand Up @@ -114,8 +115,11 @@ def run_work_isolated(work, *args, start_server: bool = False, **kwargs):
start_server=start_server,
).dispatch()
# pop the stopped status.
call_hash = work._calls["latest_call_hash"]
work._calls[call_hash]["statuses"].pop(-1)
call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]

if call_hash in work._calls:
work._calls[call_hash]["statuses"].pop(-1)

if isinstance(work.run, ProxyWorkRun):
work.run = work.run.work_run

Expand Down Expand Up @@ -176,7 +180,7 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator:
# 3. Launch the application in the cloud from the Lightning CLI.
with tempfile.TemporaryDirectory() as tmpdir:
env_copy = os.environ.copy()
env_copy["PREPARE_LIGHTING"] = "1"
env_copy["PACKAGE_LIGHTNING"] = "1"
shutil.copytree(app_folder, tmpdir, dirs_exist_ok=True)
# TODO - add -no-cache to the command line.
process = Popen(
Expand Down Expand Up @@ -216,7 +220,10 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator:
record_har_path=Config.har_location,
)
admin_page = context.new_page()
res = requests.post(Config.url + "/v1/auth/login", data=json.dumps(payload))
url = Config.url
if url.endswith("/"):
url = url[:-1]
res = requests.post(url + "/v1/auth/login", data=json.dumps(payload))
token = res.json()["token"]
print(f"The Lightning App Token is: {token}")
print(f"The Lightning App user key is: {Config.key}")
Expand Down
13 changes: 10 additions & 3 deletions src/lightning_app/utilities/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,16 @@ def __post_init__(self):


def make_status(stage: str, message: Optional[str] = None, reason: Optional[str] = None):
return {
status = {
"stage": stage,
"message": message,
"reason": reason,
"timestamp": datetime.now(tz=timezone.utc).timestamp(),
}
if message:
status["message"] = message
if reason:
status["reason"] = reason
return status


class CacheCallsKeys:
LATEST_CALL_HASH = "latest_call_hash"
2 changes: 1 addition & 1 deletion src/lightning_app/utilities/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _configure_session() -> Session:
return http


def _check_service_url_is_ready(url: str, timeout: float = 1) -> bool:
def _check_service_url_is_ready(url: str, timeout: float = 100) -> bool:
try:
response = requests.get(url, timeout=timeout)
return response.status_code in (200, 404)
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/utilities/packaging/build_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _find_requirements(self, work: "LightningWork") -> List[str]:
file = inspect.getfile(work.__class__)

# 2. Try to find a requirement file associated the file.
dirname = os.path.dirname(file)
dirname = os.path.dirname(file) or "."
requirement_files = [os.path.join(dirname, f) for f in os.listdir(dirname) if f == "requirements.txt"]
if not requirement_files:
return []
Expand All @@ -126,7 +126,7 @@ def _find_dockerfile(self, work: "LightningWork") -> List[str]:
file = inspect.getfile(work.__class__)

# 2. Check for Dockerfile.
dirname = os.path.dirname(file)
dirname = os.path.dirname(file) or "."
dockerfiles = [os.path.join(dirname, f) for f in os.listdir(dirname) if f == "Dockerfile"]

if not dockerfiles:
Expand Down
Loading