diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index c6e89e19c8..c0d3a44791 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -15,6 +15,7 @@ from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig from sagemaker.train.configs import InputData, OutputDataConfig from sagemaker.train.defaults import TrainDefaults +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) @@ -317,13 +318,15 @@ def _resolve_model_package_arn(model_package) -> Optional[str]: def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session, - hub_name: str = "SageMakerPublicHub") -> tuple: + hub_name: Optional[str] = None) -> tuple: """Get fine-tuning options and model ARN for given customization technique. Returns: tuple: (FineTuningOptions, model_arn, is_gated_model) """ - + if hub_name is None: + hub_name = get_sagemaker_hub_name() + try: hub_content = _get_hub_content_metadata( diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 2ce6ea7198..f39a7a165e 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -14,6 +14,7 @@ from enum import Enum import re from sagemaker.train.base_trainer import BaseTrainer +from sagemaker.train.constants import get_sagemaker_hub_name from sagemaker.core.utils.utils import Unassigned @@ -52,8 +53,6 @@ class _ModelResolver: and fine-tuned ModelPackage objects/ARNs. """ - DEFAULT_HUB_NAME = "SageMakerPublicHub" - def __init__(self, sagemaker_session=None): """ Initialize the resolver. @@ -89,7 +88,7 @@ def resolve_model_info( if base_model.startswith("arn:aws:sagemaker:") and ":model-package/" in base_model: return self._resolve_model_package_arn(base_model) else: - return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME) + return self._resolve_jumpstart_model(base_model, hub_name or get_sagemaker_hub_name()) # Handle BaseTrainer type elif isinstance(base_model, BaseTrainer): if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job, diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 68b0f6c474..212a57e642 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -40,7 +40,13 @@ + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}", ] -HUB_NAME = "SageMakerPublicHub" +def get_sagemaker_hub_name() -> str: + """Return the SageMaker Hub name, honoring SAGEMAKER_HUB_NAME env var override. + + Resolved at call time so tests and dev workflows can override the hub + without re-importing this module. Defaults to ``"SageMakerPublicHub"``. + """ + return os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub") # Allowed reward model IDs for RLAIF trainer with region restrictions _ALLOWED_REWARD_MODEL_IDS = { diff --git a/sagemaker-train/src/sagemaker/train/dpo_trainer.py b/sagemaker-train/src/sagemaker/train/dpo_trainer.py index 75a450c3c8..b9f7449354 100644 --- a/sagemaker-train/src/sagemaker/train/dpo_trainer.py +++ b/sagemaker-train/src/sagemaker/train/dpo_trainer.py @@ -23,7 +23,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -244,7 +244,7 @@ def train(self, ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py index d6bad422c6..4e4e522b3a 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/benchmark_evaluator.py @@ -21,6 +21,7 @@ from .execution import EvaluationPipelineExecution from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature +from sagemaker.train.constants import get_sagemaker_hub_name _logger = logging.getLogger(__name__) @@ -466,7 +467,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=get_sagemaker_hub_name(), evaluation_type=evaluation_type, region=region, session=boto_session diff --git a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py index 78d297006c..9c768c3891 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py @@ -16,6 +16,7 @@ from .execution import EvaluationPipelineExecution from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature +from sagemaker.train.constants import get_sagemaker_hub_name _logger = logging.getLogger(__name__) @@ -240,7 +241,7 @@ def hyperparameters(self): override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=get_sagemaker_hub_name(), evaluation_type="DeterministicEvaluation", region=region, session=boto_session @@ -365,7 +366,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict: _logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}") override_params = _get_evaluation_override_params( hub_content_name=hub_content_name, - hub_name="SageMakerPublicHub", + hub_name=get_sagemaker_hub_name(), evaluation_type="DeterministicEvaluation", region=region, session=session diff --git a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py index db19a5e1d9..d4cbc7cf8f 100644 --- a/sagemaker-train/src/sagemaker/train/rlaif_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlaif_trainer.py @@ -27,7 +27,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME, _ALLOWED_REWARD_MODEL_IDS +from sagemaker.train.constants import get_sagemaker_hub_name, _ALLOWED_REWARD_MODEL_IDS logger = logging.getLogger(__name__) @@ -263,7 +263,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { @@ -358,7 +358,7 @@ def _process_non_builtin_reward_prompt(self): sagemaker_session=self.sagemaker_session ) hub_content = _get_hub_content_metadata( - hub_name=HUB_NAME, + hub_name=get_sagemaker_hub_name(), hub_content_type="JsonDoc", hub_content_name=self.reward_prompt, session=session.boto_session, diff --git a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py index 8a11cfb0d8..93a5105f8e 100644 --- a/sagemaker-train/src/sagemaker/train/rlvr_trainer.py +++ b/sagemaker-train/src/sagemaker/train/rlvr_trainer.py @@ -25,7 +25,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) @@ -251,7 +251,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/src/sagemaker/train/sft_trainer.py b/sagemaker-train/src/sagemaker/train/sft_trainer.py index 80465c061d..cc67469406 100644 --- a/sagemaker-train/src/sagemaker/train/sft_trainer.py +++ b/sagemaker-train/src/sagemaker/train/sft_trainer.py @@ -24,7 +24,7 @@ ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature -from sagemaker.train.constants import HUB_NAME +from sagemaker.train.constants import get_sagemaker_hub_name logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -245,7 +245,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati ) vpc_config = self.networking if self.networking else None - tags = _get_studio_tags(self._model_name, HUB_NAME) + tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name()) # Build TrainingJob.create() arguments create_args = { diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index 31a827e3f0..7e145dbb21 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -63,7 +63,6 @@ def test_resolver_initialization(self): """Test ModelResolver initialization.""" resolver = _ModelResolver() assert resolver.sagemaker_session is None - assert resolver.DEFAULT_HUB_NAME == "SageMakerPublicHub" def test_resolver_with_session(self): """Test ModelResolver with custom session.""" diff --git a/sagemaker-train/tests/unit/train/test_constants.py b/sagemaker-train/tests/unit/train/test_constants.py new file mode 100644 index 0000000000..4cb3fc6dec --- /dev/null +++ b/sagemaker-train/tests/unit/train/test_constants.py @@ -0,0 +1,20 @@ +"""Tests for SAGEMAKER_HUB_NAME env-var override via get_sagemaker_hub_name.""" +from __future__ import absolute_import + +import os +from unittest.mock import patch + +from sagemaker.train.constants import get_sagemaker_hub_name + + +def test_get_sagemaker_hub_name_defaults_to_public_hub(): + """When SAGEMAKER_HUB_NAME is unset, returns SageMakerPublicHub.""" + env = {k: v for k, v in os.environ.items() if k != "SAGEMAKER_HUB_NAME"} + with patch.dict(os.environ, env, clear=True): + assert get_sagemaker_hub_name() == "SageMakerPublicHub" + + +def test_get_sagemaker_hub_name_overridden_by_env_var(): + """When SAGEMAKER_HUB_NAME is set, returns the override value.""" + with patch.dict(os.environ, {"SAGEMAKER_HUB_NAME": "MyPrivateHub"}): + assert get_sagemaker_hub_name() == "MyPrivateHub"