Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 127 additions & 5 deletions sagemaker-train/tests/integ/train/test_custom_scorer_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@
"region": "us-west-2",
}

# Base model only evaluation configuration (uses JumpStart model ID directly, no model package)
BASE_MODEL_ONLY_CONFIG = {
"base_model_id": "meta-textgeneration-llama-3-2-1b-instruct",
"evaluator_arn": "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/eval-lambda-test/0.0.1",
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"region": "us-west-2",
}


# @pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
@pytest.mark.xdist_group("custom_scorer_evaluator")
Expand Down Expand Up @@ -288,13 +298,125 @@ def test_custom_scorer_with_builtin_metric(self):
logger.info("Built-in metric evaluation completed successfully")

# @pytest.mark.skip(reason="Base model only evaluation - not working yet per notebook")
@pytest.mark.gpu_intensive
def test_custom_scorer_base_model_only(self):
"""
Test custom scorer evaluation with base model only (no fine-tuned model).

Note: Per the notebook, "Evaluation with Base Model Only is yet to be
implemented/tested - Not Working currently". This test is skipped until
that functionality is available.
This test uses a JumpStart model ID directly instead of a model package ARN,
which triggers the CUSTOM_SCORER_TEMPLATE_BASE_MODEL_ONLY template path.
The evaluation runs against only the base model without any fine-tuned weights.

This test covers:
1. Creating CustomScorerEvaluator with a JumpStart model ID (base model only)
2. Accessing hyperparameters
3. Starting evaluation
4. Monitoring execution
5. Waiting for completion
6. Viewing results
7. Retrieving execution by ARN
"""
logger.info("Base model only evaluation - not yet implemented")
pass
# Step 1: Create CustomScorerEvaluator with JumpStart model ID
logger.info("Creating CustomScorerEvaluator with base model only (JumpStart model ID)")

evaluator = CustomScorerEvaluator(
evaluator=BASE_MODEL_ONLY_CONFIG["evaluator_arn"],
dataset=BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"],
model=BASE_MODEL_ONLY_CONFIG["base_model_id"],
s3_output_path=BASE_MODEL_ONLY_CONFIG["s3_output_path"],
evaluate_base_model=False,
)

# Verify evaluator was created with base model ID
assert evaluator is not None
assert evaluator.evaluator == BASE_MODEL_ONLY_CONFIG["evaluator_arn"]
assert evaluator.model == BASE_MODEL_ONLY_CONFIG["base_model_id"]
assert evaluator.dataset == BASE_MODEL_ONLY_CONFIG["dataset_s3_uri"]

logger.info(f"Created evaluator with base model: {BASE_MODEL_ONLY_CONFIG['base_model_id']}")

# Step 2: Access hyperparameters
logger.info("Accessing hyperparameters")
hyperparams = evaluator.hyperparameters.to_dict()

# Verify hyperparameters structure
assert isinstance(hyperparams, dict)
assert "max_new_tokens" in hyperparams
assert "temperature" in hyperparams

logger.info(f"Hyperparameters: {hyperparams}")

# Step 3: Start evaluation
logger.info("Starting evaluation execution")
execution = evaluator.evaluate()

# Verify execution was created
assert execution is not None
assert execution.arn is not None
assert execution.name is not None
assert execution.eval_type is not None

logger.info(f"Pipeline Execution ARN: {execution.arn}")
logger.info(f"Initial Status: {execution.status.overall_status}")

# Step 4: Monitor execution
logger.info("Refreshing execution status")
execution.refresh()

# Verify status was updated
assert execution.status.overall_status is not None

# Log step details if available
if execution.status.step_details:
logger.info("Step Details:")
for step in execution.status.step_details:
logger.info(f" {step.name}: {step.status}")

# Step 5: Wait for completion
logger.info(f"Waiting for evaluation to complete (timeout: {EVALUATION_TIMEOUT_SECONDS}s / {EVALUATION_TIMEOUT_SECONDS//3600}h)")

try:
execution.wait(target_status="Succeeded", poll=30, timeout=EVALUATION_TIMEOUT_SECONDS)
logger.info(f"Final Status: {execution.status.overall_status}")

# Verify completion
assert execution.status.overall_status == "Succeeded"

# Step 6: View results
logger.info("Displaying results")
execution.show_results()

# Verify S3 output path is set
assert execution.s3_output_path is not None
logger.info(f"Results stored at: {execution.s3_output_path}")

except Exception as e:
logger.error(f"Evaluation failed or timed out: {e}")
logger.error(f"Final status: {execution.status.overall_status}")
if execution.status.failure_reason:
logger.error(f"Failure reason: {execution.status.failure_reason}")

# Log step failures
if execution.status.step_details:
for step in execution.status.step_details:
if "failed" in step.status.lower():
logger.error(f"Failed step: {step.name}")
if step.failure_reason:
logger.error(f" Reason: {step.failure_reason}")

# Re-raise to fail the test
raise

# Step 7: Retrieve execution by ARN
logger.info("Retrieving execution by ARN")
retrieved_execution = EvaluationPipelineExecution.get(
arn=execution.arn,
region=BASE_MODEL_ONLY_CONFIG["region"]
)

# Verify retrieved execution matches
assert retrieved_execution.arn == execution.arn
assert retrieved_execution.status.overall_status == "Succeeded"

logger.info(f"Retrieved execution status: {retrieved_execution.status.overall_status}")
logger.info("Base model only evaluation completed successfully")
Loading