Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
PROCESSING_JOB_ROLE_ARN_PATH,
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
)
from sagemaker.fw_utils import UploadedCode
from sagemaker.job import _Job
from sagemaker.local import LocalSession
from sagemaker.network import NetworkConfig
Expand Down Expand Up @@ -298,7 +299,8 @@ def _normalize_args(
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
code (str): This can be an S3 URI or a local path to a file with the framework
code (str or :class:`~sagemaker.processing.fw_utils.UploadedCode`): This can be
an S3 URI or a local path to a file with the framework
script to run (default: None). A no op in the base class.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
Expand Down Expand Up @@ -615,7 +617,7 @@ def get_run_args(
@runnable_by_pipeline
def run(
self,
code: str,
code: Union[str, "UploadedCode"],
inputs: Optional[List["ProcessingInput"]] = None,
outputs: Optional[List["ProcessingOutput"]] = None,
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
Expand All @@ -628,8 +630,8 @@ def run(
"""Runs a processing job.

Args:
code (str): This can be an S3 URI or a local path to
a file with the framework script to run.
code (str or :class:`~sagemaker.processing.fw_utils.UploadedCode`): This can be
an S3 URI or a local path to a file with the framework script to run.
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
Expand Down Expand Up @@ -683,7 +685,9 @@ def run(
if wait:
self.latest_job.wait(logs=logs)

def _include_code_in_inputs(self, inputs, code, kms_key=None):
def _include_code_in_inputs(
self, inputs: List["ProcessingInput"], code: Union[str, UploadedCode], kms_key=None
):
"""Converts code to appropriate input and includes in input list.

Side effects include:
Expand All @@ -694,7 +698,7 @@ def _include_code_in_inputs(self, inputs, code, kms_key=None):
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects.
code (str): This can be an S3 URI or a local path to a file with the framework
code (str or UploadedCode): This can be an S3 URI or a local path to a file with the framework
script to run (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
Expand All @@ -703,10 +707,15 @@ def _include_code_in_inputs(self, inputs, code, kms_key=None):
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the
code as `ProcessingInput`.
"""
user_code_s3_uri = self._handle_user_code_url(code, kms_key)
user_script_name = self._get_user_code_name(code)

inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)
if isinstance(code, UploadedCode):
# If the caller supplied an UploadedCode, then that indicates the
# code has already been uploaded (presumably via a ProcessingInput)
inputs_with_code = inputs or []
user_script_name = code.script_name
else:
user_code_s3_uri = self._handle_user_code_url(code, kms_key)
user_script_name = self._get_user_code_name(code)
inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)

self._set_entrypoint(self.command, user_script_name)
return inputs_with_code
Expand Down Expand Up @@ -1600,7 +1609,7 @@ def get_run_args(
code (str): This can be an S3 URI or a local path to a file with the framework
script to run. See the ``code`` argument in
`sagemaker.processing.FrameworkProcessor.run()`.
source_dir (str): Path (absolute, relative, or an S3 URI) to a directory wit
source_dir (str): Path (absolute, relative, or an S3 URI) to a directory with
any other processing source code dependencies aside from the entrypoint
file (default: None). See the ``source_dir`` argument in
`sagemaker.processing.FrameworkProcessor.run()`
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from sagemaker import Session
from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.fw_utils import UploadedCode
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
from sagemaker.model import Model
from sagemaker.pipeline import PipelineModel
Expand Down Expand Up @@ -765,7 +766,7 @@ def __init__(
inputs: List[ProcessingInput] = None,
outputs: List[ProcessingOutput] = None,
job_arguments: List[str] = None,
code: str = None,
code: Union[str, UploadedCode] = None,
property_files: List[PropertyFile] = None,
cache_config: CacheConfig = None,
depends_on: Optional[List[Union[str, Step, "StepCollection"]]] = None,
Expand All @@ -789,7 +790,7 @@ def __init__(
instances. Defaults to `None`.
job_arguments (List[str]): A list of strings to be passed into the processing job.
Defaults to `None`.
code (str): This can be an S3 URI or a local path to a file with the framework
code (str or ProcessingInput): This can be an S3 URI or a local path to a file with the framework
script to run. Defaults to `None`.
property_files (List[PropertyFile]): A list of property files that workflow looks
for and resolves from the configured processing output list.
Expand Down Expand Up @@ -835,7 +836,7 @@ def __init__(
# arguments attribute. Refactor `Processor.run`, if possible.
self.processor.arguments = job_arguments

if code:
if code and not isinstance(code, UploadedCode):
if is_pipeline_variable(code):
raise ValueError(
"code argument has to be a valid S3 URI or local file path "
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from contextlib import contextmanager
from _hashlib import HASH as Hash

from sagemaker.fw_utils import UploadedCode
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.pipeline_context import _StepArguments, _PipelineConfig
from sagemaker.workflow.entities import (
Expand Down Expand Up @@ -168,19 +169,18 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
Returns:
str: A hash string representing the unique code artifact(s) for the step
"""

# FrameworkProcessor
if source_dir:
source_dir_url = urlparse(source_dir)
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
# Include code in the hash when possible
if code:
if code and not isinstance(code, UploadedCode):
code_url = urlparse(code)
if code_url.scheme == "" or code_url.scheme == "file":
return hash_files_or_dirs([code, source_dir] + dependencies)
return hash_files_or_dirs([source_dir] + dependencies)
# Other Processors - Spark, Script, Base, etc.
if code:
if code and not isinstance(code, UploadedCode):
code_url = urlparse(code)
if code_url.scheme == "" or code_url.scheme == "file":
return hash_files_or_dirs([code] + dependencies)
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/sagemaker/workflow/test_processing_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from copy import deepcopy

from sagemaker.estimator import Estimator
from sagemaker.fw_utils import UploadedCode
from sagemaker.parameter import IntegerParameter
from sagemaker.transformer import Transformer
from sagemaker.tuner import HyperparameterTuner
Expand Down Expand Up @@ -70,6 +71,10 @@

DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "workflow/abalone/preprocessing.py")
UPLOADED_SCRIPT_INPUT = UploadedCode(
s3_prefix=DUMMY_S3_SCRIPT_PATH,
script_name="/opt/ml/processing/code/script_from_s3.py",
)
SPARK_APP_JAR_PATH = os.path.join(
DATA_DIR, "spark/code/java/hello-java-spark/HelloJavaSparkApp.jar"
)
Expand Down Expand Up @@ -379,7 +384,9 @@ def test_processing_step_with_processor_and_step_args(


@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
@pytest.mark.parametrize("code_artifact", [DUMMY_S3_SCRIPT_PATH, LOCAL_SCRIPT_PATH])
@pytest.mark.parametrize(
"code_artifact", [DUMMY_S3_SCRIPT_PATH, LOCAL_SCRIPT_PATH, UPLOADED_SCRIPT_INPUT]
)
def test_processing_step_with_script_processor(
pipeline_session, processing_input, network_config, code_artifact
):
Expand Down