In [1]:
import os
from uuid import uuid4
import json
from utils import DecimalEncoder
import boto3
import stepfunctions
from stepfunctions import steps
from stepfunctions.inputs import ExecutionInput
from stepfunctions.template import TrainingPipeline
from stepfunctions.template.utils import replace_parameters_with_jsonpath
from stepfunctions.workflow import Workflow
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.session import Session

from dotenv import load_dotenv

In [2]:
assert load_dotenv("../../../.env")

In [None]:
# BUCKET_ARTIFACT = os.environ["ARTIFACT_BUCKET"]

In [3]:
BUCKET = os.environ["BUCKET"]
SESSION = boto3.session.Session()
DYNAMODB_RESOURCE = SESSION.resource('dynamodb')
DYNAMO_TABLE = DYNAMODB_RESOURCE.Table(os.environ['DB_NAME'])
SFN_CLIENT = client = boto3.client('stepfunctions')
EXECUTION_ROLE = os.environ["SAGEMAKER_ARN"]
WORKFLOW_EXECUTION_ROLE = os.environ["STEP_FUNCTION_ARN"]

In [4]:
execution_input = ExecutionInput(
    schema={
        "SOURCE_TO_TRANSLATE": str,
        "SOURCE_MODEL_ARTIFACT": str,
        "DESTINATION_OUTPUT": str, 
        "ProcessingJobName": str,
        "input_code": str,
        "job_pk": str,
        "job_sk": str,
    }
)

In [5]:
input_meta = [
        ProcessingInput(
            source=execution_input["SOURCE_TO_TRANSLATE"],
            destination='/opt/ml/processing/input',
            input_name="input"
        ),
        ProcessingInput(
            source=execution_input["SOURCE_MODEL_ARTIFACT"],
            destination='/opt/ml/processing/input/model',
            input_name="model"
            ),
        #this is new
        ProcessingInput(
            source=execution_input["input_code"],
            destination="/opt/ml/processing/code",
            input_name="code"
        ),
    ]

output_meta = [ 
        ProcessingOutput(
            source='/opt/ml/processing/output',
            output_name='output',
            destination=execution_input["DESTINATION_OUTPUT"]
        ),
    ]

In [6]:
def get_processing_container_config():
    sklearn_processor = SKLearnProcessor(
        framework_version='0.20.0',
        role=EXECUTION_ROLE,
        instance_type='ml.t3.medium',
        instance_count=1,
        sagemaker_session=Session(default_bucket=os.environ["BATCH_JOB_BUCKET"])
    )
    return sklearn_processor

In [7]:
processing_step = steps.ProcessingStep(
    "SageMakerTranslationJob",
    processor=get_processing_container_config(),
    job_name=execution_input["ProcessingJobName"],
    inputs=input_meta,
    outputs=output_meta,
    container_arguments=[
        "--bucket", os.environ["BUCKET"],
        "--file", "document.docx"
    ],
    container_entrypoint=["python3", "/opt/ml/processing/code/main.py"]
)

In [8]:
def create_update_job_data_steps(job_type: str):
    section_list = ["completed", "failed"]
    step_repo = {}
    for item in section_list:
        exec_status = item.capitalize()
        name = f'update-dynamodb-{job_type}-{exec_status}'
        step_repo[item] = steps.compute.LambdaStep(
                name,
                parameters={  
                    "FunctionName": 'quack-tsln-update-jobstatus-step',
                    'Payload':{
                        "inputs": execution_input,
                        "execStatus": f"Task{exec_status}"
                    }
                }
             )
    return step_repo

# Build the Main Workflow

In [9]:
update_job_data_repo = create_update_job_data_steps("job1")
catch_state_processing = stepfunctions.steps.states.Catch(
    error_equals=["States.ALL"],
    next_step=update_job_data_repo["failed"],
)

processing_step.add_catch(catch_state_processing)

In [10]:
workflow_graph = steps.Chain([processing_step, update_job_data_repo["completed"]])
workflow = Workflow(
    name="quack-tsln_ProcessingJob",
    definition=workflow_graph,
    role=WORKFLOW_EXECUTION_ROLE,
    execution_input=execution_input
)

In [11]:
workflow.render_graph()
# workflow_arn = workflow.create()

In [21]:
# workflow_arn

In [12]:
def workflow_to_json(workflow):
    filename = workflow.name
    with open(f"{filename}.json", "w") as f:
        data = workflow.definition.to_json(pretty=True)
        f.write(data)

In [13]:
workflow_to_json(workflow)

# Update Existing Stepfunctoin

In [16]:
SFN_CLIENT.update_state_machine(**{
    'stateMachineArn': os.environ["SFN_WORKFLOW_ARN"],
    'definition': workflow.definition.to_json()
})

{'updateDate': datetime.datetime(2022, 12, 9, 19, 44, 13, 58000, tzinfo=tzlocal()),
 'ResponseMetadata': {'RequestId': '7e6d3be4-89b5-4cb6-ad25-68210f8d6a78',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '7e6d3be4-89b5-4cb6-ad25-68210f8d6a78',
   'date': 'Fri, 09 Dec 2022 18:44:13 GMT',
   'content-type': 'application/x-amz-json-1.0',
   'content-length': '31'},
  'RetryAttempts': 0}}

# Delete a Step Function Pipeline

In [19]:
def delete_redundant_functions(arn_list: list) -> None:
    confirmed = input("Please confirm with y/n: ")
    if confirmed == 'y':
        for arn in arn_list:
            resp = SFN_CLIENT.delete_state_machine(stateMachineArn=arn)
            print(f"Deleted {resp}")
    print("Delition cancled.")
        
def delete_function_by_arn(arn: str) -> None:
    confirmed = input("Please confirm with y/n: ")
    if confirmed == 'y':
        resp = SFN_CLIENT.delete_state_machine(stateMachineArn=arn)
        print(f"Deleted {resp}")
    print("Delition cancled.")

In [20]:
# delete_function_by_arn(workflow_arn)

In [17]:
arn_list = [f"arn:aws:states:eu-west-1:240911078895:stateMachine:germandossier_ProcessingJob2_v{i}" for i in range(1,6)]

In [20]:
## Example of Adding Parallel Step
# parallel_state = steps.Parallel("TranslateAllSections")
# parallel_state.add_branch(processing_step_repo["heading"])
# parallel_state.add_branch(processing_step_repo["footnote"])
# parallel_state.add_branch(processing_step_repo["phrase"])
# parallel_state.add_catch(catch_state)
# parallel_state.add_retry(retry_step)
# processing_step_repo.add_catch(catch_state)
# temp_step = processing_step_repo["heading"]
# temp_step.add_catch(catch_state)
# workflow_graph = steps.Chain([temp_step, update_job_data_repo["completed"]])