Skip to content

Commit

Permalink
Support Injecting Secrets into Apps Running in the Cloud
Browse files Browse the repository at this point in the history
Adds a new '--secret' flag to 'lightning run app':

lightning run app --cloud --secret MY_SECRET=my-secret-name app.py

When the Lightning App runs in the cloud, the 'MY_SECRET'
environment variable will be populated with the value of the
referenced Secret. The value of the Secret is encrypted in the
database, and will only be decrypted and accessible to the
Flow/Work processes in the cloud.
  • Loading branch information
alecmerdler committed Sep 12, 2022
1 parent 925edbc commit f9529e5
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 6 deletions.
2 changes: 1 addition & 1 deletion requirements/app/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lightning-cloud==0.5.3
lightning-cloud==0.5.4
packaging
deepdiff>=5.7.0, <=5.8.1
starsessions>=1.2.1, <2.0 # strict
Expand Down
13 changes: 12 additions & 1 deletion src/lightning_app/cli/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def _run_app(
blocking: bool,
open_ui: bool,
env: tuple,
secret: tuple,
):
file = _prepare_file(file)

Expand All @@ -298,10 +299,17 @@ def _run_app(
"Caching is a property of apps running in cloud. "
"Using the flag --no-cache in local execution is not supported."
)
if secret:
raise click.ClickException(
"Secrets can only be used for apps running in cloud. "
"Using the option --secret in local execution is not supported."
)

env_vars = _format_input_env_variables(env)
os.environ.update(env_vars)

secrets = _format_input_env_variables(secret)

def on_before_run(*args):
if open_ui and not without_server:
click.launch(get_app_url(runtime_type, *args))
Expand All @@ -320,6 +328,7 @@ def on_before_run(*args):
on_before_run=on_before_run,
name=name,
env_vars=env_vars,
secrets=secrets,
cluster_id=cluster_id,
)
if runtime_type == RuntimeType.CLOUD:
Expand All @@ -345,6 +354,7 @@ def run():
@click.option("--blocking", "blocking", type=bool, default=False)
@click.option("--open-ui", type=bool, default=True, help="Decide whether to launch the app UI in a web browser")
@click.option("--env", type=str, default=[], multiple=True, help="Env variables to be set for the app.")
@click.option("--secret", type=str, default=[], multiple=True, help="Secret variables to be set for the app.")
@click.option("--app_args", type=str, default=[], multiple=True, help="Collection of arguments for the app.")
def run_app(
file: str,
Expand All @@ -356,10 +366,11 @@ def run_app(
blocking: bool,
open_ui: bool,
env: tuple,
secret: tuple,
app_args: List[str],
):
"""Run an app from a file."""
_run_app(file, cloud, cluster_id, without_server, no_cache, name, blocking, open_ui, env)
_run_app(file, cloud, cluster_id, without_server, no_cache, name, blocking, open_ui, env, secret)


@_main.group(hidden=True)
Expand Down
8 changes: 7 additions & 1 deletion src/lightning_app/runners/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from lightning_app.utilities.dependency_caching import get_hash
from lightning_app.utilities.packaging.app_config import AppConfig, find_config_file
from lightning_app.utilities.packaging.lightning_utils import _prepare_lightning_wheels_and_requirements
from lightning_app.utilities.secrets import _names_to_ids

logger = Logger(__name__)

Expand Down Expand Up @@ -96,8 +97,13 @@ def dispatch(

print(f"The name of the app is: {app_config.name}")

work_reqs: List[V1Work] = []
secret_names_to_ids = _names_to_ids(self.secrets.values())
env_vars_from_secrets = [V1EnvVar(name=k, from_secret=secret_names_to_ids[v]) for k, v in self.secrets.items()]

v1_env_vars = [V1EnvVar(name=k, value=v) for k, v in self.env_vars.items()]
v1_env_vars.extend(env_vars_from_secrets)

work_reqs: List[V1Work] = []
for flow in self.app.flows:
for work in flow.works(recurse=False):
work_requirements = "\n".join(work.cloud_build_config.requirements)
Expand Down
11 changes: 10 additions & 1 deletion src/lightning_app/runners/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def dispatch(
on_before_run: Optional[Callable] = None,
name: str = "",
env_vars: Dict[str, str] = {},
secrets: Dict[str, str] = {},
cluster_id: str = None,
) -> Optional[Any]:
"""Bootstrap and dispatch the application to the target.
Expand All @@ -43,6 +44,7 @@ def dispatch(
on_before_run: Callable to be executed before run.
name: Name of app execution
env_vars: Dict of env variables to be set on the app
secrets: Dict of secrets to be passed as environment variables to the app
cluster_id: the Lightning AI cluster to run the app on. Defaults to managed Lightning AI cloud
"""
from lightning_app.runners.runtime_type import RuntimeType
Expand All @@ -58,7 +60,13 @@ def dispatch(
app.stage = AppStage.BLOCKING

runtime = runtime_cls(
app=app, entrypoint_file=entrypoint_file, start_server=start_server, host=host, port=port, env_vars=env_vars
app=app,
entrypoint_file=entrypoint_file,
start_server=start_server,
host=host,
port=port,
env_vars=env_vars,
secrets=secrets,
)
# a cloud dispatcher will return the result while local
# dispatchers will be running the app in the main process
Expand All @@ -78,6 +86,7 @@ class Runtime:
done: bool = False
backend: Optional[Union[str, Backend]] = "multiprocessing"
env_vars: Dict[str, str] = field(default_factory=dict)
secrets: Dict[str, str] = field(default_factory=dict)

def __post_init__(self):
if isinstance(self.backend, str):
Expand Down
26 changes: 26 additions & 0 deletions src/lightning_app/utilities/secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Dict, List

from lightning_app.utilities.network import LightningClient
from lightning_app.utilities.cloud import _get_project


def _names_to_ids(secret_names: List[str]) -> Dict[str, str]:
"""Returns the name/ID pair for each given Secret name.
Raises a `ValueError` if any of the given Secret names do not exist.
"""
lightning_client = LightningClient()

project = _get_project(lightning_client)
secrets = lightning_client.secret_service_list_secrets(project.project_id)

secret_names_to_ids: Dict[str, str] = {}
for secret in secrets.secrets:
if secret.name in secret_names:
secret_names_to_ids[secret.name] = secret.id

for secret_name in secret_names:
if secret_name not in secret_names_to_ids.keys():
raise ValueError(f"Secret with name '{secret.name}' not found")

return secret_names_to_ids
25 changes: 23 additions & 2 deletions tests/tests_app/cli/test_run_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def _lightning_app_run_and_logging(self, *args, **kwargs):

with caplog.at_level(logging.INFO):
with mock.patch("lightning_app.LightningApp._run", _lightning_app_run_and_logging):

runner = CliRunner()
result = runner.invoke(
run_app,
Expand Down Expand Up @@ -70,6 +69,7 @@ def test_lightning_run_cluster_without_cloud(monkeypatch):
open_ui=False,
no_cache=True,
env=("FOO=bar",),
secret=(),
)


Expand All @@ -80,7 +80,7 @@ def test_lightning_run_app_cloud(mock_dispatch: mock.MagicMock, open_ui, caplog,
"""This test validates the command has ran properly when --cloud argument is passed.
It tests it by checking if the click.launch is called with the right url if --open-ui was true and also checks the
call to `dispatch` for the right arguments
call to `dispatch` for the right arguments.
"""
monkeypatch.setattr("lightning_app.runners.cloud.logger", logging.getLogger())

Expand All @@ -95,6 +95,7 @@ def test_lightning_run_app_cloud(mock_dispatch: mock.MagicMock, open_ui, caplog,
open_ui=open_ui,
no_cache=True,
env=("FOO=bar",),
secret=("BAR=my-secret",),
)
# capture logs.
# TODO(yurij): refactor the test, check if the actual HTTP request is being sent and that the proper admin
Expand All @@ -108,5 +109,25 @@ def test_lightning_run_app_cloud(mock_dispatch: mock.MagicMock, open_ui, caplog,
name="",
no_cache=True,
env_vars={"FOO": "bar"},
secrets={"BAR": "my-secret"},
cluster_id="",
)


def test_lightning_run_app_secrets(monkeypatch):
"""Validates that running apps only supports the `--secrets` argument if the `--cloud` argument is passed."""
monkeypatch.setattr("lightning_app.runners.cloud.logger", logging.getLogger())

with pytest.raises(click.exceptions.ClickException):
_run_app(
file=os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"),
cloud=False,
cluster_id="test-cluster",
without_server=False,
name="",
blocking=False,
open_ui=False,
no_cache=True,
env=(),
secret=("FOO=my-secret"),
)
48 changes: 48 additions & 0 deletions tests/tests_app/utilities/test_secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Dict, List
from unittest import mock
from unittest.mock import MagicMock
import pytest
from lightning_cloud.openapi import V1ListMembershipsResponse, V1ListSecretsResponse, V1Membership, V1Secret

from lightning_app.utilities.secrets import _names_to_ids


@pytest.mark.parametrize(
"secret_names, secrets, expected, expected_exception",
[
([], [], {}, False),
(
["first-secret", "second-secret"],
[
V1Secret(name="first-secret", id="1234"),
],
{},
True,
),
(
["first-secret", "second-secret"],
[V1Secret(name="first-secret", id="1234"), V1Secret(name="second-secret", id="5678")],
{"first-secret": "1234", "second-secret": "5678"},
False,
),
],
)
@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.secret_service_list_secrets")
@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")
def test_names_to_ids(
list_memberships: MagicMock,
list_secrets: MagicMock,
secret_names: List[str],
secrets: List[V1Secret],
expected: Dict[str, str],
expected_exception: bool,
):
list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
list_secrets.return_value = V1ListSecretsResponse(secrets=secrets)

if expected_exception:
with pytest.raises(ValueError):
_names_to_ids(secret_names)
else:
assert _names_to_ids(secret_names) == expected

0 comments on commit f9529e5

Please sign in to comment.