# Build machine learning workflow to train a model with Amazon SageMaker and AWS Step Functions

This script creates a Step Function state machine to preprocess the training data and train a model with the images in ECR.

## Import modules

In [1]:
import uuid
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.processing import Processor, ProcessingInput, ProcessingOutput
import stepfunctions
from stepfunctions.inputs import ExecutionInput
from stepfunctions.workflow import Workflow
from stepfunctions.steps import (
    TrainingStep, 
    Chain,
    ProcessingStep,
)

## Setup

Modify according to your configurations.

In [2]:
# Bucket name in S3
bucket = "hermione-sagemaker"

In [3]:
# Set session
region_name="us-east-1"
boto3.setup_default_session(region_name=region_name)

In [4]:
# Get user role
role = get_execution_role()

In [5]:
# Role to create and execute step functions
# paste the AmazonSageMaker-StepFunctionsWorkflowExecutionRole ARN
workflow_execution_role = ""

In [6]:
# SageMaker expects unique names for each job, model and endpoint.
# Otherwise, the execution will fail. The ExecutionInput creates
# dynamically names for each execution.
execution_input = ExecutionInput(
    schema={
        "PreprocessingJobName": str,
        "TrainingJobName": str
    }
)

In [7]:
# Get AWS Account ID
account_number = boto3.client("sts").get_caller_identity()["Account"]

In [8]:
# Processor image name previous uploaded in ECR
image_name_processor = "hermione-processor"

In [9]:
# Training image name previous uploaded in ECR
image_name_train = "hermione-train"

In [10]:
# Input and output paths to execute
paths = {
    'train_raw': f"s3://{bucket}/TRAIN_RAW",
    'expectations': f"s3://{bucket}/PREPROCESSING/EXPECTATIONS",
    'preprocessing': f"s3://{bucket}/PREPROCESSING/PREPROCESSING",
    'train_processed': f"s3://{bucket}/PREPROCESSING/TRAIN_PROCESSED",
    'val_processed': f"s3://{bucket}/PREPROCESSING/VAL_PROCESSED",
    'model': f"s3://{bucket}/PREPROCESSING/MODEL"
}

In [11]:
# instance to run the code
instance_type_preprocessing="ml.t3.medium"
instance_type_train="ml.m5.large"

## Preprocessing Step

In [12]:
# Processor image previous uploaded in ECR
image_uri_processor = f"{account_number}.dkr.ecr.{region_name}.amazonaws.com/{image_name_processor}"

In [13]:
# Creates the processor to access the ECR image
processor = Processor(image_uri=image_uri_processor,
                     role=role,
                     instance_count=1,
                     instance_type=instance_type_preprocessing)

In [14]:
# Creates input and output objects for ProcessingStep
inputs=[
    ProcessingInput(source=paths['train_raw'], 
                    destination='/opt/ml/processing/input/raw_data', 
                    input_name="raw_data")
]
outputs = [
    ProcessingOutput(
        source="/opt/ml/processing/output/expectations",
        destination=paths['expectations'],
        output_name="expectations",
    ),
    ProcessingOutput(
        source="/opt/ml/processing/output/preprocessing",
        destination=paths['preprocessing'],
        output_name="preprocessing",
    ),
    ProcessingOutput(
        source="/opt/ml/processing/output/processed/train",
        destination=paths['train_processed'],
        output_name="train_data",
    ),
    ProcessingOutput(
        source="/opt/ml/processing/output/processed/val",
        destination=paths['val_processed'],
        output_name="val_data",
    )
]

In [15]:
# Creates the ProcessingStep
processing_step = ProcessingStep(
    "Preprocessing step",
    processor=processor,
    job_name=execution_input["PreprocessingJobName"],
    inputs=inputs,
    outputs=outputs,
    container_arguments=["--step", "train"]
)

## TrainingStep

In [16]:
# Training image previous uploaded in ECR
image_uri_train = f"{account_number}.dkr.ecr.{region_name}.amazonaws.com/{image_name_train}"

In [17]:
# Creates input and output objects for TrainingStep
train_config = sagemaker.inputs.TrainingInput(
    paths['train_processed'],
    content_type='text/csv',
)
val_config = sagemaker.inputs.TrainingInput(
    paths['val_processed'],
    content_type='text/csv'
)
output_path = paths['model']

In [18]:
# Creates the estimator to access the ECR image
est = sagemaker.estimator.Estimator(
    image_uri_train,
    role, 
    instance_count=1, 
    instance_type=instance_type_train,
    volume_size = 30,
    output_path = output_path,
    base_job_name = "Hermione-Train",
    use_spot_instances=True,  # Usar instâncias SPOT
    max_run = 24*60*60,
    max_wait = 24*60*60       # timeout em segundos. Required if use_spot_instances == True
)

In [19]:
# Creates the TrainingStep
training_step = TrainingStep(
    'TrainStep',
    estimator=est,
    data={
        'train': train_config,
        'validation': val_config
    }, 
    job_name=execution_input["TrainingJobName"] 
)

## Create Workflow and Execute

In [20]:
# Creates Fail state to mark the workflow failed in case any of the steps fail.
failed_state_sagemaker_processing_failure = stepfunctions.steps.states.Fail(
    "ML Workflow failed", cause="SageMakerProcessingJobFailed"
)

In [21]:
# Adds the Error handling in the workflow
catch_state_processing = stepfunctions.steps.states.Catch(
    error_equals=["States.TaskFailed"],
    next_step=failed_state_sagemaker_processing_failure,
)

processing_step.add_catch(catch_state_processing)
training_step.add_catch(catch_state_processing)

In [None]:
# Creates workflow with Pre-Processing Job and Training Job
workflow_graph = Chain([processing_step, training_step])
branching_workflow = Workflow(
    name="SFN_Hermione_Train",
    definition=workflow_graph,
    role=workflow_execution_role,
)
branching_workflow.create()

In [23]:
# Generates unique names for Pre-Processing Job and Training Job
# Each job requires a unique name
preprocessing_job_name = "Hermione-Preprocessing-{}".format(
    uuid.uuid1().hex
) 
training_job_name = "Hermione-Training-{}".format(
    uuid.uuid1().hex
) 

In [24]:
# Executes the workflow
execution = branching_workflow.execute(
    inputs={
        "PreprocessingJobName": preprocessing_job_name,
        "TrainingJobName": training_job_name
    }
)
execution_output = execution.get_output(wait=False)
execution.render_progress()