Skip to content

Commit

Permalink
[App] Accelerate Multi Node Startup Time (#15650)
Browse files Browse the repository at this point in the history
(cherry picked from commit 757413c)
  • Loading branch information
tchaton authored and Borda committed Nov 16, 2022
1 parent fe302b2 commit 27afcd3
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 152 deletions.
4 changes: 2 additions & 2 deletions examples/app_multi_node/README.md
Expand Up @@ -28,9 +28,9 @@ lightning run app train_lite.py

Using Lite, you retain control over your loops while accessing in a minimal way all Lightning distributed strategies.

## Multi Node with PyTorch Lightning
## Multi Node with Lightning Trainer

Lightning supports running PyTorch Lightning from a script or within a Lightning Work.
Lightning supports running Lightning Trainer from a script or within a Lightning Work.

You can either run a script directly

Expand Down
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -61,6 +61,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))


- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650))


## [1.8.0] - 2022-11-01

Expand Down
33 changes: 21 additions & 12 deletions src/lightning_app/components/database/server.py
Expand Up @@ -4,6 +4,7 @@
import sys
import tempfile
import threading
import traceback
from typing import List, Optional, Type, Union

import uvicorn
Expand Down Expand Up @@ -36,6 +37,9 @@ def install_signal_handlers(self):
"""Ignore Uvicorn Signal Handlers."""


_lock = threading.Lock()


class Database(LightningWork):
def __init__(
self,
Expand Down Expand Up @@ -146,25 +150,29 @@ class CounterModel(SQLModel, table=True):
self._exit_event = None

def store_database(self):
with tempfile.TemporaryDirectory() as tmpdir:
tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))
try:
with tempfile.TemporaryDirectory() as tmpdir:
tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))

source = sqlite3.connect(self.db_filename)
dest = sqlite3.connect(tmp_db_filename)
source = sqlite3.connect(self.db_filename)
dest = sqlite3.connect(tmp_db_filename)

source.backup(dest)
source.backup(dest)

source.close()
dest.close()
source.close()
dest.close()

drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
drive.put(os.path.basename(tmp_db_filename))
drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
drive.put(os.path.basename(tmp_db_filename))

print("Stored the database to the Drive.")
print("Stored the database to the Drive.")
except Exception:
print(traceback.print_exc())

def periodic_store_database(self, store_interval):
while not self._exit_event.is_set():
self.store_database()
with _lock:
self.store_database()
self._exit_event.wait(store_interval)

def run(self, token: Optional[str] = None) -> None:
Expand Down Expand Up @@ -210,4 +218,5 @@ def db_url(self) -> Optional[str]:

def on_exit(self):
self._exit_event.set()
self.store_database()
with _lock:
self.store_database()
50 changes: 17 additions & 33 deletions src/lightning_app/components/multi_node/base.py
Expand Up @@ -3,7 +3,6 @@
from lightning_app import structures
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.enum import WorkStageStatus
from lightning_app.utilities.packaging.cloud_compute import CloudCompute


Expand Down Expand Up @@ -52,46 +51,31 @@ def run(
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
self.ws = structures.List()
self._work_cls = work_cls
self.num_nodes = num_nodes
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
self.has_started = False
self.ws = structures.List(
*[
work_cls(
*work_args,
cloud_compute=cloud_compute,
**work_kwargs,
parallel=True,
)
for _ in range(num_nodes)
]
)

def run(self) -> None:
if not self.has_started:

# 1. Create & start the works
if not self.ws:
for node_rank in range(self.num_nodes):
self.ws.append(
self._work_cls(
*self._work_args,
cloud_compute=self._cloud_compute,
**self._work_kwargs,
parallel=True,
)
)

# Starting node `node_rank`` ...
self.ws[-1].start()

# 2. Wait for all machines to be started !
if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
return

self.has_started = True
# 1. Wait for all works to be started !
if not all(w.internal_ip for w in self.ws):
return

# Loop over all node machines
for node_rank in range(self.num_nodes):
# 2. Loop over all node machines
for node_rank in range(len(self.ws)):

# 3. Run the user code in a distributed way !
self.ws[node_rank].run(
main_address=self.ws[0].internal_ip,
main_port=self.ws[0].port,
num_nodes=self.num_nodes,
num_nodes=len(self.ws),
node_rank=node_rank,
)

Expand Down
10 changes: 10 additions & 0 deletions src/lightning_app/core/app.py
Expand Up @@ -472,6 +472,8 @@ def _run(self) -> bool:
self._original_state = deepcopy(self.state)
done = False

self._start_with_flow_works()

if self.should_publish_changes_to_api and self.api_publish_state_queue:
logger.debug("Publishing the state with changes")
# Push two states to optimize start in the cloud.
Expand Down Expand Up @@ -668,3 +670,11 @@ def _send_flow_to_work_deltas(self, state) -> None:
if deep_diff:
logger.debug(f"Sending deep_diff to {w.name} : {deep_diff}")
self.flow_to_work_delta_queues[w.name].put(deep_diff)

def _start_with_flow_works(self):
for w in self.works:
if w._start_with_flow:
parallel = w.parallel
w._parallel = True
w.start()
w._parallel = parallel
129 changes: 64 additions & 65 deletions src/lightning_app/runners/cloud.py
Expand Up @@ -142,78 +142,77 @@ def dispatch(
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))

works: List[V1Work] = []
for flow in self.app.flows:
for work in flow.works(recurse=False):
if not work._start_with_flow:
continue

work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)
for work in self.app.works:
if not work._start_with_flow:
continue

drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)
work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
status=V1DriveStatus(),
drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
mount_location=str(drive.root_folder),
status=V1DriveStatus(),
),
)
mount_location=str(drive.root_folder),
),
)

# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
)

random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))

# We need to collect a spec for each flow that contains a frontend so that the backend knows
# for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below)
Expand Down
6 changes: 2 additions & 4 deletions src/lightning_app/utilities/proxies.py
Expand Up @@ -103,8 +103,6 @@ class ProxyWorkRun:
caller_queue: "BaseQueue"

def __post_init__(self):
self.cache_calls = self.work.cache_calls
self.parallel = self.work.parallel
self.work_state = None

def __call__(self, *args, **kwargs):
Expand All @@ -123,7 +121,7 @@ def __call__(self, *args, **kwargs):

# The if/else conditions are left un-compressed to simplify readability
# for the readers.
if self.cache_calls:
if self.work.cache_calls:
if not entered or stopped_on_sigterm:
_send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
else:
Expand All @@ -137,7 +135,7 @@ def __call__(self, *args, **kwargs):
# the previous task has completed and we can re-queue the next one.
# overriding the return value for next loop iteration.
_send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
if not self.parallel:
if not self.work.parallel:
raise CacheMissException("Task never called before. Triggered now")

def _validate_call_args(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
Expand Down
11 changes: 7 additions & 4 deletions tests/tests_app/components/database/test_client_server.py
Expand Up @@ -2,6 +2,7 @@
import sys
import tempfile
import time
import traceback
from pathlib import Path
from time import sleep
from typing import List, Optional
Expand Down Expand Up @@ -197,7 +198,9 @@ def run(self):
assert len(self._client.select_all()) == 1
self._exit()

with tempfile.TemporaryDirectory() as tmpdir:

app = LightningApp(Flow(tmpdir))
MultiProcessRuntime(app).dispatch()
try:
with tempfile.TemporaryDirectory() as tmpdir:
app = LightningApp(Flow(tmpdir))
MultiProcessRuntime(app).dispatch()
except Exception:
print(traceback.print_exc())
2 changes: 1 addition & 1 deletion tests/tests_app/core/test_lightning_api.py
Expand Up @@ -42,7 +42,7 @@

class WorkA(LightningWork):
def __init__(self):
super().__init__(parallel=True)
super().__init__(parallel=True, start_with_flow=False)
self.var_a = 0
self.drive = Drive("lit://test_app_state_api")

Expand Down

0 comments on commit 27afcd3

Please sign in to comment.