Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions python/gigl/common/services/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class VertexAiJobConfig:
container_uri: str
command: list[str]
args: Optional[list[str]] = None
environment_variables: Optional[list[dict[str, str]]] = None
environment_variables: Optional[list[env_var.EnvVar]] = None
machine_type: str = "n1-standard-4"
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED"
accelerator_count: int = 0
Expand Down Expand Up @@ -167,12 +167,14 @@ def launch_job(self, job_config: VertexAiJobConfig) -> aiplatform.CustomJob:
datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
"leader_worker_internal_ip.txt",
)
env_vars = [
env_vars: list[env_var.EnvVar] = [
env_var.EnvVar(
name=LEADER_WORKER_INTERNAL_IP_FILE_PATH_ENV_KEY,
value=leader_worker_internal_ip_file_path.uri,
)
]
if job_config.environment_variables:
env_vars.extend(job_config.environment_variables)

container_spec = _create_container_spec(job_config, env_vars)

Expand Down Expand Up @@ -265,8 +267,16 @@ def launch_graph_store_job(
storage_disk_spec = _create_disk_spec(storage_pool_job_config)
compute_disk_spec = _create_disk_spec(compute_pool_job_config)

storage_container_spec = _create_container_spec(storage_pool_job_config)
compute_container_spec = _create_container_spec(compute_pool_job_config)
env_vars: list[env_var.EnvVar] = (
compute_pool_job_config.environment_variables or []
)

storage_container_spec = _create_container_spec(
storage_pool_job_config, env_vars
)
compute_container_spec = _create_container_spec(
compute_pool_job_config, env_vars
)

worker_pool_specs: list[Union[WorkerPoolSpec, dict]] = []

Expand Down
9 changes: 5 additions & 4 deletions python/gigl/src/inference/v2/glt_inferencer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from typing import Optional

from google.cloud.aiplatform_v1.types import accelerator_type
from google.cloud.aiplatform_v1.types import accelerator_type, env_var

from gigl.common import Uri, UriFactory
from gigl.common.constants import (
Expand Down Expand Up @@ -101,14 +101,15 @@ def __execute_VAI_inference(
command = inference_process_command.strip().split(" ")
logger.info(f"Running inference with command: {command}")
vai_job_name = f"gigl_infer_{applied_task_identifier}"
environment_variables: list[env_var.EnvVar] = [
env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"),
]
job_config = VertexAiJobConfig(
job_name=vai_job_name,
container_uri=container_uri,
command=command,
args=job_args,
environment_variables=[
{"name": "TF_CPP_MIN_LOG_LEVEL", "value": "3"},
],
environment_variables=environment_variables,
machine_type=inferencer_resource_config.machine_type,
accelerator_type=inferencer_resource_config.gpu_type.upper().replace(
"-", "_"
Expand Down
10 changes: 5 additions & 5 deletions python/gigl/src/training/v1/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import torch
from google.cloud.aiplatform_v1.types import accelerator_type
from google.cloud.aiplatform_v1.types import accelerator_type, env_var

from gigl.common import Uri, UriFactory
from gigl.common.constants import (
Expand Down Expand Up @@ -46,7 +46,9 @@ def run(
cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
container_uri = cpu_docker_uri if is_cpu_training else cuda_docker_uri

environment_variables: list[env_var.EnvVar] = [
env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"),
]
job_args = [
f"--job_name={applied_task_identifier}",
f"--task_config_uri={task_config_uri}",
Expand All @@ -58,9 +60,7 @@ def run(
container_uri=container_uri,
command=["python", "-m", "gigl.src.training.v1.lib.training_process"],
args=job_args,
environment_variables=[
{"name": "TF_CPP_MIN_LOG_LEVEL", "value": "3"},
],
environment_variables=environment_variables,
machine_type=trainer_config.machine_type,
accelerator_type=trainer_config.gpu_type.upper().replace("-", "_"),
accelerator_count=trainer_config.gpu_limit,
Expand Down
9 changes: 5 additions & 4 deletions python/gigl/src/training/v2/glt_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
from typing import Optional

from google.cloud.aiplatform_v1.types import accelerator_type
from google.cloud.aiplatform_v1.types import accelerator_type, env_var

from gigl.common import Uri, UriFactory
from gigl.common.constants import (
Expand Down Expand Up @@ -101,14 +101,15 @@ def __execute_VAI_training(
command = training_process_command.strip().split(" ")
logger.info(f"Running trainer with command: {command}")
vai_job_name = f"gigl_train_{applied_task_identifier}"
environment_variables: list[env_var.EnvVar] = [
env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"),
]
job_config = VertexAiJobConfig(
job_name=vai_job_name,
container_uri=container_uri,
command=command,
args=job_args,
environment_variables=[
{"name": "TF_CPP_MIN_LOG_LEVEL", "value": "3"},
],
environment_variables=environment_variables,
machine_type=trainer_resource_config.machine_type,
accelerator_type=trainer_resource_config.gpu_type.upper().replace("-", "_"),
accelerator_count=trainer_resource_config.gpu_limit,
Expand Down
12 changes: 10 additions & 2 deletions python/tests/integration/common/services/vertex_ai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid

import kfp
from google.cloud.aiplatform_v1.types import env_var
from parameterized import param, parameterized

from gigl.common import UriFactory
Expand Down Expand Up @@ -73,10 +74,17 @@ def test_launch_job(self):
command = ["python", "-c", "import logging; logging.info('Hello, World!')"]

job_config = VertexAiJobConfig(
job_name=job_name, container_uri=container_uri, command=command
job_name=job_name,
container_uri=container_uri,
command=command,
environment_variables=[env_var.EnvVar(name="FOO", value="BAR")],
)

self._vertex_ai_service.launch_job(job_config)
job = self._vertex_ai_service.launch_job(job_config)
self.assertIn(
env_var.EnvVar(name="FOO", value="BAR"),
job.job_spec.worker_pool_specs[0].container_spec.env,
)

@parameterized.expand(
[
Expand Down