Skip to content

Commit

Permalink
Fix cloudcomputes registration for structures (#15964)
Browse files Browse the repository at this point in the history
* fix cloudcomputes
* updates cloudcompute registration
* changelog

(cherry picked from commit 90a4c02)
  • Loading branch information
justusschock authored and Borda committed Dec 9, 2022
1 parent 843786f commit ad1de34
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951))
- Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963))
- Fixed MultiNode Component to use separate cloud computes ([#15965](https://github.com/Lightning-AI/lightning/pull/15965))
- Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964))


## [1.8.3] - 2022-11-22
Expand Down
19 changes: 13 additions & 6 deletions src/lightning_app/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,19 @@ def __setattr__(self, name: str, value: Any) -> None:
elif isinstance(value, (Dict, List)):
self._structures.add(name)
_set_child_name(self, value, name)
if getattr(self, "_backend", None) is not None:
value._backend = self._backend
for flow in value.flows:
LightningFlow._attach_backend(flow, self._backend)
for work in value.works:
self._backend._wrap_run_method(_LightningAppRef().get_current(), work)

_backend = getattr(self, "backend", None)
if _backend is not None:
value._backend = _backend

for flow in value.flows:
if _backend is not None:
LightningFlow._attach_backend(flow, _backend)

for work in value.works:
work._register_cloud_compute()
if _backend is not None:
_backend._wrap_run_method(_LightningAppRef().get_current(), work)

elif isinstance(value, Path):
# In the init context, the full name of the Flow and Work is not known, i.e., we can't serialize
Expand Down
29 changes: 28 additions & 1 deletion tests/tests_app/core/test_lightning_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import pytest
from deepdiff import DeepDiff, Delta

from lightning_app import LightningApp
import lightning_app
from lightning_app import CloudCompute, LightningApp
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.runners import MultiProcessRuntime
Expand Down Expand Up @@ -901,3 +902,29 @@ def run_patch(method):
state = app.api_publish_state_queue.put._mock_call_args[0][0]
call_hash = state["works"]["w"]["calls"]["latest_call_hash"]
assert state["works"]["w"]["calls"][call_hash]["statuses"][0]["stage"] == "succeeded"


def test_structures_register_work_cloudcompute():
class MyDummyWork(LightningWork):
def run(self):
return

class MyDummyFlow(LightningFlow):
def __init__(self):
super().__init__()
self.w_list = LList(*[MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)])
self.w_dict = LDict(**{str(i): MyDummyWork(cloud_compute=CloudCompute("gpu")) for i in range(5)})

def run(self):
for w in self.w_list:
w.run()

for w in self.w_dict.values():
w.run()

MyDummyFlow()
assert len(lightning_app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE) == 10
for v in lightning_app.utilities.packaging.cloud_compute._CLOUD_COMPUTE_STORE.values():
assert len(v.component_names) == 1
assert v.component_names[0][:-1] in ("root.w_list.", "root.w_dict.")
assert v.component_names[0][-1].isdigit()

0 comments on commit ad1de34

Please sign in to comment.