Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: populate default config name to model #4617

Merged
merged 14 commits into from
Apr 26, 2024
26 changes: 26 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,31 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
return kwargs


def _add_config_name_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
"""Sets default config name to the kwargs. Returns full kwargs."""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)

return kwargs


def get_deploy_kwargs(
model_id: str,
model_version: Optional[str] = None,
Expand Down Expand Up @@ -808,5 +833,6 @@ def get_init_kwargs(
model_init_kwargs = _add_model_package_arn_to_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
model_init_kwargs = _add_config_name_to_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _validate_model_id_and_type():
self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
self.region = model_init_kwargs.region
self.sagemaker_session = model_init_kwargs.sagemaker_session
self.config_name = config_name
self.config_name = model_init_kwargs.config_name

if self.model_type == JumpStartModelType.PROPRIETARY:
self.log_subscription_warning()
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,10 +1076,12 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
"benchmark_metrics",
"config_components",
"resolved_metadata_config",
"config_name",
]

def __init__(
self,
config_name: str,
base_fields: Dict[str, Any],
config_components: Dict[str, JumpStartConfigComponent],
benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]],
Expand All @@ -1098,6 +1100,7 @@ def __init__(
self.config_components: Dict[str, JumpStartConfigComponent] = config_components
self.benchmark_metrics: Dict[str, List[JumpStartBenchmarkStat]] = benchmark_metrics
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
self.config_name: Optional[str] = config_name

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartMetadataConfig object."""
Expand Down Expand Up @@ -1251,6 +1254,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
inference_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down Expand Up @@ -1303,6 +1307,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
training_configs_dict: Optional[Dict[str, JumpStartMetadataConfig]] = (
{
alias: JumpStartMetadataConfig(
alias,
json_obj,
(
{
Expand Down
21 changes: 18 additions & 3 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,8 @@ def test_model_initialization_with_config_name(

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1594,6 +1596,8 @@ def test_model_set_deployment_config(

model = JumpStartModel(model_id=model_id)

assert model.config_name is None

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand All @@ -1612,6 +1616,8 @@ def test_model_set_deployment_config(
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
model.set_deployment_config("neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1654,6 +1660,8 @@ def test_model_unset_deployment_config(

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1789,7 +1797,6 @@ def test_model_retrieve_deployment_config(
):
model_id, _ = "pytorch-eqa-bert-base-cased", "*"

mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
mock_verify_model_region_and_return_specs.side_effect = (
lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
)
Expand All @@ -1804,15 +1811,23 @@ def test_model_retrieve_deployment_config(
)
mock_model_deploy.return_value = default_predictor

expected = get_base_deployment_configs()[0]
config_name = expected.get("DeploymentConfigName")
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(
model_id, config_name
)

mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id)

expected = get_base_deployment_configs()[0]
model.set_deployment_config(expected.get("DeploymentConfigName"))
model.set_deployment_config(config_name)

self.assertEqual(model.deployment_config, expected)

mock_get_init_kwargs.reset_mock()
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)

# Unset
model.set_deployment_config(None)
self.assertIsNone(model.deployment_config)
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import copy
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
import boto3

from sagemaker.compute_resource_requirements import ResourceRequirements
Expand Down Expand Up @@ -237,7 +237,7 @@ def get_base_spec_with_prototype_configs_with_missing_benchmarks(
copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS)
copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None

inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
inference_configs = {**copy_inference_configs, **INFERENCE_CONFIG_RANKINGS}
training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS}

spec.update(inference_configs)
Expand Down Expand Up @@ -335,7 +335,9 @@ def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, An
return configs


def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
def get_mock_init_kwargs(
model_id: str, config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
return JumpStartModelInitKwargs(
model_id=model_id,
model_type=JumpStartModelType.OPEN_WEIGHTS,
Expand All @@ -344,4 +346,5 @@ def get_mock_init_kwargs(model_id) -> JumpStartModelInitKwargs:
instance_type=INIT_KWARGS.get("instance_type"),
env=INIT_KWARGS.get("env"),
resources=ResourceRequirements(),
config_name=config_name,
)
Loading