Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions gigl/env/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Environment-variable keys exported by ``launch_custom``.

These keys are set on the subprocess env (never on the parent
``os.environ``) by ``gigl.src.common.custom_launcher.launch_custom`` so
that receiving CLIs can ``os.environ.get(...)`` their runtime context.
"""

from typing import Final

GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY: Final[str] = "GIGL_APPLIED_TASK_IDENTIFIER"
GIGL_TASK_CONFIG_URI_ENV_KEY: Final[str] = "GIGL_TASK_CONFIG_URI"
GIGL_RESOURCE_CONFIG_URI_ENV_KEY: Final[str] = "GIGL_RESOURCE_CONFIG_URI"
Comment thread
kmontemayor2-sc marked this conversation as resolved.
GIGL_CPU_DOCKER_URI_ENV_KEY: Final[str] = "GIGL_CPU_DOCKER_URI"
GIGL_CUDA_DOCKER_URI_ENV_KEY: Final[str] = "GIGL_CUDA_DOCKER_URI"
Comment thread
kmontemayor2-sc marked this conversation as resolved.
GIGL_COMPONENT_ENV_KEY: Final[str] = "GIGL_COMPONENT"
76 changes: 55 additions & 21 deletions gigl/src/common/custom_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,34 @@
dynamic content (runtime URIs, image refs, etc.) is the caller's
responsibility — typically resolved at YAML-load time before the
proto reaches this module.

The dispatcher exports its context args as ``GIGL_*`` environment
variables on the subprocess env (see ``gigl.env.constants``) so
receiving CLIs can ``os.environ.get(...)`` whatever runtime context
they need. The parent process's ``os.environ`` is never mutated; the
``GIGL_*`` keys live only in the per-call env passed to
``subprocess.run``.
"""

import os
import shlex
import subprocess
from collections.abc import Mapping
from typing import Optional

from gigl.common import Uri
from gigl.common.constants import (
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA,
)
from gigl.common.logger import Logger
from gigl.env.constants import (
GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY,
GIGL_COMPONENT_ENV_KEY,
GIGL_CPU_DOCKER_URI_ENV_KEY,
GIGL_CUDA_DOCKER_URI_ENV_KEY,
GIGL_RESOURCE_CONFIG_URI_ENV_KEY,
GIGL_TASK_CONFIG_URI_ENV_KEY,
)
from gigl.src.common.constants.components import GiGLComponents
from snapchat.research.gbml.gigl_resource_config_pb2 import CustomLauncherConfig

Expand All @@ -36,8 +55,6 @@ def launch_custom(
applied_task_identifier: str,
task_config_uri: Uri,
resource_config_uri: Uri,
process_command: str,
process_runtime_args: Mapping[str, str],
cpu_docker_uri: Optional[str],
cuda_docker_uri: Optional[str],
component: GiGLComponents,
Expand All @@ -46,36 +63,41 @@ def launch_custom(

Composes a shell line as ``command`` followed by each ``args[]``
element passed through ``shlex.quote``, then invokes
``subprocess.run(shell_line, shell=True, check=True)``.
``subprocess.run(shell_line, shell=True, check=True, env=env)``.

The dispatcher takes ``command`` and ``args[]`` verbatim — no
template substitution of any kind. Any placeholder text in those
fields reaches ``subprocess.run`` literally; consumers that want
substitution should resolve it at YAML-load time before the proto
reaches this module.

``applied_task_identifier``, ``task_config_uri``,
``resource_config_uri``, ``process_command``,
``process_runtime_args``, ``cpu_docker_uri``, and ``cuda_docker_uri``
are accepted for API symmetry with the GLT-side Vertex AI launchers
but are intentionally not plumbed into the subprocess — the
receiving CLI is expected to source whatever context it needs from
the resource config it gets handed (or from env vars inherited from
the parent process).
The subprocess env is built per-call from ``os.environ.copy()`` plus
the ``GIGL_*`` keys defined in :mod:`gigl.env.constants`. The
parent process's ``os.environ`` is never mutated. When ``None`` is
passed for ``cpu_docker_uri`` / ``cuda_docker_uri``, the
corresponding env var falls back to
:data:`gigl.common.constants.DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU` /
:data:`gigl.common.constants.DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA`
so receivers always observe a usable image URI.

Args:
custom_launcher_config: Proto whose ``command`` is the shell
snippet to execute and whose ``args`` are positional
arguments appended verbatim.
applied_task_identifier: Accepted for back-compat; ignored.
task_config_uri: Accepted for back-compat; ignored.
resource_config_uri: Accepted for back-compat; ignored.
process_command: Accepted for back-compat; ignored.
process_runtime_args: Accepted for back-compat; ignored.
cpu_docker_uri: Accepted for back-compat; ignored.
cuda_docker_uri: Accepted for back-compat; ignored.
applied_task_identifier: Exported to the subprocess as
``GIGL_APPLIED_TASK_IDENTIFIER``.
task_config_uri: Exported to the subprocess as
``GIGL_TASK_CONFIG_URI`` (stringified).
resource_config_uri: Exported to the subprocess as
``GIGL_RESOURCE_CONFIG_URI`` (stringified).
cpu_docker_uri: Exported as ``GIGL_CPU_DOCKER_URI``. Falls back
to ``DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU`` when ``None``.
cuda_docker_uri: Exported as ``GIGL_CUDA_DOCKER_URI``. Falls
back to ``DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA`` when
``None``.
component: Which GiGL component is being launched. Must be in
``_LAUNCHABLE_COMPONENTS``.
``_LAUNCHABLE_COMPONENTS``. Exported as ``GIGL_COMPONENT``
using ``component.name`` (e.g. ``"Trainer"``).

Raises:
ValueError: If ``component`` is not Trainer or Inferencer, or if
Expand All @@ -91,6 +113,18 @@ def launch_custom(
command: str = custom_launcher_config.command
args: list[str] = list(custom_launcher_config.args)

env: dict[str, str] = os.environ.copy()
env[GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY] = applied_task_identifier
env[GIGL_TASK_CONFIG_URI_ENV_KEY] = str(task_config_uri)
env[GIGL_RESOURCE_CONFIG_URI_ENV_KEY] = str(resource_config_uri)
env[GIGL_COMPONENT_ENV_KEY] = component.name
env[GIGL_CPU_DOCKER_URI_ENV_KEY] = (
cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
)
env[GIGL_CUDA_DOCKER_URI_ENV_KEY] = (
cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
)

shell_line = " ".join([command, *(shlex.quote(a) for a in args)])
logger.info(f"Launching {component.name} via subprocess: {shell_line!r}")
subprocess.run(shell_line, shell=True, check=True)
subprocess.run(shell_line, shell=True, check=True, env=env)
Comment thread
kmontemayor2-sc marked this conversation as resolved.
107 changes: 99 additions & 8 deletions tests/unit/src/common/custom_launcher_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
"""Unit tests for ``gigl.src.common.custom_launcher``."""

import os
from unittest.mock import MagicMock, patch

from absl.testing import absltest

from gigl.common import Uri
from gigl.common.constants import (
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA,
)
from gigl.env.constants import (
GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY,
GIGL_COMPONENT_ENV_KEY,
GIGL_CPU_DOCKER_URI_ENV_KEY,
GIGL_CUDA_DOCKER_URI_ENV_KEY,
GIGL_RESOURCE_CONFIG_URI_ENV_KEY,
GIGL_TASK_CONFIG_URI_ENV_KEY,
)
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.custom_launcher import launch_custom
from snapchat.research.gbml import gigl_resource_config_pb2
Expand Down Expand Up @@ -43,8 +56,6 @@ def test_dispatches_subprocess_with_literal_command_and_args(
applied_task_identifier="job-42",
task_config_uri=Uri("gs://bucket/task.yaml"),
resource_config_uri=Uri("gs://bucket/resource.yaml"),
process_command="ignored",
process_runtime_args={"ignored": "v"},
cpu_docker_uri="gcr.io/p/cpu:tag",
cuda_docker_uri="gcr.io/p/cuda:tag",
component=GiGLComponents.Trainer,
Expand All @@ -68,8 +79,6 @@ def test_empty_command_raises_value_error(self, mock_run: MagicMock) -> None:
applied_task_identifier="job",
task_config_uri=Uri("gs://bucket/task.yaml"),
resource_config_uri=Uri("gs://bucket/resource.yaml"),
process_command="",
process_runtime_args={},
cpu_docker_uri=None,
cuda_docker_uri=None,
component=GiGLComponents.Trainer,
Expand All @@ -85,8 +94,6 @@ def test_invalid_component_raises_value_error(self, mock_run: MagicMock) -> None
applied_task_identifier="job",
task_config_uri=Uri("gs://bucket/task.yaml"),
resource_config_uri=Uri("gs://bucket/resource.yaml"),
process_command="echo 'hello, world!",
process_runtime_args={},
cpu_docker_uri=None,
cuda_docker_uri=None,
component=GiGLComponents.DataPreprocessor,
Expand All @@ -101,8 +108,6 @@ def test_args_with_spaces_are_shell_quoted(self, mock_run: MagicMock) -> None:
applied_task_identifier="job",
task_config_uri=Uri("gs://bucket/task.yaml"),
resource_config_uri=Uri("gs://bucket/resource.yaml"),
process_command="",
process_runtime_args={},
cpu_docker_uri=None,
cuda_docker_uri=None,
component=GiGLComponents.Trainer,
Expand All @@ -113,6 +118,92 @@ def test_args_with_spaces_are_shell_quoted(self, mock_run: MagicMock) -> None:
self.assertIn("'a b c'", shell_line)
self.assertIn("'--name=with space'", shell_line)

@patch("gigl.src.common.custom_launcher.subprocess.run")
def test_dispatch_sets_gigl_env_vars(self, mock_run: MagicMock) -> None:
config = self._build_config(command="python -m my.cli")
launch_custom(
custom_launcher_config=config,
applied_task_identifier="job-42",
task_config_uri=Uri("gs://bucket/task.yaml"),
resource_config_uri=Uri("gs://bucket/resource.yaml"),
cpu_docker_uri="gcr.io/p/cpu:tag",
cuda_docker_uri="gcr.io/p/cuda:tag",
component=GiGLComponents.Trainer,
)
env = mock_run.call_args.kwargs["env"]
self.assertEqual(env[GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY], "job-42")
self.assertEqual(env[GIGL_TASK_CONFIG_URI_ENV_KEY], "gs://bucket/task.yaml")
self.assertEqual(
env[GIGL_RESOURCE_CONFIG_URI_ENV_KEY], "gs://bucket/resource.yaml"
)
self.assertEqual(env[GIGL_CPU_DOCKER_URI_ENV_KEY], "gcr.io/p/cpu:tag")
self.assertEqual(env[GIGL_CUDA_DOCKER_URI_ENV_KEY], "gcr.io/p/cuda:tag")
# component is exported via .name (the enum member identifier).
self.assertEqual(env[GIGL_COMPONENT_ENV_KEY], "Trainer")

@patch("gigl.src.common.custom_launcher.subprocess.run")
def test_dispatch_defaults_optional_uris_to_release_images(
self, mock_run: MagicMock
) -> None:
config = self._build_config(command="echo")
launch_custom(
custom_launcher_config=config,
applied_task_identifier="job",
task_config_uri=Uri("gs://bucket/task.yaml"),
resource_config_uri=Uri("gs://bucket/resource.yaml"),
cpu_docker_uri=None,
cuda_docker_uri=None,
Comment thread
kmontemayor2-sc marked this conversation as resolved.
component=GiGLComponents.Inferencer,
)
env = mock_run.call_args.kwargs["env"]
# When the caller passes None for a docker URI, the env var
# falls back to the public release image so receivers always
# see a usable URI.
self.assertEqual(
env[GIGL_CPU_DOCKER_URI_ENV_KEY], DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
)
self.assertEqual(
env[GIGL_CUDA_DOCKER_URI_ENV_KEY], DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
)
self.assertEqual(env[GIGL_COMPONENT_ENV_KEY], "Inferencer")

@patch("gigl.src.common.custom_launcher.subprocess.run")
def test_dispatch_isolates_subprocess_env_from_parent(
self, mock_run: MagicMock
) -> None:
sentinel_key = "GIGL_TEST_PARENT_ENV_SENTINEL"
sentinel_value = "preserved-value"
try:
os.environ[sentinel_key] = sentinel_value
snapshot = dict(os.environ)
config = self._build_config(command="echo")
launch_custom(
custom_launcher_config=config,
applied_task_identifier="job",
task_config_uri=Uri("gs://bucket/task.yaml"),
resource_config_uri=Uri("gs://bucket/resource.yaml"),
cpu_docker_uri="gcr.io/p/cpu:tag",
cuda_docker_uri="gcr.io/p/cuda:tag",
component=GiGLComponents.Trainer,
)
# Parent os.environ is untouched; none of the GIGL_* keys
# leak into it.
self.assertEqual(dict(os.environ), snapshot)
for key in (
GIGL_APPLIED_TASK_IDENTIFIER_ENV_KEY,
GIGL_TASK_CONFIG_URI_ENV_KEY,
GIGL_RESOURCE_CONFIG_URI_ENV_KEY,
GIGL_CPU_DOCKER_URI_ENV_KEY,
GIGL_CUDA_DOCKER_URI_ENV_KEY,
GIGL_COMPONENT_ENV_KEY,
):
self.assertNotIn(key, os.environ)
# Inherited parent env entries reach the subprocess env.
env = mock_run.call_args.kwargs["env"]
self.assertEqual(env.get(sentinel_key), sentinel_value)
finally:
os.environ.pop(sentinel_key, None)


if __name__ == "__main__":
absltest.main()