In [2]:
import os
import boto3
import sagemaker
import sagemaker.session
from sagemaker.estimator import Estimator
from sagemaker.inputs import TrainingInput
from sagemaker.model_metrics import MetricsSource, ModelMetrics
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
from sagemaker.pytorch.processing import PyTorchProcessor
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
from sagemaker.workflow.functions import JsonGet
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.properties import PropertyFile
from sagemaker.workflow.step_collections import RegisterModel
from sagemaker.workflow.steps import ProcessingStep, TrainingStep, CacheConfig

In [10]:
BUCKET_NAME = "aravind-aws-ml-sagemaker"
def get_session(region, default_bucket=BUCKET_NAME):
  
    boto_session = boto3.Session(region_name=region)

    sagemaker_client = boto_session.client("sagemaker")
    runtime_client = boto_session.client("sagemaker-runtime")
    return sagemaker.session.Session(
        boto_session=boto_session,
        sagemaker_client=sagemaker_client,
        sagemaker_runtime_client=runtime_client,
        default_bucket=default_bucket,
    )

sagemaker_session = get_session("us-east-1")
# if role is None:
role = sagemaker.session.get_execution_role(sagemaker_session)

# Parameters for preprocessing pipeline execution
print("Setting parameters for pipeline")
processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
processing_instance_type = ParameterString(
    name="ProcessingInstanceType", default_value="ml.m5.xlarge"
)

Setting parameters for pipeline


In [13]:
processing_instance_type

sagemaker.workflow.parameters.ParameterString

In [11]:
processor = PyTorchProcessor(role=role,
                                 framework_version="1.8",
                                 instance_type=processing_instance_type,
                                 instance_count=processing_instance_count,
                                 sagemaker_session=sagemaker_session)

The input argument instance_type of function (sagemaker.image_uris.retrieve) is a pipeline variable (<class 'sagemaker.workflow.parameters.ParameterString'>), which is not allowed. The default_value of this Parameter object will be used to override it. Please make sure the default_value is valid.


In [21]:
cache_config = CacheConfig(
        enable_caching=True,
        expire_after="T1H"
    )
processing_step = ProcessingStep(name="pre-processing-step",
                                 processor=processor,
                                 inputs=ProcessingInput(input_name="raw-data",
                                                        source="/",
                                                        destination=f"input"),
                                 code="sample.ipynb",
                                 outputs=[ProcessingOutput(output_name="train",
                                                           source="train",
                                                           destination="dest"),
                                          ProcessingOutput(output_name="val",
                                                           source=f"val",
                                                           destination=f"val")],
                                cache_config=cache_config)