Skip to content

Commit

Permalink
fix: Move sagemaker pysdk version check after bootstrap in remote job
Browse files Browse the repository at this point in the history
  • Loading branch information
qidewenwhen committed Mar 8, 2024
1 parent 615a8ad commit b97ebb5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def main(sys_args=None):
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")

RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
client_sagemaker_pysdk_version
)

user = getpass.getuser()
if user != "root":
Expand All @@ -89,6 +86,10 @@ def main(sys_args=None):
client_python_version, conda_env, dependency_settings
)

RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
client_sagemaker_pysdk_version
)

exit_code = SUCCESS_EXIT_CODE
except Exception as e: # pylint: disable=broad-except
logger.exception("Error encountered while bootstrapping runtime environment: %s", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import dataclasses
import json

import sagemaker


class _UTCFormatter(logging.Formatter):
"""Class that overrides the default local time provider in log formatter."""
Expand Down Expand Up @@ -330,6 +328,7 @@ def _current_python_version(self):

def _current_sagemaker_pysdk_version(self):
"""Returns the current sagemaker python sdk version where program is running"""
import sagemaker

return sagemaker.__version__

Expand Down Expand Up @@ -366,10 +365,10 @@ def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
):
logger.warning(
"Inconsistent sagemaker versions found: "
"sagemaker pysdk version found in the container is "
"sagemaker python sdk version found in the container is "
"'%s' which does not match the '%s' on the local client. "
"Please make sure that the python version used in the training container "
"is the same as the local python version in case of unexpected behaviors.",
"Please make sure that the sagemaker version used in the training container "
"is the same as the local sagemaker version in case of unexpected behaviors.",
job_sagemaker_pysdk_version,
client_sagemaker_pysdk_version,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,6 @@ def test_main_success_pipeline_step_with_root_user(
_exit_process.assert_called_with(0)


@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -260,7 +256,6 @@ def test_main_failure_remote_job_with_root_user(
write_failure,
_exit_process,
validate_python,
validate_sagemaker,
):
runtime_err = RuntimeEnvironmentError("some failure reason")
bootstrap_runtime.side_effect = runtime_err
Expand All @@ -269,17 +264,12 @@ def test_main_failure_remote_job_with_root_user(

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
run_pre_exec_script.assert_not_called()
bootstrap_runtime.assert_called()
write_failure.assert_called_with(str(runtime_err))
_exit_process.assert_called_with(1)


@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -308,7 +298,6 @@ def test_main_failure_pipeline_step_with_root_user(
write_failure,
_exit_process,
validate_python,
validate_sagemaker,
):
runtime_err = RuntimeEnvironmentError("some failure reason")
bootstrap_runtime.side_effect = runtime_err
Expand All @@ -317,7 +306,6 @@ def test_main_failure_pipeline_step_with_root_user(

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
run_pre_exec_script.assert_not_called()
bootstrap_runtime.assert_called()
write_failure.assert_called_with(str(runtime_err))
Expand Down

0 comments on commit b97ebb5

Please sign in to comment.