Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sagemaker.core.shapes import ServerlessJobConfig, Channel, DataSource, ModelPackageConfig, MlflowConfig
from sagemaker.train.configs import InputData, OutputDataConfig
from sagemaker.train.defaults import TrainDefaults
from sagemaker.train.constants import get_sagemaker_hub_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -317,13 +318,15 @@ def _resolve_model_package_arn(model_package) -> Optional[str]:


def _get_fine_tuning_options_and_model_arn(model_name: str, customization_technique: str, training_type, sagemaker_session,
hub_name: str = "SageMakerPublicHub") -> tuple:
hub_name: Optional[str] = None) -> tuple:
"""Get fine-tuning options and model ARN for given customization technique.

Returns:
tuple: (FineTuningOptions, model_arn, is_gated_model)
"""

if hub_name is None:
hub_name = get_sagemaker_hub_name()

try:

hub_content = _get_hub_content_metadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from enum import Enum
import re
from sagemaker.train.base_trainer import BaseTrainer
from sagemaker.train.constants import get_sagemaker_hub_name
from sagemaker.core.utils.utils import Unassigned


Expand Down Expand Up @@ -52,8 +53,6 @@ class _ModelResolver:
and fine-tuned ModelPackage objects/ARNs.
"""

DEFAULT_HUB_NAME = "SageMakerPublicHub"

def __init__(self, sagemaker_session=None):
"""
Initialize the resolver.
Expand Down Expand Up @@ -89,7 +88,7 @@ def resolve_model_info(
if base_model.startswith("arn:aws:sagemaker:") and ":model-package/" in base_model:
return self._resolve_model_package_arn(base_model)
else:
return self._resolve_jumpstart_model(base_model, hub_name or self.DEFAULT_HUB_NAME)
return self._resolve_jumpstart_model(base_model, hub_name or get_sagemaker_hub_name())
# Handle BaseTrainer type
elif isinstance(base_model, BaseTrainer):
if hasattr(base_model, '_latest_training_job') and hasattr(base_model._latest_training_job,
Expand Down
8 changes: 7 additions & 1 deletion sagemaker-train/src/sagemaker/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@
+ f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}",
]

HUB_NAME = "SageMakerPublicHub"
def get_sagemaker_hub_name() -> str:
"""Return the SageMaker Hub name, honoring SAGEMAKER_HUB_NAME env var override.

Resolved at call time so tests and dev workflows can override the hub
without re-importing this module. Defaults to ``"SageMakerPublicHub"``.
"""
return os.environ.get("SAGEMAKER_HUB_NAME", "SageMakerPublicHub")

# Allowed reward model IDs for RLAIF trainer with region restrictions
_ALLOWED_REWARD_MODEL_IDS = {
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import HUB_NAME
from sagemaker.train.constants import get_sagemaker_hub_name

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -244,7 +244,7 @@ def train(self,
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .execution import EvaluationPipelineExecution
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import get_sagemaker_hub_name

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -466,7 +467,7 @@ def hyperparameters(self):

override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
hub_name=get_sagemaker_hub_name(),
evaluation_type=evaluation_type,
region=region,
session=boto_session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .execution import EvaluationPipelineExecution
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import get_sagemaker_hub_name

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -240,7 +241,7 @@ def hyperparameters(self):

override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
hub_name=get_sagemaker_hub_name(),
evaluation_type="DeterministicEvaluation",
region=region,
session=boto_session
Expand Down Expand Up @@ -365,7 +366,7 @@ def _get_inference_params_from_hub(self, region: str) -> dict:
_logger.info(f"Fetching evaluation recipe override parameters from hub for model: {hub_content_name}")
override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
hub_name=get_sagemaker_hub_name(),
evaluation_type="DeterministicEvaluation",
region=region,
session=session
Expand Down
6 changes: 3 additions & 3 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import HUB_NAME, _ALLOWED_REWARD_MODEL_IDS
from sagemaker.train.constants import get_sagemaker_hub_name, _ALLOWED_REWARD_MODEL_IDS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -263,7 +263,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down Expand Up @@ -358,7 +358,7 @@ def _process_non_builtin_reward_prompt(self):
sagemaker_session=self.sagemaker_session
)
hub_content = _get_hub_content_metadata(
hub_name=HUB_NAME,
hub_name=get_sagemaker_hub_name(),
hub_content_type="JsonDoc",
hub_content_name=self.reward_prompt,
session=session.boto_session,
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import HUB_NAME
from sagemaker.train.constants import get_sagemaker_hub_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -251,7 +251,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down
4 changes: 2 additions & 2 deletions sagemaker-train/src/sagemaker/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train.constants import HUB_NAME
from sagemaker.train.constants import get_sagemaker_hub_name

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -245,7 +245,7 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
)

vpc_config = self.networking if self.networking else None
tags = _get_studio_tags(self._model_name, HUB_NAME)
tags = _get_studio_tags(self._model_name, get_sagemaker_hub_name())

# Build TrainingJob.create() arguments
create_args = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def test_resolver_initialization(self):
"""Test ModelResolver initialization."""
resolver = _ModelResolver()
assert resolver.sagemaker_session is None
assert resolver.DEFAULT_HUB_NAME == "SageMakerPublicHub"

def test_resolver_with_session(self):
"""Test ModelResolver with custom session."""
Expand Down
20 changes: 20 additions & 0 deletions sagemaker-train/tests/unit/train/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Tests for SAGEMAKER_HUB_NAME env-var override via get_sagemaker_hub_name."""
from __future__ import absolute_import

import os
from unittest.mock import patch

from sagemaker.train.constants import get_sagemaker_hub_name


def test_get_sagemaker_hub_name_defaults_to_public_hub():
"""When SAGEMAKER_HUB_NAME is unset, returns SageMakerPublicHub."""
env = {k: v for k, v in os.environ.items() if k != "SAGEMAKER_HUB_NAME"}
with patch.dict(os.environ, env, clear=True):
assert get_sagemaker_hub_name() == "SageMakerPublicHub"


def test_get_sagemaker_hub_name_overridden_by_env_var():
"""When SAGEMAKER_HUB_NAME is set, returns the override value."""
with patch.dict(os.environ, {"SAGEMAKER_HUB_NAME": "MyPrivateHub"}):
assert get_sagemaker_hub_name() == "MyPrivateHub"
Loading