diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 65311b2385..1c3f09a43e 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -58,10 +58,10 @@ def __init__(self, sagemaker_session=None): Args: sagemaker_session: SageMaker session to use for API calls. - If None, will be created with beta endpoint if configured. + If None, will be created with endpoint if configured. """ self.sagemaker_session = sagemaker_session - self._beta_endpoint = os.environ.get('SAGEMAKER_ENDPOINT') + self._endpoint = os.environ.get('SAGEMAKER_ENDPOINT') def resolve_model_info( self, @@ -188,8 +188,21 @@ def _resolve_model_package_object(self, model_package: 'ModelPackage') -> _Model base_model_name = hub_content_name if hasattr(container.base_model, 'hub_content_arn'): base_model_arn = container.base_model.hub_content_arn + + # If hub_content_arn is not present, construct it from hub_content_name and version + if not base_model_arn and hasattr(container.base_model, 'hub_content_version'): + hub_content_version = container.base_model.hub_content_version + model_pkg_arn = getattr(model_package, 'model_package_arn', None) + + if hub_content_name and hub_content_version and model_pkg_arn: + # Extract region from model package ARN + arn_parts = model_pkg_arn.split(':') + if len(arn_parts) >= 4: + region = arn_parts[3] + # Construct hub content ARN for SageMaker public hub + base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}" - # If we couldn't extract base model ARN, this is not a supported model package + # If we couldn't extract or construct base model ARN, this is not a supported model package if not base_model_arn: raise ValueError( f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. " @@ -234,82 +247,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo: # Validate ARN format self._validate_model_package_arn(model_package_arn) - # TODO: Switch to sagemaker_core ModelPackage.get() once the bug is fixed - # Currently, ModelPackage.get() has a Pydantic validation issue where - # the transform() function doesn't include model_package_name in the response, - # causing: "1 validation error for ModelPackage - model_package_name: Field required" - # Using boto3 directly as a workaround. - - # Use the sagemaker client from the session (which has the correct endpoint configured) - sm_client = session.sagemaker_client if hasattr(session, 'sagemaker_client') else session.boto_session.client('sagemaker') - response = sm_client.describe_model_package(ModelPackageName=model_package_arn) - - # Extract base model info from response - base_model_name = None - base_model_arn = None - hub_content_name = None + # Use sagemaker.core ModelPackage.get() to retrieve model package information + from sagemaker.core.resources import ModelPackage - # Check inference specification - if 'InferenceSpecification' not in response: - raise ValueError( - f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. " - f"The provided model package (ARN: {model_package_arn}) " - f"does not have an inference_specification." - ) + import logging + logger = logging.getLogger(__name__) - inf_spec = response['InferenceSpecification'] - if 'Containers' not in inf_spec or len(inf_spec['Containers']) == 0: - raise ValueError( - f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. " - f"The provided model package (ARN: {model_package_arn}) " - f"does not have any containers in its inference_specification." - ) - - container = inf_spec['Containers'][0] - - # Extract base model info - if 'BaseModel' not in container: - raise ValueError( - f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. " - f"The provided model package (ARN: {model_package_arn}) " - f"does not have base_model metadata in its inference_specification.containers[0]. " - f"Please ensure the model was created using SageMaker's fine-tuning capabilities." - ) - - base_model_info = container['BaseModel'] - hub_content_name = base_model_info.get('HubContentName') - hub_content_version = base_model_info.get('HubContentVersion') - base_model_arn = base_model_info.get('HubContentArn') - - # If HubContentArn is None, construct it from HubContentName and version - # This handles cases where the API doesn't return the full ARN - if not base_model_arn and hub_content_name and hub_content_version: - # Extract region from model_package_arn - arn_parts = model_package_arn.split(':') - if len(arn_parts) >= 4: - region = arn_parts[3] - # Construct hub content ARN for SageMaker public hub - base_model_arn = f"arn:aws:sagemaker:{region}:aws:hub-content/SageMakerPublicHub/Model/{hub_content_name}/{hub_content_version}" - - if not base_model_arn: - raise ValueError( - f"NotSupported: Evaluation is only supported for model packages customized by SageMaker's fine-tuning flows. " - f"The provided model package (ARN: {model_package_arn}) " - f"does not have base_model metadata with HubContentArn or sufficient information to construct it. " - f"Please ensure the model was created using SageMaker's fine-tuning capabilities." - ) + # Get the model package using sagemaker.core + model_package = ModelPackage.get( + model_package_name=model_package_arn, + session=session.boto_session, + region=session.boto_session.region_name + ) - # Use hub_content_name as base_model_name - base_model_name = hub_content_name if hub_content_name else response.get('ModelPackageGroupName', 'unknown') + logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}") - return _ModelInfo( - base_model_name=base_model_name, - base_model_arn=base_model_arn, - source_model_package_arn=model_package_arn, - model_type=_ModelType.FINE_TUNED, - hub_content_name=hub_content_name, - additional_metadata={} - ) + # Now use the existing _resolve_model_package_object method to extract base model info + return self._resolve_model_package_object(model_package) except ValueError: # Re-raise ValueError as-is (our custom error messages) @@ -342,7 +296,7 @@ def _validate_model_package_arn(self, arn: str) -> bool: def _get_session(self): """ - Get or create SageMaker session with beta endpoint support. + Get or create SageMaker session with endpoint support. Returns: SageMaker session @@ -352,12 +306,11 @@ def _get_session(self): from sagemaker.core.helper.session_helper import Session - # Check for beta endpoint in environment variable - if self._beta_endpoint: + # Check for endpoint in environment variable + if self._endpoint: sm_client = boto3.client( 'sagemaker', - endpoint_url=self._beta_endpoint, - region_name=os.environ.get('AWS_REGION', 'us-west-2') + endpoint_url=self._endpoint ) return Session(sagemaker_client=sm_client) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index df68711dc7..620b7ffe34 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -546,6 +546,8 @@ def _get_or_create_artifact_arn(self, source_uri: str, region: str) -> str: properties['HubContentArn'] = source_uri else: properties['SourceUri'] = source_uri + + _logger.info(f"source_uri: {source_uri}, region: {region}, properties: {properties}") # Create artifact using Artifact.create() artifact = Artifact.create( diff --git a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py index cf5da3a2e8..290a6f80ba 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/custom_scorer_evaluator.py @@ -308,6 +308,10 @@ def _get_custom_scorer_template_additions(self, evaluator_config: dict) -> dict: 'evaluator_arn': evaluator_config['evaluator_arn'], } + # Add lambda_type for Nova models + if is_nova: + custom_scorer_context['lambda_type'] = 'rft' + # Add preset_reward_function if present if evaluator_config['preset_reward_function']: custom_scorer_context['preset_reward_function'] = evaluator_config['preset_reward_function'] diff --git a/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py b/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py index c3bbc20518..ea5b10b5ed 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/pipeline_templates.py @@ -632,7 +632,8 @@ "task": "{{ task }}", "strategy": "{{ strategy }}"{% if metric is defined %}, "metric": "{{ metric }}"{% elif evaluation_metric is defined %}, - "evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %}, + "evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %}, + "lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %}, "max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %}, "temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %}, "top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %}, @@ -694,7 +695,8 @@ "task": "{{ task }}", "strategy": "{{ strategy }}"{% if metric is defined %}, "metric": "{{ metric }}"{% elif evaluation_metric is defined %}, - "evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %}, + "evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %}, + "lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %}, "max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %}, "temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %}, "top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %}, @@ -872,7 +874,8 @@ "task": "{{ task }}", "strategy": "{{ strategy }}"{% if metric is defined %}, "metric": "{{ metric }}"{% elif evaluation_metric is defined %}, - "evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if max_new_tokens is defined %}, + "evaluation_metric": "{{ evaluation_metric }}"{% endif %}{% if lambda_type is defined %}, + "lambda_type": "{{ lambda_type }}"{% endif %}{% if max_new_tokens is defined %}, "max_new_tokens": "{{ max_new_tokens }}"{% endif %}{% if temperature is defined %}, "temperature": "{{ temperature }}"{% endif %}{% if top_k is defined %}, "top_k": "{{ top_k }}"{% endif %}{% if top_p is defined %}, diff --git a/sagemaker-train/tests/integ/train/__init__.py b/sagemaker-train/tests/integ/train/__init__.py index e69de29bb2..c35f2f33e9 100644 --- a/sagemaker-train/tests/integ/train/__init__.py +++ b/sagemaker-train/tests/integ/train/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for SageMaker Modules Evaluate""" +from __future__ import absolute_import diff --git a/tests/integ/sagemaker/modules/evaluate/test_benchmark_evaluator.py b/sagemaker-train/tests/integ/train/test_benchmark_evaluator.py similarity index 99% rename from tests/integ/sagemaker/modules/evaluate/test_benchmark_evaluator.py rename to sagemaker-train/tests/integ/train/test_benchmark_evaluator.py index ac8cfb18f9..ee6fa631ec 100644 --- a/tests/integ/sagemaker/modules/evaluate/test_benchmark_evaluator.py +++ b/sagemaker-train/tests/integ/train/test_benchmark_evaluator.py @@ -72,6 +72,7 @@ } +@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/") class TestBenchmarkEvaluatorIntegration: """Integration tests for BenchmarkEvaluator with fine-tuned model package""" diff --git a/tests/integ/sagemaker/modules/evaluate/test_custom_scorer_evaluator.py b/sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py similarity index 99% rename from tests/integ/sagemaker/modules/evaluate/test_custom_scorer_evaluator.py rename to sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py index 63aac99356..0af4ca1838 100644 --- a/tests/integ/sagemaker/modules/evaluate/test_custom_scorer_evaluator.py +++ b/sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py @@ -55,6 +55,7 @@ } +@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/") class TestCustomScorerEvaluatorIntegration: """Integration tests for CustomScorerEvaluator with custom evaluator""" diff --git a/tests/integ/sagemaker/modules/evaluate/test_llm_as_judge_evaluator.py b/sagemaker-train/tests/integ/train/test_llm_as_judge_evaluator.py similarity index 99% rename from tests/integ/sagemaker/modules/evaluate/test_llm_as_judge_evaluator.py rename to sagemaker-train/tests/integ/train/test_llm_as_judge_evaluator.py index a6ecaf56a1..49a68c22d9 100644 --- a/tests/integ/sagemaker/modules/evaluate/test_llm_as_judge_evaluator.py +++ b/sagemaker-train/tests/integ/train/test_llm_as_judge_evaluator.py @@ -84,6 +84,7 @@ } +@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/") class TestLLMAsJudgeEvaluatorIntegration: """Integration tests for LLMAsJudgeEvaluator""" diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index 31440f29e5..d0cc5990a8 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -73,7 +73,7 @@ def test_resolver_with_session(self): def test_resolver_with_beta_endpoint(self): """Test ModelResolver detects beta endpoint.""" resolver = _ModelResolver() - assert resolver._beta_endpoint == 'https://beta.endpoint' + assert resolver._endpoint == 'https://beta.endpoint' class TestResolveModelInfo: @@ -307,33 +307,34 @@ def test_resolve_package_fallback_name(self): class TestResolveModelPackageArn: """Tests for _resolve_model_package_arn method.""" + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_success(self, mock_validate, mock_get_session): - """Test successful ARN resolution.""" + def test_resolve_arn_success(self, mock_validate, mock_get_session, mock_model_package_class): + """Test successful ARN resolution using ModelPackage.get().""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" - # Mock session and client + # Mock session mock_session = MagicMock() - mock_client = MagicMock() - mock_session.sagemaker_client = mock_client + mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock describe_model_package response - mock_client.describe_model_package.return_value = { - 'InferenceSpecification': { - 'Containers': [ - { - 'BaseModel': { - 'HubContentName': 'base-model', - 'HubContentVersion': '1.0', - 'HubContentArn': 'arn:aws:sagemaker:us-west-2:aws:hub-content/base' - } - } - ] - }, - 'ModelPackageGroupName': 'my-model' - } + # Mock ModelPackage.get() return value + mock_package = MagicMock() + mock_package.model_package_arn = arn + + # Mock inference specification with hub_content_arn + mock_container = MagicMock() + mock_base_model = MagicMock() + mock_base_model.hub_content_name = 'base-model' + mock_base_model.hub_content_version = '1.0' + mock_base_model.hub_content_arn = 'arn:aws:sagemaker:us-west-2:aws:hub-content/base' + mock_container.base_model = mock_base_model + + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + + mock_model_package_class.get.return_value = mock_package resolver = _ModelResolver() result = resolver._resolve_model_package_arn(arn) @@ -342,73 +343,96 @@ def test_resolve_arn_success(self, mock_validate, mock_get_session): assert result.hub_content_name == "base-model" assert result.source_model_package_arn == arn assert result.model_type == _ModelType.FINE_TUNED + mock_model_package_class.get.assert_called_once_with( + model_package_name=arn, + session=mock_session.boto_session, + region='us-west-2' + ) + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_session): + def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_session, mock_model_package_class): """Test ARN resolution when HubContentArn needs to be constructed.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" + # Mock session mock_session = MagicMock() - mock_client = MagicMock() - mock_session.sagemaker_client = mock_client + mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock response without HubContentArn - mock_client.describe_model_package.return_value = { - 'InferenceSpecification': { - 'Containers': [ - { - 'BaseModel': { - 'HubContentName': 'base-model', - 'HubContentVersion': '1.0' - } - } - ] - } - } + # Mock ModelPackage without hub_content_arn (needs to be constructed) + mock_package = MagicMock() + mock_package.model_package_arn = arn + + mock_container = MagicMock() + mock_base_model = MagicMock() + mock_base_model.hub_content_name = 'base-model' + mock_base_model.hub_content_version = '1.0' + mock_base_model.hub_content_arn = None # Not provided, needs construction + mock_container.base_model = mock_base_model + + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + + mock_model_package_class.get.return_value = mock_package resolver = _ModelResolver() result = resolver._resolve_model_package_arn(arn) - # Should construct ARN from region and hub content name + # Should construct ARN from region and hub content name/version expected_arn = "arn:aws:sagemaker:us-west-2:aws:hub-content/SageMakerPublicHub/Model/base-model/1.0" assert result.base_model_arn == expected_arn + assert result.base_model_name == "base-model" + assert result.hub_content_name == "base-model" + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session): + def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session, mock_model_package_class): """Test error when InferenceSpecification is missing.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" + # Mock session mock_session = MagicMock() - mock_client = MagicMock() - mock_session.sagemaker_client = mock_client + mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - mock_client.describe_model_package.return_value = {} + # Mock ModelPackage without inference_specification + mock_package = MagicMock() + mock_package.model_package_arn = arn + mock_package.inference_specification = None + + mock_model_package_class.get.return_value = mock_package resolver = _ModelResolver() with pytest.raises(ValueError, match="NotSupported.*does not have an inference_specification"): resolver._resolve_model_package_arn(arn) + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session): + def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session, mock_model_package_class): """Test error when BaseModel is missing.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" + # Mock session mock_session = MagicMock() - mock_client = MagicMock() - mock_session.sagemaker_client = mock_client + mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - mock_client.describe_model_package.return_value = { - 'InferenceSpecification': { - 'Containers': [{}] - } - } + # Mock ModelPackage with container but no base_model + mock_package = MagicMock() + mock_package.model_package_arn = arn + + mock_container = MagicMock() + mock_container.base_model = None + + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + + mock_model_package_class.get.return_value = mock_package resolver = _ModelResolver() @@ -465,7 +489,7 @@ def test_get_default_session(self, mock_session_class): assert result == mock_session mock_session_class.assert_called_once() - @patch.dict(os.environ, {'SAGEMAKER_ENDPOINT': 'https://beta.endpoint', 'AWS_REGION': 'us-east-1'}) + @patch.dict(os.environ, {'SAGEMAKER_ENDPOINT': 'https://beta.endpoint'}) @patch('boto3.client') @patch('sagemaker.core.helper.session_helper.Session') def test_get_session_with_beta_endpoint(self, mock_session_class, mock_boto_client): @@ -481,8 +505,7 @@ def test_get_session_with_beta_endpoint(self, mock_session_class, mock_boto_clie mock_boto_client.assert_called_once_with( 'sagemaker', - endpoint_url='https://beta.endpoint', - region_name='us-east-1' + endpoint_url='https://beta.endpoint' ) mock_session_class.assert_called_once_with(sagemaker_client=mock_sm_client) diff --git a/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py index 73a202cebc..1f37632903 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_custom_scorer_evaluator.py @@ -1016,3 +1016,116 @@ def test_custom_scorer_evaluator_get_custom_scorer_template_additions_custom_arn assert additions['postprocessing'] == 'True' # Verify aggregation is set from configured params assert additions['aggregation'] == 'median' + + +@patch('sagemaker.train.common_utils.recipe_utils._is_nova_model') +@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn') +@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options') +@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params') +@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') +@patch('sagemaker.core.resources.Artifact') +def test_custom_scorer_evaluator_lambda_type_for_nova_models( + mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_resolve_mlflow, mock_is_nova +): + """Test that lambda_type is added for Nova models.""" + mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN + mock_info = Mock() + mock_info.base_model_name = "nova-textgeneration-micro" + mock_info.base_model_arn = "arn:aws:sagemaker:us-west-2:aws:hub-content/SageMakerPublicHub/Model/nova-textgeneration-micro/1.0.0" + mock_info.source_model_package_arn = None + mock_resolve.return_value = mock_info + + mock_artifact.get_all.return_value = iter([]) + mock_artifact_instance = Mock() + mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN + mock_artifact.create.return_value = mock_artifact_instance + + mock_session = Mock() + mock_session.boto_region_name = DEFAULT_REGION + mock_session.boto_session = Mock() + mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE + mock_session.sagemaker_config = None + + # Mock recipe utils + mock_get_params.return_value = {'temperature': 0.7} + mock_extract_options.return_value = {'temperature': {'value': 0.7}} + + # Mock is_nova_model to return True + mock_is_nova.return_value = True + + evaluator = CustomScorerEvaluator( + evaluator=_BuiltInMetric.PRIME_MATH, + dataset=DEFAULT_DATASET, + model="nova-textgeneration-micro", + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + + evaluator_config = {'evaluator_arn': None, 'preset_reward_function': 'prime_math'} + additions = evaluator._get_custom_scorer_template_additions(evaluator_config) + + # Verify lambda_type is present for Nova models + assert 'lambda_type' in additions + assert additions['lambda_type'] == 'rft' + # Verify 'metric' key is used instead of 'evaluation_metric' for Nova + assert 'metric' in additions + assert additions['metric'] == 'all' + assert 'evaluation_metric' not in additions + + +@patch('sagemaker.train.common_utils.recipe_utils._is_nova_model') +@patch('sagemaker.train.common_utils.finetune_utils._resolve_mlflow_resource_arn') +@patch('sagemaker.train.common_utils.recipe_utils._extract_eval_override_options') +@patch('sagemaker.train.common_utils.recipe_utils._get_evaluation_override_params') +@patch('sagemaker.train.common_utils.model_resolution._resolve_base_model') +@patch('sagemaker.core.resources.Artifact') +def test_custom_scorer_evaluator_no_lambda_type_for_non_nova_models( + mock_artifact, mock_resolve, mock_get_params, mock_extract_options, mock_resolve_mlflow, mock_is_nova +): + """Test that lambda_type is NOT added for non-Nova models.""" + mock_resolve_mlflow.return_value = DEFAULT_MLFLOW_ARN + mock_info = Mock() + mock_info.base_model_name = DEFAULT_MODEL + mock_info.base_model_arn = DEFAULT_BASE_MODEL_ARN + mock_info.source_model_package_arn = None + mock_resolve.return_value = mock_info + + mock_artifact.get_all.return_value = iter([]) + mock_artifact_instance = Mock() + mock_artifact_instance.artifact_arn = DEFAULT_ARTIFACT_ARN + mock_artifact.create.return_value = mock_artifact_instance + + mock_session = Mock() + mock_session.boto_region_name = DEFAULT_REGION + mock_session.boto_session = Mock() + mock_session.get_caller_identity_arn.return_value = DEFAULT_ROLE + mock_session.sagemaker_config = None + + # Mock recipe utils + mock_get_params.return_value = {'temperature': 0.7} + mock_extract_options.return_value = {'temperature': {'value': 0.7}} + + # Mock is_nova_model to return False + mock_is_nova.return_value = False + + evaluator = CustomScorerEvaluator( + evaluator=_BuiltInMetric.PRIME_MATH, + dataset=DEFAULT_DATASET, + model=DEFAULT_MODEL, + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + + evaluator_config = {'evaluator_arn': None, 'preset_reward_function': 'prime_math'} + additions = evaluator._get_custom_scorer_template_additions(evaluator_config) + + # Verify lambda_type is NOT present for non-Nova models + assert 'lambda_type' not in additions + # Verify 'evaluation_metric' key is used instead of 'metric' for non-Nova + assert 'evaluation_metric' in additions + assert additions['evaluation_metric'] == 'all' + assert 'metric' not in additions diff --git a/tests/integ/sagemaker/modules/evaluate/__init__.py b/tests/integ/sagemaker/modules/evaluate/__init__.py deleted file mode 100644 index c35f2f33e9..0000000000 --- a/tests/integ/sagemaker/modules/evaluate/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""Integration tests for SageMaker Modules Evaluate""" -from __future__ import absolute_import