diff --git a/src/sagemaker/experiments/_environment.py b/src/sagemaker/experiments/_environment.py index 441661ae5a..4006742875 100644 --- a/src/sagemaker/experiments/_environment.py +++ b/src/sagemaker/experiments/_environment.py @@ -18,12 +18,13 @@ import logging import os +from sagemaker import Session from sagemaker.experiments import trial_component from sagemaker.utils import retry_with_backoff TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN" PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json" -TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH" +TRANSFORM_JOB_ARN_ENV = "TRANSFORM_JOB_ARN" MAX_RETRY_ATTEMPTS = 7 logger = logging.getLogger(__name__) @@ -40,7 +41,7 @@ class _EnvironmentType(enum.Enum): class _RunEnvironment(object): """Retrieves job specific data from the environment.""" - def __init__(self, environment_type, source_arn): + def __init__(self, environment_type: _EnvironmentType, source_arn: str): """Init for _RunEnvironment. Args: @@ -53,9 +54,9 @@ def __init__(self, environment_type, source_arn): @classmethod def load( cls, - training_job_arn_env=TRAINING_JOB_ARN_ENV, - processing_job_config_path=PROCESSING_JOB_CONFIG_PATH, - transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR, + training_job_arn_env: str = TRAINING_JOB_ARN_ENV, + processing_job_config_path: str = PROCESSING_JOB_CONFIG_PATH, + transform_job_arn_env: str = TRANSFORM_JOB_ARN_ENV, ): """Loads source arn of current job from environment. @@ -64,8 +65,8 @@ def load( (default: `TRAINING_JOB_ARN`). processing_job_config_path (str): The processing job config path (default: `/opt/ml/config/processingjobconfig.json`). - transform_job_batch_var (str): The environment variable indicating if - it is a transform job (default: `SAGEMAKER_BATCH`). + transform_job_arn_env (str): The environment key for transform job ARN + (default: `TRANSFORM_JOB_ARN_ENV`). Returns: _RunEnvironment: Job data loaded from the environment. None if config does not exist. @@ -78,16 +79,15 @@ def load( environment_type = _EnvironmentType.SageMakerProcessingJob source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"] return _RunEnvironment(environment_type, source_arn) - if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true": + if transform_job_arn_env in os.environ: environment_type = _EnvironmentType.SageMakerTransformJob - # TODO: need to figure out how to get source_arn from job env - # with Transform team's help. - source_arn = "" + # TODO: need to update to get source_arn from config file once Transform side ready + source_arn = os.environ.get(transform_job_arn_env) return _RunEnvironment(environment_type, source_arn) return None - def get_trial_component(self, sagemaker_session): + def get_trial_component(self, sagemaker_session: Session): """Retrieves the trial component from the job in the environment. Args: @@ -99,14 +99,6 @@ def get_trial_component(self, sagemaker_session): Returns: _TrialComponent: The trial component created from the job. None if not found. """ - # TODO: Remove this condition check once we have a way to retrieve source ARN - # from transform job env - if self.environment_type == _EnvironmentType.SageMakerTransformJob: - logger.error( - "Currently getting the job trial component from the transform job environment " - "is not supported. Returning None." - ) - return None def _get_trial_component(): summaries = list( diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py index f80c43f337..31dd679cc8 100644 --- a/src/sagemaker/experiments/_metrics.py +++ b/src/sagemaker/experiments/_metrics.py @@ -14,7 +14,6 @@ from __future__ import absolute_import import datetime -import json import logging import os import time @@ -35,85 +34,6 @@ logger = logging.getLogger(__name__) -# TODO: remove this _SageMakerFileMetricsWriter class -# when _MetricsManager is fully ready -class _SageMakerFileMetricsWriter(object): - """Write metric data to file.""" - - def __init__(self, metrics_file_path=None): - """Construct a `_SageMakerFileMetricsWriter` object""" - self._metrics_file_path = metrics_file_path - self._file = None - self._closed = False - - def log_metric(self, metric_name, value, timestamp=None, step=None): - """Write a metric to file. - - Args: - metric_name (str): The name of the metric. - value (float): The value of the metric. - timestamp (datetime.datetime): Timestamp of the metric. - If not specified, the current UTC time will be used. - step (int): Iteration number of the metric (default: None). - - Raises: - SageMakerMetricsWriterException: If the metrics file is closed. - AttributeError: If file has been initialized and the writer hasn't been closed. - """ - raw_metric_data = _RawMetricData( - metric_name=metric_name, value=value, timestamp=timestamp, step=step - ) - try: - logger.debug("Writing metric: %s", raw_metric_data) - self._file.write(json.dumps(raw_metric_data.to_record())) - self._file.write("\n") - except AttributeError as attr_err: - if self._closed: - raise SageMakerMetricsWriterException("log_metric called on a closed writer") - if not self._file: - self._file = open(self._get_metrics_file_path(), "a", buffering=1) - self._file.write(json.dumps(raw_metric_data.to_record())) - self._file.write("\n") - else: - raise attr_err - - def close(self): - """Closes the metric file.""" - if not self._closed and self._file: - self._file.close() - self._file = None # invalidate reference, causing subsequent log_metric to fail. - self._closed = True - - def __enter__(self): - """Return self""" - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - """Execute self.close()""" - self.close() - - def __del__(self): - """Execute self.close()""" - self.close() - - def _get_metrics_file_path(self): - """Get file path to store metrics""" - pid_filename = "{}.json".format(str(os.getpid())) - metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename) - logger.debug("metrics_file_path = %s", metrics_file_path) - return metrics_file_path - - -class SageMakerMetricsWriterException(Exception): - """SageMakerMetricsWriterException""" - - def __init__(self, message, errors=None): - """Construct a `SageMakerMetricsWriterException` instance""" - super().__init__(message) - if errors: - self.errors = errors - - class _RawMetricData(object): """A Raw Metric Data Object""" diff --git a/src/sagemaker/experiments/_utils.py b/src/sagemaker/experiments/_utils.py index 5ef5d99dad..d1df535335 100644 --- a/src/sagemaker/experiments/_utils.py +++ b/src/sagemaker/experiments/_utils.py @@ -127,11 +127,9 @@ def get_tc_and_exp_config_from_job_env( num_attempts=4, ) else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob - raise RuntimeError( - "Failed to load the Run as loading experiment config " - "from transform job environment is not currently supported. " - "As a workaround, please explicitly pass in " - "the experiment_name and run_name in load_run." + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_transform_job(job_name), + num_attempts=4, ) job_exp_config = job_response.get("ExperimentConfig", dict()) diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index 69e06419f2..fc87ab7804 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -120,19 +120,18 @@ def __init__( estimator.fit(job_name="my-job") # Create a training job In order to reuse an existing run to log extra data, ``load_run`` is recommended. + For example, instead of the ``Run`` constructor, the ``load_run`` is recommended to use + in a job script to load the existing run created before the job launch. + Otherwise, a new run may be created each time you launch a job. + The code snippet below displays how to load the run initialized above in a custom training job script, where no ``run_name`` or ``experiment_name`` is presented as they are automatically retrieved from the experiment config in the job environment. - Note: - Instead of the ``Run`` constructor, the ``load_run`` is recommended to use - in a job script to load the existing run created before the job launch. - Otherwise, a new run may be created each time you launch a job. - .. code:: python - with load_run() as run: + with load_run(sagemaker_session=sagemaker_session) as run: run.log_metric(...) ... diff --git a/tests/data/experiment/inference.py b/tests/data/experiment/inference.py index cdb9a7b8c6..edfbd013ce 100644 --- a/tests/data/experiment/inference.py +++ b/tests/data/experiment/inference.py @@ -46,6 +46,9 @@ def model_fn(model_dir): run.log_parameters({"p3": 3.0, "p4": 4.0}) run.log_metric("test-job-load-log-metric", 0.1) + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameters({"p5": 5.0, "p6": 6}) + model_file = "xgboost-model" booster = pkl.load(open(os.path.join(model_dir, model_file), "rb")) return booster diff --git a/tests/integ/sagemaker/experiments/test_metrics.py b/tests/integ/sagemaker/experiments/test_metrics.py index 15c0c2f9dc..e621a4d727 100644 --- a/tests/integ/sagemaker/experiments/test_metrics.py +++ b/tests/integ/sagemaker/experiments/test_metrics.py @@ -30,8 +30,7 @@ def verify_metrics(): sagemaker_session=sagemaker_session, ) metrics = updated_tc.metrics - # TODO: revert to len(metrics) == 2 once backend fix reaches prod - assert len(metrics) > 0 + assert len(metrics) == 2 assert list(filter(lambda x: x.metric_name == "test-x-step", metrics)) assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics)) diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py index 96aad30dc0..d57a4f63b9 100644 --- a/tests/integ/sagemaker/experiments/test_run.py +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -444,10 +444,9 @@ def test_run_from_processing_job_and_override_default_exp_config( def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, xgboost_latest_version): # Notes: # 1. The 1st Run (run) created locally - # 2. In the inference script running in a transform job, load the 1st Run - # via explicitly passing the experiment_name and run_name of the 1st Run - # TODO: once we're able to retrieve exp config from the transform job env, - # we should expand this test and add the load_run() without explicitly supplying the names + # 2. In the inference script running in a transform job, load the 1st Run twice and log data + # 1) via explicitly passing the experiment_name and run_name of the 1st Run + # 2) use load_run() without explicitly supplying the names # 3. All data are logged in the Run either locally or in the transform job exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) xgb_model_data_s3 = sagemaker_session.upload_data( @@ -494,6 +493,7 @@ def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, xgboost_latest_v content_type="text/libsvm", split_type="Line", wait=True, + logs=False, job_name=f"transform-job-{name()}", ) @@ -506,7 +506,7 @@ def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, xgboost_latest_v experiment_name=run.experiment_name, run_name=run.run_name ) _check_run_from_job_result( - tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False, has_extra_load=True ) @@ -636,8 +636,7 @@ def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True assert "s3://Input" == tc.input_artifacts[artifact_name].value assert not tc.input_artifacts[artifact_name].media_type - # TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod - assert len(tc.metrics) > 0 + assert len(tc.metrics) == 1 metric_summary = tc.metrics[0] assert metric_summary.metric_name == metric_name assert metric_summary.max == 9.0 @@ -651,9 +650,7 @@ def validate_tc_updated_in_init(): assert tc.status.primary_status == _TrialComponentStatusType.Completed.value assert tc.parameters["p1"] == 1.0 assert tc.parameters["p2"] == 2.0 - # TODO: revert to assert len(tc.metrics) == 5 once - # backend fix hits prod - assert len(tc.metrics) > 0 + assert len(tc.metrics) == 5 for metric_summary in tc.metrics: # metrics deletion is not supported at this point # so its count would accumulate diff --git a/tests/unit/sagemaker/experiments/test_environment.py b/tests/unit/sagemaker/experiments/test_environment.py index 8bb23db7b6..effca7a7f7 100644 --- a/tests/unit/sagemaker/experiments/test_environment.py +++ b/tests/unit/sagemaker/experiments/test_environment.py @@ -21,6 +21,7 @@ import pytest from sagemaker.experiments import _environment +from sagemaker.experiments._environment import TRANSFORM_JOB_ARN_ENV, TRAINING_JOB_ARN_ENV from sagemaker.utils import retry_with_backoff @@ -33,22 +34,22 @@ def tempdir(): @pytest.fixture def training_job_env(): - old_value = os.environ.get("TRAINING_JOB_ARN") - os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe" + old_value = os.environ.get(TRAINING_JOB_ARN_ENV) + os.environ[TRAINING_JOB_ARN_ENV] = "arn:1234aBcDe" yield os.environ - del os.environ["TRAINING_JOB_ARN"] + del os.environ[TRAINING_JOB_ARN_ENV] if old_value: - os.environ["TRAINING_JOB_ARN"] = old_value + os.environ[TRAINING_JOB_ARN_ENV] = old_value @pytest.fixture def transform_job_env(): - old_value = os.environ.get("SAGEMAKER_BATCH") - os.environ["SAGEMAKER_BATCH"] = "true" + old_value = os.environ.get(TRANSFORM_JOB_ARN_ENV) + os.environ[TRANSFORM_JOB_ARN_ENV] = "arn:1234aBcDe" yield os.environ - del os.environ["SAGEMAKER_BATCH"] + del os.environ[TRANSFORM_JOB_ARN_ENV] if old_value: - os.environ["SAGEMAKER_BATCH"] = old_value + os.environ[TRANSFORM_JOB_ARN_ENV] = old_value def test_processing_job_environment(tempdir): @@ -70,8 +71,7 @@ def test_training_job_environment(training_job_env): def test_transform_job_environment(transform_job_env): environment = _environment._RunEnvironment.load() assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type - # TODO: update if we figure out how to get source_arn from the transform job - assert not environment.source_arn + assert "arn:1234aBcDe" == environment.source_arn def test_no_environment(): @@ -100,8 +100,3 @@ def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_j client.list_trial_components.side_effect = Exception("Failed test") environment = _environment._RunEnvironment.load() assert environment.get_trial_component(sagemaker_session) is None - - -def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session): - environment = _environment._RunEnvironment.load() - assert environment.get_trial_component(sagemaker_session) is None diff --git a/tests/unit/sagemaker/experiments/test_metrics.py b/tests/unit/sagemaker/experiments/test_metrics.py index 21556f70fd..969085971a 100644 --- a/tests/unit/sagemaker/experiments/test_metrics.py +++ b/tests/unit/sagemaker/experiments/test_metrics.py @@ -18,14 +18,9 @@ import shutil import datetime import dateutil -import json import time -from sagemaker.experiments._metrics import ( - _RawMetricData, - _SageMakerFileMetricsWriter, - SageMakerMetricsWriterException, -) +from sagemaker.experiments._metrics import _RawMetricData @pytest.fixture @@ -104,75 +99,3 @@ def test_raw_metric_data_invalid_timestamp(): with pytest.raises(ValueError) as error2: _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() + 10000) assert "Timestamps must be between two weeks before and two hours from now" in str(error2) - - -def test_file_metrics_writer_log_metric(timestamp, filepath): - now = datetime.datetime.now(datetime.timezone.utc) - writer = _SageMakerFileMetricsWriter(filepath) - writer.log_metric(metric_name="foo", value=1.0) - writer.log_metric(metric_name="foo", value=2.0, step=1) - writer.log_metric(metric_name="foo", value=3.0, timestamp=timestamp) - writer.log_metric(metric_name="foo", value=4.0, timestamp=timestamp, step=2) - writer.close() - - lines = [x for x in open(filepath).read().split("\n") if x] - [entry_one, entry_two, entry_three, entry_four] = [json.loads(line) for line in lines] - - assert "foo" == entry_one["MetricName"] - assert 1.0 == entry_one["Value"] - assert (now.timestamp() - entry_one["Timestamp"]) < 1 - assert "Step" not in entry_one - - assert 1 == entry_two["Step"] - assert timestamp.timestamp() == entry_three["Timestamp"] - assert 2 == entry_four["Step"] - - -def test_file_metrics_writer_flushes_buffer_every_line_log_metric(filepath): - writer = _SageMakerFileMetricsWriter(filepath) - - writer.log_metric(metric_name="foo", value=1.0) - - lines = [x for x in open(filepath).read().split("\n") if x] - [entry_one] = [json.loads(line) for line in lines] - assert "foo" == entry_one["MetricName"] - assert 1.0 == entry_one["Value"] - - writer.log_metric(metric_name="bar", value=2.0) - lines = [x for x in open(filepath).read().split("\n") if x] - [entry_one, entry_two] = [json.loads(line) for line in lines] - assert "bar" == entry_two["MetricName"] - assert 2.0 == entry_two["Value"] - - writer.log_metric(metric_name="biz", value=3.0) - lines = [x for x in open(filepath).read().split("\n") if x] - [entry_one, entry_two, entry_three] = [json.loads(line) for line in lines] - assert "biz" == entry_three["MetricName"] - assert 3.0 == entry_three["Value"] - - writer.close() - - -def test_file_metrics_writer_context_manager(timestamp, filepath): - with _SageMakerFileMetricsWriter(filepath) as writer: - writer.log_metric("foo", value=1.0, timestamp=timestamp) - entry = json.loads(open(filepath, "r").read().strip()) - assert { - "MetricName": "foo", - "Value": 1.0, - "Timestamp": timestamp.timestamp(), - }.items() <= entry.items() - - -def test_file_metrics_writer_fail_write_on_close(filepath): - writer = _SageMakerFileMetricsWriter(filepath) - writer.log_metric(metric_name="foo", value=1.0) - writer.close() - with pytest.raises(SageMakerMetricsWriterException): - writer.log_metric(metric_name="foo", value=1.0) - - -def test_file_metrics_writer_no_write(filepath): - writer = _SageMakerFileMetricsWriter(filepath) - writer.close() - assert not os.path.exists(filepath) diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py index 0e4ebee181..b41aee75d5 100644 --- a/tests/unit/sagemaker/experiments/test_run.py +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -299,21 +299,45 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session): client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) @patch("sagemaker.experiments.run._RunEnvironment") def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): - # TODO: update this test once figure out how to get source_arn from transform job + client = sagemaker_session.sagemaker_client + job_name = "my-transform-job" rv = unittest.mock.Mock() + rv.source_arn = f"arn:1234/{job_name}" rv.environment_type = _environment._EnvironmentType.SageMakerTransformJob - rv.source_arn = "" mock_run_env.load.return_value = rv - with pytest.raises(RuntimeError) as err: - with load_run(sagemaker_session=sagemaker_session): - pass + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_transform_job.return_value = { + "TransformJobName": "transform-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } - assert ( - "loading experiment config from transform job environment is not currently supported" - ) in str(err) + with load_run(sagemaker_session=sagemaker_session): + pass + + client.describe_transform_job.assert_called_once_with(TransformJobName=job_name) def test_log_parameter_outside_run_context(run_obj):