In [1]:
import stepfunctions
import sagemaker
import logging

from stepfunctions.steps import Pass, Chain, Choice, ChoiceRule, LambdaStep, Catch
import stepfunctions.steps.sagemaker as smsteps
from stepfunctions.workflow import Workflow

stepfunctions.set_stream_logger(level=logging.INFO)

In [2]:
from sagemaker.processing import ScriptProcessor
from sagemaker.processing import ProcessingInput, ProcessingOutput
from stepfunctions.inputs import ExecutionInput
from sagemaker.session import s3_input
from sagemaker import get_execution_role

In [3]:
role = get_execution_role()

In [4]:
execution_input = ExecutionInput(
    schema={
        "experiment_name": str,
        "experiment_description": str,
        "lr": float,
        "epochs": int,
        "loss_function": str,
        "input_data_path": str,
        "output_model_path": str,
        "code_path": str
    }
)

In [5]:
launch_workflow = LambdaStep(
        state_id="Start training and evaluation",
        timeout_seconds = 60,
        result_path="$",
        parameters={
            "FunctionName": "start_lightfm_train_eval", 
            "Payload": {"input": execution_input}
        }
    )

In [6]:
model_training_evaluation = ScriptProcessor(
                image_uri='',
                role=get_execution_role(),
                instance_count=1,
                instance_type='ml.m5.xlarge',
                command =  ["python3"],
                base_job_name = "lightfm-training",
                max_runtime_in_seconds=7200)

In [7]:
run_training_step = smsteps.ProcessingStep(state_id="Train Model",
                                               processor=model_training_evaluation,
                                               job_name=launch_workflow.output()["Payload"]["body"]["training_job_name"],
                                               inputs=[ProcessingInput(input_name="code",
                                                        source=launch_workflow.output()["Payload"]["body"]["code_path"],
                                                        destination="/opt/ml/processing/code"),
                                                      ProcessingInput(input_name="data",
                                                                      source=launch_workflow.output()["Payload"]["body"]["input_data_path"], 
                                                                       destination='/opt/ml/processing/input/')
                                                      ],
                                               outputs=[ProcessingOutput(output_name="model_output",
                                                                         source='/opt/ml/processing/output',
                                                                         destination=launch_workflow.output()["Payload"]["body"]["output_model_path"]
                                                                        ) 
                                                       ],
                                                container_entrypoint=["python3", "/opt/ml/processing/code/train.py"],
                                                container_arguments=launch_workflow.output()["Payload"]["body"]["script_args"],
                                                wait_for_completion=True,
                                               result_path = "$.trainresult"
                                               )

In [8]:
run_evaluation_step = smsteps.ProcessingStep(state_id="Evaluate Model",
                                               processor=model_training_evaluation,
                                               job_name=run_training_step.output()["Payload"]["body"]["eval_job_name"],
                                               inputs=[ProcessingInput(input_name="code",
                                                        source=run_training_step.output()["Payload"]["body"]["code_path"],
                                                        destination="/opt/ml/processing/code"),
                                                       ProcessingInput(input_name="data",
                                                                      source=run_training_step.output()["Payload"]["body"]["input_data_path"], 
                                                                       destination='/opt/ml/processing/input/'),
                                                       ProcessingInput(input_name="model_artifact",
                                                          source=run_training_step.output()["Payload"]["body"]["output_model_path"], 
                                                           destination='/opt/ml/processing/model/')
                                          ],
                                               outputs=[ProcessingOutput(output_name="model_output",
                                                                         source='/opt/ml/processing/output',
                                                                         destination=run_training_step.output()["Payload"]["body"]["output_model_path"]
                                                                        ) 
                                                       ],
                                                container_entrypoint=["python3", "/opt/ml/processing/code/evaluate.py"],
                                                container_arguments=run_training_step.output()["Payload"]["body"]["script_args"],
                                                wait_for_completion=True,
                                                result_path = "$.evalresult"
                                               )

In [9]:
send_notification = LambdaStep(
        state_id="Publish results and send notification",
        timeout_seconds = 60,
        parameters={
            "FunctionName": "lightfm-results", 
            "Payload": {"input": {
                "experiment_name": run_evaluation_step.output()["Payload"]["body"]["experiment_name"],
                "trial_name": run_evaluation_step.output()["Payload"]["body"]["trial_name"]
                
                                 }
                       }
        }
    )

In [10]:
train_eval_path = Chain([launch_workflow, run_training_step, run_evaluation_step, send_notification])

In [11]:
inference_workflow = Workflow(
    name="lightfm-workflow",
    definition=train_eval_path,
    role=role
)

In [12]:
print(inference_workflow.definition.to_json(pretty=True))

{
    "StartAt": "Start training and evaluation",
    "States": {
        "Start training and evaluation": {
            "TimeoutSeconds": 60,
            "ResultPath": "$",
            "Parameters": {
                "FunctionName": "start_lightfm_train_eval",
                "Payload": {
                    "input.$": "$$.Execution.Input"
                }
            },
            "Resource": "arn:aws:states:::lambda:invoke",
            "Type": "Task",
            "Next": "Train Model"
        },
        "Train Model": {
            "ResultPath": "$.trainresult",
            "Resource": "arn:aws:states:::sagemaker:createProcessingJob.sync",
            "Parameters": {
                "ProcessingJobName.$": "$['Payload']['body']['training_job_name']",
                "ProcessingInputs": [
                    {
                        "InputName": "code",
                        "S3Input": {
                            "S3Uri.$": "$['Payload']['body']['code_path']",
                