From c9dd68121664d271d8bebcf4c35f0f3e96aab3d9 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Fri, 29 May 2026 15:56:05 -0700 Subject: [PATCH] feat: implement test_custom_scorer_base_model_only integration test Add integration test for CustomScorerEvaluator base-model-only evaluation path. This test uses a JumpStart model ID directly (instead of a model package ARN) to exercise the CUSTOM_SCORER_TEMPLATE_BASE_MODEL_ONLY pipeline template. The test covers evaluator creation, hyperparameter access, evaluation execution, result verification, and execution retrieval. Also added BASE_MODEL_ONLY_CONFIG with dedicated test configuration using the meta-textgeneration-llama-3-2-1b-instruct model and existing test account resources (729646638167). --- .../train/test_custom_scorer_evaluator.py | 132 +++++++++++++++++- 1 file changed, 127 insertions(+), 5 deletions(-) diff --git a/sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py b/sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py index a71af868a7..363cec14ba 100644 --- a/sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py +++ b/sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py @@ -54,6 +54,16 @@ "region": "us-west-2", } +# Base model only evaluation configuration (uses JumpStart model ID directly, no model package) +BASE_MODEL_ONLY_CONFIG = { + "base_model_id": "meta-textgeneration-llama-3-2-1b-instruct", + "evaluator_arn": "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/eval-lambda-test/0.0.1", + "dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl", + "s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/", + "mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX", + "region": "us-west-2", +} + # @pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/") @pytest.mark.xdist_group("custom_scorer_evaluator") @@ -288,13 +298,125 @@ def test_custom_scorer_with_builtin_metric(self): logger.info("Built-in metric evaluation completed successfully") # @pytest.mark.skip(reason="Base model only evaluation - not working yet per notebook") + @pytest.mark.gpu_intensive def test_custom_scorer_base_model_only(self): """ Test custom scorer evaluation with base model only (no fine-tuned model). - Note: Per the notebook, "Evaluation with Base Model Only is yet to be - implemented/tested - Not Working currently". This test is skipped until - that functionality is available. + This test uses a JumpStart model ID directly instead of a model package ARN, + which triggers the CUSTOM_SCORER_TEMPLATE_BASE_MODEL_ONLY template path. + The evaluation runs against only the base model without any fine-tuned weights. + + This test covers: + 1. Creating CustomScorerEvaluator with a JumpStart model ID (base model only) + 2. Accessing hyperparameters + 3. Starting evaluation + 4. Monitoring execution + 5. Waiting for completion + 6. Viewing results + 7. Retrieving execution by ARN """ - logger.info("Base model only evaluation - not yet implemented") - pass + # Step 1: Create CustomScorerEvaluator with JumpStart model ID + logger.info("Creating CustomScorerEvaluator with base model only (JumpStart model ID)") + + evaluator = CustomScorerEvaluator( + evaluator=BASE_MODEL_ONLY_CONFIG["evaluator_arn"], + dataset=BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"], + model=BASE_MODEL_ONLY_CONFIG["base_model_id"], + s3_output_path=BASE_MODEL_ONLY_CONFIG["s3_output_path"], + evaluate_base_model=False, + ) + + # Verify evaluator was created with base model ID + assert evaluator is not None + assert evaluator.evaluator == BASE_MODEL_ONLY_CONFIG["evaluator_arn"] + assert evaluator.model == BASE_MODEL_ONLY_CONFIG["base_model_id"] + assert evaluator.dataset == BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"] + + logger.info(f"Created evaluator with base model: {BASE_MODEL_ONLY_CONFIG['base_model_id']}") + + # Step 2: Access hyperparameters + logger.info("Accessing hyperparameters") + hyperparams = evaluator.hyperparameters.to_dict() + + # Verify hyperparameters structure + assert isinstance(hyperparams, dict) + assert "max_new_tokens" in hyperparams + assert "temperature" in hyperparams + + logger.info(f"Hyperparameters: {hyperparams}") + + # Step 3: Start evaluation + logger.info("Starting evaluation execution") + execution = evaluator.evaluate() + + # Verify execution was created + assert execution is not None + assert execution.arn is not None + assert execution.name is not None + assert execution.eval_type is not None + + logger.info(f"Pipeline Execution ARN: {execution.arn}") + logger.info(f"Initial Status: {execution.status.overall_status}") + + # Step 4: Monitor execution + logger.info("Refreshing execution status") + execution.refresh() + + # Verify status was updated + assert execution.status.overall_status is not None + + # Log step details if available + if execution.status.step_details: + logger.info("Step Details:") + for step in execution.status.step_details: + logger.info(f" {step.name}: {step.status}") + + # Step 5: Wait for completion + logger.info(f"Waiting for evaluation to complete (timeout: {EVALUATION_TIMEOUT_SECONDS}s / {EVALUATION_TIMEOUT_SECONDS//3600}h)") + + try: + execution.wait(target_status="Succeeded", poll=30, timeout=EVALUATION_TIMEOUT_SECONDS) + logger.info(f"Final Status: {execution.status.overall_status}") + + # Verify completion + assert execution.status.overall_status == "Succeeded" + + # Step 6: View results + logger.info("Displaying results") + execution.show_results() + + # Verify S3 output path is set + assert execution.s3_output_path is not None + logger.info(f"Results stored at: {execution.s3_output_path}") + + except Exception as e: + logger.error(f"Evaluation failed or timed out: {e}") + logger.error(f"Final status: {execution.status.overall_status}") + if execution.status.failure_reason: + logger.error(f"Failure reason: {execution.status.failure_reason}") + + # Log step failures + if execution.status.step_details: + for step in execution.status.step_details: + if "failed" in step.status.lower(): + logger.error(f"Failed step: {step.name}") + if step.failure_reason: + logger.error(f" Reason: {step.failure_reason}") + + # Re-raise to fail the test + raise + + # Step 7: Retrieve execution by ARN + logger.info("Retrieving execution by ARN") + retrieved_execution = EvaluationPipelineExecution.get( + arn=execution.arn, + region=BASE_MODEL_ONLY_CONFIG["region"] + ) + + # Verify retrieved execution matches + assert retrieved_execution.arn == execution.arn + assert retrieved_execution.status.overall_status == "Succeeded" + + logger.info(f"Retrieved execution status: {retrieved_execution.status.overall_status}") + logger.info("Base model only evaluation completed successfully")