In [9]:
import sagemaker
import boto3
import os

from sagemaker.workflow.parameters import (
    ParameterInteger,
    ParameterString,
    ParameterFloat,
)
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.pipeline import PipelineDefinitionConfig
from sagemaker import image_uris
from steps.processor import get_processor_step
from steps.evaluator import get_evaluator_step
from steps.trainer import get_trainer_step

In [10]:
def get_parameters() -> dict:
    # - Common --------------------------------------
    random_seed = ParameterString(
        name="RandomSeed",
        default_value="1"
    )
    # - Processing ----------------------------------
    process_instance_count_param = ParameterInteger(
        name="ProcessingInstanceCount",
        default_value=1
    )
    process_instance_type_param = ParameterString(
        name="ProcessingInstanceType",
        default_value="ml.m5.large",
    )
    # - Trainer -------------------------------------
    tracking_server_arn = ParameterString(
        name="TrackingServerArn",
        default_value="arn:aws:sagemaker:eu-central-1:567821811420:mlflow-tracking-server/wildfire-mj",
    )
    train_instance_count_param = ParameterInteger(
        name="TrainInstanceCount",
        default_value=1
    )
    train_instance_type_param = ParameterString(
        name="TrainInstanceType",
        default_value="ml.p3.2xlarge",
    )
    train_epochs_num = ParameterInteger(
        name="NumberOfEpochs",
        default_value=10
    )
    train_batch_size = ParameterInteger(
        name="BatchSize",
        default_value=32
    )
    train_learning_rate = ParameterFloat(
        name="LearningRate",
        default_value=0.1
    )
    # -----------------------------------------------

    return {
        "random_seed": random_seed,
        "process_instance_count_param": process_instance_count_param,
        "process_instance_type_param": process_instance_type_param,
        "tracking_server_arn": tracking_server_arn,
        "train_instance_count_param": train_instance_count_param,
        "train_instance_type_param": train_instance_type_param,
        "train_epochs_num": train_epochs_num,
        "train_batch_size": train_batch_size,
        "train_learning_rate": train_learning_rate,
    }

In [11]:
def get_pipeline(
    session: sagemaker.Session,
    parameters: dict,
    constants: dict,
    sklearn_image_uri: str,
):
    pipeline_def_config = PipelineDefinitionConfig(use_custom_job_prefix=True)

    # - Processing ----------------------------------
    processor_step = get_processor_step(
        project=constants["project"],
        bucket_name=constants["bucket_name"],
        process_instance_count_param=parameters["process_instance_count_param"],
        process_instance_type_param=parameters["process_instance_type_param"],
        sklearn_image_uri=sklearn_image_uri,
        region=constants["region"],
        seed=parameters["random_seed"]
    )

    # - Trainer -------------------------------------
    trainer_step = get_trainer_step(
        project=constants["project"],
        bucket_name=constants["bucket_name"],
        tracking_server_arn=parameters["tracking_server_arn"],
        train_instance_count_param=parameters["train_instance_count_param"],
        train_instance_type_param=parameters["train_instance_type_param"],
        region=constants["region"],
        epochs_num=parameters["train_epochs_num"],
        batch_size=parameters["train_batch_size"],
        learning_rate=parameters["train_learning_rate"],
        seed=parameters["random_seed"]
    )

    # - Evaluator ----------------------------------
    evaluator_step = get_evaluator_step(
        project=constants["project"],
        bucket_name=constants["bucket_name"],
        process_instance_count_param=parameters["process_instance_count_param"],
        process_instance_type_param=parameters["process_instance_type_param"],
        evaluation_image_uri='763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-ec2',
        region=constants["region"],

        test_metadata_prefix='sagemaker/fire-image-classification',
        best_model_prefix='models',
        test_metadata_file='test.pkl',
        best_model_file='model_resnet18.tar.gz',
        result_prefix='evaluation/dummy',
        data_dir='s3://wildfires/sagemaker/fire-image-classification',
        model_package_arn='arn:aws:sagemaker:eu-central-1:567821811420:model-package/first-fire-mlflow-ee0049/1'
    )

    # ------------------------------------------------
    trainer_step.add_depends_on([processor_step])

    return Pipeline(
        name=f"{constants['project']}-pipeline",
        parameters=[parameters[key] for key in parameters],
        pipeline_definition_config=pipeline_def_config,
        steps=[
            processor_step,
            trainer_step,
            # evaluator_step
        ],
    )

In [12]:
print(os.getcwd())

parameters = get_parameters()

constants = {
    "region": "eu-central-1",
    "project": "wildfire-project",
    "bucket_name": "wildfires",
    "sklearn_image_uri_version": "1.2-1",
}

session = sagemaker.Session(boto3.Session(region_name=constants["region"]))

sklearn_image_uri = image_uris.retrieve(
    framework="sklearn",
    region=constants["region"],
    version=constants["sklearn_image_uri_version"],
)

pipeline = get_pipeline(
    session=session,
    parameters=parameters,
    constants=constants,
    sklearn_image_uri=sklearn_image_uri,
)

pipeline.upsert(role_arn=sagemaker.get_execution_role())

INFO:sagemaker.image_uris:Defaulting to only available Python version: py3
INFO:sagemaker.image_uris:Defaulting to only supported image scope: cpu.


/home/sagemaker-user/src/pipeline


INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.


{'PipelineArn': 'arn:aws:sagemaker:eu-central-1:567821811420:pipeline/wildfire-project-pipeline',
 'ResponseMetadata': {'RequestId': '2ba80f8f-21d9-4adc-854e-8f9d7955088e',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '2ba80f8f-21d9-4adc-854e-8f9d7955088e',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '96',
   'date': 'Thu, 27 Jun 2024 23:43:46 GMT'},
  'RetryAttempts': 0}}