Skip to content

Commit

Permalink
[App] fix lightning open command & better redirects (#16794)
Browse files Browse the repository at this point in the history
* fix(app): URLs, create run on app run

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and lexierule committed Feb 21, 2023
1 parent 5a349fe commit 494ce3e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 48 deletions.
46 changes: 20 additions & 26 deletions src/lightning_app/runners/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,8 @@ def open(self, name: str, cluster_id: Optional[str] = None):
if getattr(run, "cluster_id", None):
print(f"Running on {run.cluster_id}")

# TODO: We shouldn't need to create an instance here
if existing_run_instance is not None:
run_instance = self._api_transfer_run_instance(
project.project_id,
run.id,
existing_run_instance.id,
V1LightningappInstanceState.STOPPED,
)
else:
run_instance = self._api_create_run_instance(
cluster_id,
project.project_id,
cloudspace_name,
cloudspace_id,
run.id,
V1LightningappInstanceState.STOPPED,
)

if "PYTEST_CURRENT_TEST" not in os.environ:
click.launch(self._get_app_url(project, cloudspace_name, run_instance, "code", needs_credits))
click.launch(self._get_cloudspace_url(project, cloudspace_name, "code", needs_credits))

except ApiException as e:
logger.error(e.body)
Expand Down Expand Up @@ -383,9 +365,7 @@ def dispatch(
# TODO: Remove testing dependency, but this would open a tab for each test...
if open_ui and "PYTEST_CURRENT_TEST" not in os.environ:
click.launch(
self._get_app_url(
project, cloudspace_name, run_instance, "logs" if run.is_headless else "web-ui", needs_credits
)
self._get_app_url(project, run_instance, "logs" if run.is_headless else "web-ui", needs_credits)
)
except ApiException as e:
logger.error(e.body)
Expand Down Expand Up @@ -1007,10 +987,24 @@ def _print_specs(run_body: CloudspaceIdRunsBody, print_format: str) -> None:
requirements_path = getattr(getattr(run_body.image_spec, "dependency_file_info", ""), "path", "")
logger.info(f"requirements_path: {requirements_path}")

def _get_cloudspace_url(
self, project: V1Membership, cloudspace_name: str, tab: str, need_credits: bool = False
) -> str:
user = self.backend.client.auth_service_get_user()
action = "?action=add_credits" if need_credits else ""
paths = [
user.username,
project.name,
"apps",
cloudspace_name,
tab,
]
path = "/".join([quote(path, safe="") for path in paths])
return f"{get_lightning_cloud_url()}/{path}{action}"

def _get_app_url(
self,
project: V1Membership,
cloudspace_name: str,
run_instance: Externalv1LightningappInstance,
tab: str,
need_credits: bool = False,
Expand All @@ -1021,8 +1015,8 @@ def _get_app_url(
paths = [
user.username,
project.name,
"apps",
cloudspace_name,
"jobs",
run_instance.name,
tab,
]
else:
Expand All @@ -1032,5 +1026,5 @@ def _get_app_url(
run_instance.id,
tab,
]
path = quote("/".join([path.replace(" ", "_").replace("/", "~") for path in paths]))
path = "/".join([quote(path, safe="") for path in paths])
return f"{get_lightning_cloud_url()}/{path}{action}"
64 changes: 42 additions & 22 deletions tests/tests_app/runners/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import pytest
from lightning_cloud.openapi import (
Body4,
CloudspaceIdRunsBody,
Externalv1Cluster,
Externalv1LightningappInstance,
Expand Down Expand Up @@ -1422,9 +1421,6 @@ def test_open(self, monkeypatch):

mock_client.cloud_space_service_create_cloud_space.return_value = V1CloudSpace(id="cloudspace_id")
mock_client.cloud_space_service_create_lightning_run.return_value = V1LightningRun(id="run_id")
mock_client.cloud_space_service_create_lightning_run_instance.return_value = Externalv1LightningappInstance(
id="instance_id"
)

mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([Externalv1Cluster(id="test")])
cloud_backend = mock.MagicMock()
Expand All @@ -1445,9 +1441,6 @@ def test_open(self, monkeypatch):
cloudspace_id="cloudspace_id",
body=mock.ANY,
)
mock_client.cloud_space_service_create_lightning_run_instance.assert_called_once_with(
project_id="test-project-id", cloudspace_id="cloudspace_id", id="run_id", body=mock.ANY
)

assert mock_client.cloud_space_service_create_cloud_space.call_args.kwargs["body"].name == "test_space"

Expand Down Expand Up @@ -1565,10 +1558,6 @@ def test_reopen(self, monkeypatch, capsys):
body=mock.ANY,
)

mock_client.lightningapp_instance_service_update_lightningapp_instance_release.assert_called_once_with(
project_id="test-project-id", id="instance_id", body=Body4(release_id="run_id")
)

out, _ = capsys.readouterr()
assert "will not overwrite the files in your CloudSpace." in out

Expand Down Expand Up @@ -2012,12 +2001,11 @@ def run(self):


@pytest.mark.parametrize(
"project, cloudspace_name, run_instance, user, tab, lightning_cloud_url, expected_url",
"project, run_instance, user, tab, lightning_cloud_url, expected_url",
[
# Old style
(
V1Membership(),
"any",
Externalv1LightningappInstance(id="test-app-id"),
V1GetUserResponse(username="tester", features=V1UserFeatures()),
"logs",
Expand All @@ -2026,7 +2014,6 @@ def run(self):
),
(
V1Membership(),
"any",
Externalv1LightningappInstance(id="test-app-id"),
V1GetUserResponse(username="tester", features=V1UserFeatures()),
"logs",
Expand All @@ -2036,25 +2023,58 @@ def run(self):
# New style
(
V1Membership(name="tester's project"),
"test/app",
Externalv1LightningappInstance(),
Externalv1LightningappInstance(name="test/job"),
V1GetUserResponse(username="tester", features=V1UserFeatures(project_selector=True)),
"logs",
"https://lightning.ai",
"https://lightning.ai/tester/tester%27s_project/apps/test~app/logs",
"https://lightning.ai/tester/tester%27s%20project/jobs/test%2Fjob/logs",
),
(
V1Membership(name="tester's project"),
"test/app",
Externalv1LightningappInstance(),
Externalv1LightningappInstance(name="test/job"),
V1GetUserResponse(username="tester", features=V1UserFeatures(project_selector=True)),
"logs",
"https://localhost:9800",
"https://localhost:9800/tester/tester%27s%20project/jobs/test%2Fjob/logs",
),
],
)
def test_get_app_url(monkeypatch, project, run_instance, user, tab, lightning_cloud_url, expected_url):
mock_client = mock.MagicMock()
mock_client.auth_service_get_user.return_value = user
cloud_backend = mock.MagicMock(client=mock_client)
monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend))

runtime = CloudRuntime()

with mock.patch(
"lightning.app.runners.cloud.get_lightning_cloud_url", mock.MagicMock(return_value=lightning_cloud_url)
):
assert runtime._get_app_url(project, run_instance, tab) == expected_url


@pytest.mark.parametrize(
"user, project, cloudspace_name, tab, lightning_cloud_url, expected_url",
[
(
V1GetUserResponse(username="tester", features=V1UserFeatures()),
V1Membership(name="default-project"),
"test/cloudspace",
"code",
"https://lightning.ai",
"https://lightning.ai/tester/default-project/apps/test%2Fcloudspace/code",
),
(
V1GetUserResponse(username="tester", features=V1UserFeatures()),
V1Membership(name="Awesome Project"),
"The Best CloudSpace ever",
"web-ui",
"http://localhost:9800",
"http://localhost:9800/tester/tester%27s_project/apps/test~app/logs",
"http://localhost:9800/tester/Awesome%20Project/apps/The%20Best%20CloudSpace%20ever/web-ui",
),
],
)
def test_get_app_url(monkeypatch, project, cloudspace_name, run_instance, user, tab, lightning_cloud_url, expected_url):
def test_get_cloudspace_url(monkeypatch, user, project, cloudspace_name, tab, lightning_cloud_url, expected_url):
mock_client = mock.MagicMock()
mock_client.auth_service_get_user.return_value = user
cloud_backend = mock.MagicMock(client=mock_client)
Expand All @@ -2065,4 +2085,4 @@ def test_get_app_url(monkeypatch, project, cloudspace_name, run_instance, user,
with mock.patch(
"lightning_app.runners.cloud.get_lightning_cloud_url", mock.MagicMock(return_value=lightning_cloud_url)
):
assert runtime._get_app_url(project, cloudspace_name, run_instance, tab) == expected_url
assert runtime._get_cloudspace_url(project, cloudspace_name, tab) == expected_url

0 comments on commit 494ce3e

Please sign in to comment.