From 62ca624c25aec7c9ffaf427dc45fedc3eae8c4da Mon Sep 17 00:00:00 2001 From: Molly He Date: Tue, 21 Apr 2026 11:43:53 -0700 Subject: [PATCH 1/2] allow SAGEMAKER_HUB_NAME env var override for HUB_NAME constant --- .../train/common_utils/finetune_utils.py | 3 ++- .../train/common_utils/model_resolution.py | 3 ++- .../src/sagemaker/train/constants.py | 2 +- .../train/evaluate/benchmark_evaluator.py | 3 ++- .../train/evaluate/custom_scorer_evaluator.py | 5 ++-- .../tests/unit/train/test_constants.py | 26 +++++++++++++++++++ 6 files changed, 36 insertions(+), 6 deletions(-) create mode 100644 sagemaker-train/tests/unit/train/test_constants.py diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index c6e89e19c8..d6eafdb214 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -15,6 +15,7 @@ from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig from sagemaker.train.configs import InputData, OutputDataConfig from sagemaker.train.defaults import TrainDefaults +from sagemaker.train.constants import HUB_NAME logger = logging.getLogger(__name__) @@ -317,7 +318,7 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: str = "SageMakerPublicHub") -> tuple: + hub_name: str = HUB_NAME) -> tuple: """Get fine-tuning options and model ARN for given customization technique. Returns: diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 2ce6ea7198..005d5d02ca 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -14,6 +14,7 @@ from enum import Enum import re from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.train.constants import HUB_NAME from sagemaker.core.utils.utils import Unassigned @@ -52,7 +53,7 @@ class _ModelResolver: and fine-tuned ModelPackage objects/ARNs. """ - DEFAULT_HUB_NAME = "SageMakerPublicHub" + DEFAULT_HUB_NAME = HUB_NAME def __init__(self, sagemaker_session=None): """ diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 68b0f6c474..dfe4e1f423 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -40,7 +40,7 @@ + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}", ] -HUB_NAME = "SageMakerPublicHub" +HUB_NAME = os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub") # Allowed reward model IDs for RLAIF trainer with region restrictions _ALLOWED_REWARD_MODEL_IDS = { diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index d6bad422c6..d12dec9c59 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -21,6 +21,7 @@ from .execution import EvaluationPipelineExecution from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature +from sagemaker.train.constants import HUB_NAME _logger = logging.getLogger(__name__) @@ -466,7 +467,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=HUB_NAME, evaluation_type=evaluation_type, region=region, session=boto_session diff --git a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py index 78d297006c..9e9e59fe85 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py @@ -16,6 +16,7 @@ from .execution import EvaluationPipelineExecution from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature +from sagemaker.train.constants import HUB_NAME _logger = logging.getLogger(__name__) @@ -240,7 +241,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=HUB_NAME, evaluation_type="DeterministicEvaluation", region=region, session=boto_session @@ -365,7 +366,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict: _logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}") override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=HUB_NAME, evaluation_type="DeterministicEvaluation", region=region, session=session diff --git a/sagemaker-train/tests/unit/train/test_constants.py b/sagemaker-train/tests/unit/train/test_constants.py new file mode 100644 index 0000000000..f71d07b458 --- /dev/null +++ b/sagemaker-train/tests/unit/train/test_constants.py @@ -0,0 +1,26 @@ +"""Tests for SAGEMAKER_HUB_NAME env-var override of the HUB_NAME constant.""" +from __future__ import absolute_import + +import importlib +import os +from unittest.mock import patch + + +def _reload_hub_name(): + """Reload the constants module under the current env and return HUB_NAME.""" + from sagemaker.train import constants + importlib.reload(constants) + return constants.HUB_NAME + + +def test_hub_name_defaults_to_public_hub(): + """When SAGEMAKER_HUB_NAME is unset, HUB_NAME is SageMakerPublicHub.""" + env = {k: v for k, v in os.environ.items() if k != "SAGEMAKER_HUB_NAME"} + with patch.dict(os.environ, env, clear=True): + assert _reload_hub_name() == "SageMakerPublicHub" + + +def test_hub_name_overridden_by_env_var(): + """When SAGEMAKER_HUB_NAME is set, HUB_NAME reflects the override.""" + with patch.dict(os.environ, {"SAGEMAKER_HUB_NAME": "MyPrivateHub"}): + assert _reload_hub_name() == "MyPrivateHub" From c5c44d886e16807566ab7c8ba9c5d8fd6f2e64be Mon Sep 17 00:00:00 2001 From: Molly He Date: Tue, 21 Apr 2026 12:07:03 -0700 Subject: [PATCH 2/2] Resolve env var runtime --- .../train/common_utils/finetune_utils.py | 8 ++++--- .../train/common_utils/model_resolution.py | 6 ++--- .../src/sagemaker/train/constants.py | 8 ++++++- .../src/sagemaker/train/dpo_trainer.py | 4 ++-- .../train/evaluate/benchmark_evaluator.py | 4 ++-- .../train/evaluate/custom_scorer_evaluator.py | 6 ++--- .../src/sagemaker/train/rlaif_trainer.py | 6 ++--- .../src/sagemaker/train/rlvr_trainer.py | 4 ++-- .../src/sagemaker/train/sft_trainer.py | 4 ++-- .../common_utils/test_model_resolution.py | 1 - .../tests/unit/train/test_constants.py | 22 +++++++------------ 11 files changed, 36 insertions(+), 37 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index d6eafdb214..c0d3a44791 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -15,7 +15,7 @@ from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig from sagemaker.train.configs import InputData, OutputDataConfig from sagemaker.train.defaults import TrainDefaults -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) @@ -318,13 +318,15 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: str = HUB_NAME) -> tuple: + hub_name: Optional[str] = None) -> tuple: """Get fine-tuning options and model ARN for given customization technique. Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ - + if hub_name is None: + hub_name = get_sagemaker_hub_name() + try: hub_content = _get_hub_content_metadata( diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 005d5d02ca..f39a7a165e 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -14,7 +14,7 @@ from enum import Enum import re from sagemaker.train.base_trainer import BaseTrainer -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name from sagemaker.core.utils.utils import Unassigned @@ -53,8 +53,6 @@ class _ModelResolver: and fine-tuned ModelPackage objects/ARNs. """ - DEFAULT_HUB_NAME = HUB_NAME - def __init__(self, sagemaker_session=None): """ Initialize the resolver. @@ -90,7 +88,7 @@ def resolve_model_info( if base_model.startswith("arn:aws:sagemaker:") and ":model-package/" in base_model: return self._resolve_model_package_arn(base_model) else: - return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME) + return self._resolve_jumpstart_model(base_model, hub_name or get_sagemaker_hub_name()) # Handle BaseTrainer type elif isinstance(base_model, BaseTrainer): if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job, diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index dfe4e1f423..212a57e642 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -40,7 +40,13 @@ + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}", ] -HUB_NAME = os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub") +def get_sagemaker_hub_name() -> str: + """Return the SageMaker Hub name, honoring SAGEMAKER_HUB_NAME env var override. + + Resolved at call time so tests and dev workflows can override the hub + without re-importing this module. Defaults to ``"SageMakerPublicHub"``. + """ + return os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub") # Allowed reward model IDs for RLAIF trainer with region restrictions _ALLOWED_REWARD_MODEL_IDS = { diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 75a450c3c8..b9f7449354 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -23,7 +23,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -244,7 +244,7 @@ def train(self, ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index d12dec9c59..4e4e522b3a 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -21,7 +21,7 @@ from .execution import EvaluationPipelineExecution from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name _logger = logging.getLogger(__name__) @@ -467,7 +467,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name=HUB_NAME, + hub_name=get_sagemaker_hub_name(), evaluation_type=evaluation_type, region=region, session=boto_session diff --git a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py index 9e9e59fe85..9c768c3891 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py @@ -16,7 +16,7 @@ from .execution import EvaluationPipelineExecution from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name _logger = logging.getLogger(__name__) @@ -241,7 +241,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name=HUB_NAME, + hub_name=get_sagemaker_hub_name(), evaluation_type="DeterministicEvaluation", region=region, session=boto_session @@ -366,7 +366,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict: _logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}") override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name=HUB_NAME, + hub_name=get_sagemaker_hub_name(), evaluation_type="DeterministicEvaluation", region=region, session=session diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index db19a5e1d9..d4cbc7cf8f 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -27,7 +27,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME, _ALLOWED_REWARD_MODEL_IDS +from sagemaker.train.constants import get_sagemaker_hub_name, _ALLOWED_REWARD_MODEL_IDS logger = logging.getLogger(__name__) @@ -263,7 +263,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { @@ -358,7 +358,7 @@ def _process_non_builtin_reward_prompt(self): sagemaker_session=self.sagemaker_session ) hub_content = _get_hub_content_metadata( - hub_name=HUB_NAME, + hub_name=get_sagemaker_hub_name(), hub_content_type="JsonDoc", hub_content_name=self.reward_prompt, session=session.boto_session, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 8a11cfb0d8..93a5105f8e 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -25,7 +25,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) @@ -251,7 +251,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 80465c061d..cc67469406 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -24,7 +24,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -245,7 +245,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index 31a827e3f0..7e145dbb21 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -63,7 +63,6 @@ def test_resolver_initialization(self): """Test ModelResolver initialization.""" resolver = _ModelResolver() assert resolver.sagemaker_session is None - assert resolver.DEFAULT_HUB_NAME == "SageMakerPublicHub" def test_resolver_with_session(self): """Test ModelResolver with custom session.""" diff --git a/sagemaker-train/tests/unit/train/test_constants.py b/sagemaker-train/tests/unit/train/test_constants.py index f71d07b458..4cb3fc6dec 100644 --- a/sagemaker-train/tests/unit/train/test_constants.py +++ b/sagemaker-train/tests/unit/train/test_constants.py @@ -1,26 +1,20 @@ -"""Tests for SAGEMAKER_HUB_NAME env-var override of the HUB_NAME constant.""" +"""Tests for SAGEMAKER_HUB_NAME env-var override via get_sagemaker_hub_name.""" from __future__ import absolute_import -import importlib import os from unittest.mock import patch +from sagemaker.train.constants import get_sagemaker_hub_name -def _reload_hub_name(): - """Reload the constants module under the current env and return HUB_NAME.""" - from sagemaker.train import constants - importlib.reload(constants) - return constants.HUB_NAME - -def test_hub_name_defaults_to_public_hub(): - """When SAGEMAKER_HUB_NAME is unset, HUB_NAME is SageMakerPublicHub.""" +def test_get_sagemaker_hub_name_defaults_to_public_hub(): + """When SAGEMAKER_HUB_NAME is unset, returns SageMakerPublicHub.""" env = {k: v for k, v in os.environ.items() if k != "SAGEMAKER_HUB_NAME"} with patch.dict(os.environ, env, clear=True): - assert _reload_hub_name() == "SageMakerPublicHub" + assert get_sagemaker_hub_name() == "SageMakerPublicHub" -def test_hub_name_overridden_by_env_var(): - """When SAGEMAKER_HUB_NAME is set, HUB_NAME reflects the override.""" +def test_get_sagemaker_hub_name_overridden_by_env_var(): + """When SAGEMAKER_HUB_NAME is set, returns the override value.""" with patch.dict(os.environ, {"SAGEMAKER_HUB_NAME": "MyPrivateHub"}): - assert _reload_hub_name() == "MyPrivateHub" + assert get_sagemaker_hub_name() == "MyPrivateHub"