In [1]:
import sagemaker
import boto3
import os

from sagemaker.workflow.parameters import (
    ParameterInteger,
    ParameterString
)
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.pipeline import PipelineDefinitionConfig
from sagemaker import image_uris
from steps.processor import get_processor_step
from steps.evaluator import get_evaluator_step

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


In [2]:
def get_parameters() -> dict:
    # - Processing ----------------------------------
    process_instance_count_param = ParameterInteger(
        name="ProcessingInstanceCount",
        default_value=1
    )
    process_instance_type_param = ParameterString(
        name="ProcessingInstanceType",
        default_value="ml.m5.large",
    )
    # - Trainer -------------------------------------

    #------------------------------------------------

    return {
        "process_instance_count_param": process_instance_count_param,
        "process_instance_type_param": process_instance_type_param,
    }

In [3]:
def get_pipeline(
    session: sagemaker.Session,
    parameters: dict,
    constants: dict,
    sklearn_image_uri: str,
):
    pipeline_def_config = PipelineDefinitionConfig(use_custom_job_prefix=True)

    # - Processing ----------------------------------
    processor_step = get_processor_step(
        project=constants["project"],
        bucket_name=constants["bucket_name"],
        process_instance_count_param=parameters["process_instance_count_param"],
        process_instance_type_param=parameters["process_instance_type_param"],
        sklearn_image_uri=sklearn_image_uri,
        region=constants["region"],
    )

    # - Trainer -------------------------------------
    trainer_step = get_trainer_step(

    )

    # - Evaluator ----------------------------------
    evaluator_step = get_evaluator_step(
        project=constants["project"],
        bucket_name=constants["bucket_name"],
        process_instance_count_param=parameters["process_instance_count_param"],
        process_instance_type_param=parameters["process_instance_type_param"],
        evaluation_image_uri='763104351884.dkr.ecr.eu-central-1.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-ec2',
        region=constants["region"],

        test_metadata_prefix='sagemaker/fire-image-classification',
        best_model_prefix='models',
        test_metadata_file='test.pkl',
        best_model_file='model_resnet18.tar.gz',
        result_prefix='evaluation/dummy',
        data_dir='s3://wildfires/sagemaker/fire-image-classification',
        model_package_arn='arn:aws:sagemaker:eu-central-1:567821811420:model-package/first-fire-mlflow-ee0049/1'
    )

    # ------------------------------------------------

    return Pipeline(
        name=f"{constants['project']}-pipeline",
        parameters=[parameters[key] for key in parameters],
        pipeline_definition_config=pipeline_def_config,
        steps=[
            processor_step,
            trainer_step,
            evaluator_step
        ],
    )

In [4]:
print(os.getcwd())

parameters = get_parameters()

constants = {
    "region": "eu-central-1",
    "project": "wildfire-project",
    "bucket_name": "wildfires",
    "sklearn_image_uri_version": "1.2-1",
}

session = sagemaker.Session(boto3.Session(region_name=constants["region"]))

sklearn_image_uri = image_uris.retrieve(
    framework="sklearn",
    region=constants["region"],
    version=constants["sklearn_image_uri_version"],
)

pipeline = get_pipeline(
    session=session,
    parameters=parameters,
    constants=constants,
    sklearn_image_uri=sklearn_image_uri,
)

pipeline.upsert(role_arn=sagemaker.get_execution_role())

INFO:sagemaker.image_uris:Defaulting to only available Python version: py3
INFO:sagemaker.image_uris:Defaulting to only supported image scope: cpu.


/home/sagemaker-user/src/pipeline


{'PipelineArn': 'arn:aws:sagemaker:eu-central-1:567821811420:pipeline/wildfire-project-pipeline',
 'ResponseMetadata': {'RequestId': '4e77de58-ccf3-47ac-bd08-6c06ee666de3',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '4e77de58-ccf3-47ac-bd08-6c06ee666de3',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '96',
   'date': 'Tue, 25 Jun 2024 17:33:59 GMT'},
  'RetryAttempts': 0}}