Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Enable Python Server and Gradio Serve to run on accelerated device such as GPU CUDA / MPS #15813

Merged
merged 5 commits into from Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))

- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813))


## [1.8.2] - 2022-11-17

Expand Down
48 changes: 48 additions & 0 deletions src/lightning_app/components/serve/python_server.py
@@ -1,5 +1,6 @@
import abc
import base64
import os
from pathlib import Path
from typing import Any, Dict, Optional

Expand All @@ -9,12 +10,54 @@
from pydantic import BaseModel
from starlette.staticfiles import StaticFiles

from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver

logger = Logger(__name__)


class _PyTorchSpawnRunExecutor(WorkRunExecutor):

"""This Executor enables to move PyTorch tensors on GPU.
Without this executor, it woud raise the following expection:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
"""

enable_start_observer: bool = False

def __call__(self, *args: Any, **kwargs: Any):
import torch

with self.enable_spawn():
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
torch.multiprocessing.spawn(
self.dispatch_run,
args=(self.__class__, self.work, queue, args, kwargs),
nprocs=1,
)

@staticmethod
def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs):
if local_rank == 0:
if isinstance(delta_queue, dict):
delta_queue = cls.process_queue(delta_queue)
work._request_queue = cls.process_queue(work._request_queue)
work._response_queue = cls.process_queue(work._response_queue)

state_observer = WorkStateObserver(work, delta_queue=delta_queue)
state_observer.start()
_proxy_setattr(work, delta_queue, state_observer)

unwrap(work.run)(*args, **kwargs)

if local_rank == 0:
state_observer.join(0)


class _DefaultInputData(BaseModel):
payload: str

Expand Down Expand Up @@ -106,6 +149,11 @@ def predict(self, request):
self._input_type = input_type
self._output_type = output_type

# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)

def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
Expand Down