In [None]:
# !pip install --ignore-installed amzn-mods-workflow-helper amzn-mods-python-sdk
# !pip install --ignore-installed amzn-secure-ai-sandbox-workflow-python-sdk

In [None]:
from secure_ai_sandbox_python_lib.session import Session as SaisSession
sais_session = SaisSession('.')

from mods_workflow_helper.sagemaker_pipeline_helper import SecurityConfig
security_config = SecurityConfig(
    kms_key=sais_session.get_team_owned_bucket_kms_key(),
    security_group=sais_session.sandbox_vpc_security_group(),
    vpc_subnets=sais_session.sandbox_vpc_subnets()
)

from sagemaker.workflow.pipeline_context import PipelineSession
session = PipelineSession(default_bucket=sais_session.team_owned_s3_bucket_name())

## Need to change model class name 

In [None]:
# eu_tsa_sq_model.py is the MODS template
from eu_tsa_sq_model import EUTSASuspectQueueModel
model = EUTSASuspectQueueModel(sagemaker_session=session)
pipeline = model.generate_pipeline()

### Prepare execution document

In [None]:
from datetime import date, datetime, timedelta

# split time for training/calibration/testing data
train_start_date = (date.today()- timedelta(days = 180)  ).strftime("%Y-%m-%d") + 'T00:00:00' 
train_end_date = (date.today()- timedelta(days = 90)  ).strftime("%Y-%m-%d") + 'T00:00:00'

validation_start_date = (date.today()- timedelta(days = 89)  ).strftime("%Y-%m-%d") + 'T00:00:00' 
validation_end_date = (date.today()- timedelta(days = 79)  ).strftime("%Y-%m-%d") + 'T00:00:00' 

calibration_start_date = (date.today()- timedelta(days = 28)  ).strftime("%Y-%m-%d") + 'T00:00:00' 
calibration_end_date =(date.today()- timedelta(days = 14)  ).strftime("%Y-%m-%d") + 'T00:00:00'

print("train_start_date : ", train_start_date)
print("train_end_date : ", train_end_date)
print("validation_start_date : ", validation_start_date)
print("validation_end_date : ", validation_end_date)
print("calibration_start_date : ", calibration_start_date)
print("calibration_end_date : ", calibration_end_date)

date_format = '%Y-%m-%dT%H:%M:%S'
train_delta_days = (datetime.strptime(train_end_date, date_format) - datetime.strptime(train_start_date, date_format)).days
train_split_job = (train_delta_days + 1) > 7
validation_delta_days = (datetime.strptime(validation_end_date, date_format) - datetime.strptime(validation_start_date, date_format)).days
validation_split_job = (validation_delta_days + 1) > 7
calibration_delta_days = (datetime.strptime(calibration_end_date, date_format) - datetime.strptime(calibration_start_date, date_format)).days
calibration_split_job = (calibration_delta_days + 1) > 7

print(train_split_job, validation_split_job, calibration_split_job)

In [None]:
from com.amazon.secureaisandboxproxyservice.models.createcradledataloadjobrequest import CreateCradleDataLoadJobRequest
from com.amazon.secureaisandboxproxyservice.models.datasourcesspecification import DataSourcesSpecification
from com.amazon.secureaisandboxproxyservice.models.mdsdatasourceproperties import MdsDataSourceProperties
from com.amazon.secureaisandboxproxyservice.models.andesdatasourceproperties import AndesDataSourceProperties
from com.amazon.secureaisandboxproxyservice.models.transformspecification import TransformSpecification
from com.amazon.secureaisandboxproxyservice.models.outputspecification import OutputSpecification
from com.amazon.secureaisandboxproxyservice.models.cradlejobspecification import CradleJobSpecification
from com.amazon.secureaisandboxproxyservice.models.edxdatasourceproperties import EdxDataSourceProperties

from com.amazon.secureaisandboxproxyservice.models.jobsplitoptions import JobSplitOptions
from com.amazon.secureaisandboxproxyservice.models.field import Field
from com.amazon.secureaisandboxproxyservice.models.datasource import DataSource
from secure_ai_sandbox_python_lib.utils import coral_utils

#### Load necessary files to create Cradle requests for data pulling

In [None]:
model_dir='./'
import sys
sys.path.append(model_dir+'/scripts/')
from params import seq_cat_vars, seq_num_vars, dense_num_vars, \
                input_data_seq_cat_otf_vars, input_data_seq_num_otf_vars, \
                input_data_seq_cat_vars, input_data_seq_num_vars, \
                input_data_dense_num_vars, numerical_cat_vars

In [None]:
model_var_list = seq_cat_vars + seq_num_vars + dense_num_vars + input_data_seq_cat_otf_vars + input_data_seq_num_otf_vars
var_list = list(set(['objectId','orderDate','transactionDate','marketplaceCountryCode','marketplaceId',
                     'isQueued','paymeth','ictry_cd','bctry_cd','sctry_cd','cctry_cd','isSidelined','emailorg',
                     'creditCardIds'] + model_var_list))
print(len(model_var_list), len(set(var_list)))

#### Create Cradle requests

In [None]:
# Placeholder, will not be actually used
output_path=''

In [None]:
request_training = CreateCradleDataLoadJobRequest(
    data_sources=DataSourcesSpecification(
        start_date=train_start_date, # data start date
        end_date=train_end_date, # data end date
        data_sources = [ # data sources a list of data source properties
            DataSource( 
                data_source_name='RAW_MDS',  # data source name, it should be uniq across the list of data source. this name should be used as table name when you write the SQL
                data_source_type='MDS', # data source type, it can be 'MDS/ANDES/EDX', you need setup the properties according to this type
                mds_data_source_properties=MdsDataSourceProperties( #
                    service_name='FORTRESS_RETAIL',
                    org_id='2',
                    region='EU',
                    # output_schema=[Field(field_name=f, field_type='STRING') for f in pullVars],
                    output_schema=[Field(field_name=f, field_type='STRING') for f in var_list],
                    use_hourly_edx_data_set=False, # MDS/EDX have another data set which merges the raw manifest. you can change this to True to use hourly data provider which can reduce the hot data set's throttling issue. hourly EDX data provider doesn't contain all the data, you need verify and make sure the hourly data set is available. example link: https://edx.corp.amazon.com/providers/cmls-raw-hourly-data/subjects/fortress-retail/datasets/na-1
                )
            ),
            DataSource( 
                data_source_name='TAGS',
                data_source_type='ANDES', # this is an example of Andes data source
                andes_data_source_properties=AndesDataSourceProperties(
                    provider='26b27bde-3847-49c6-a07c-0289c17d9c33',
                    table_name='fraud-tags-eu',
                )
            ),
        ]
    ), 
    transform_specification=TransformSpecification( # transformSQL should refer the above data source name to query the data
        transform_sql="""
        select 
            RAW_MDS.*, 
            TAGS.is_frd AS IS_FRD
        from RAW_MDS 
        left join TAGS 
            on RAW_MDS.objectId=TAGS.order_id  
            AND TAGS.order_day_timestamp >= TO_TIMESTAMP('${startDate}', 'yyyy-MM-dd')  
            AND TAGS.order_day_timestamp <= TO_TIMESTAMP('${endDate}', 'yyyy-MM-dd')
        where (TAGS.is_frd=1 OR (TAGS.is_frd=0 and rand()<0.025)) and (
            ((marketplaceCountryCode in ('GB','DE','FR','IT','ES') and isQueued=1) or 
            (marketplaceCountryCode in ('TR','NL','SE','SA','AE','EG','PL','BE') and isSidelined=1)) 
        ) 
        """,
        job_split_options=JobSplitOptions( 
            split_job=train_split_job, # edit for test False, # You can enable job split option by changing this function to True, but you need provide merge_sql. INPUT will the all the data after split executes, you can write extra logic in SQL, e.g. using group by for statistics or dedup.  
            days_per_split=7,
            merge_sql="""
                WITH data AS (
                    SELECT INPUT.*,
                        TO_TIMESTAMP(INPUT.transactionDate, 'EEE MMM dd HH:mm:ss zzz yyyy') AS transactionDateV2
                    FROM INPUT
                    ),
                dedup AS (
                    SELECT *
                    FROM (
                        SELECT *, ROW_NUMBER() OVER (PARTITION BY objectId ORDER BY transactionDateV2 DESC) AS __rownum__
                        FROM data
                        )
                    WHERE __rownum__ = 1
                    )
                select * from dedup
            """
            )
    ),
    output_specification=OutputSpecification(
        # output_schema=list(pullVars)+['IS_FRD'],  # output_schema should be provided as the final output fields. 
        output_schema=var_list+['IS_FRD'],
        output_path=output_path, # 
        output_format='UNESCAPED_TSV', #     # output format can be CSV, UNESCAPED_TSV, JSON, ION, PARQUET. CSV is the default format if you don't specify it
        output_save_mode='ERRORIFEXISTS',   # output save mode can setup to support different case, it can be OVERWRITE, ERRORIFEXISTS, APPEND, IGNORE. In default it's ERRORIFEXISTS. ",
        output_file_count=0, # output file count can be set to reduce or increase final the number of files. Too many output files will cause S3 throttling failure; Too few output will encounter performance issues. current setting is 30 per day, you can provide the overrides for your overrides.
        keep_dot_in_output_schema=True, # When set to true, the output file header will contain normal the '.'. Otherwise when set to False, the output file header will replace every '.' with '__DOT__'.
        # edit to test 
        include_header_in_s3_output=True # When set to true, the s3 output file will include header. Note that only S3 supports output with header.
    ),
    cradle_job_specification=CradleJobSpecification(
        cluster_type='LARGE', 
        cradle_account='BRP-ML-Payment-Generate-Data',
        extra_spark_job_arguments='', # you can customize the spark job driver memory if you need by vending parameters here
        job_retry_count=4, # job retry count in case of failure, in default Cradle will retry once if it fails. you can customize retry times. 
    )
)

# cradle_loading_request_dict = coral_utils.convert_coral_to_dict(request)
cradle_training_request_dict = coral_utils.convert_coral_to_dict(request_training)

In [None]:
request_validation = CreateCradleDataLoadJobRequest(
    data_sources=DataSourcesSpecification(
        start_date=validation_start_date, # data start date
        end_date=validation_end_date, # data end date
        data_sources = [ # data sources a list of data source properties
            DataSource( 
                data_source_name='RAW_MDS',  # data source name, it should be uniq across the list of data source. this name should be used as table name when you write the SQL
                data_source_type='MDS', # data source type, it can be 'MDS/ANDES/EDX', you need setup the properties according to this type
                mds_data_source_properties=MdsDataSourceProperties( #
                    service_name='FORTRESS_RETAIL',
                    org_id='2',
                    region='EU',
                    # output_schema=[Field(field_name=f, field_type='STRING') for f in pullVars],
                    output_schema=[Field(field_name=f, field_type='STRING') for f in var_list],
                    use_hourly_edx_data_set=False, # MDS/EDX have another data set which merges the raw manifest. you can change this to True to use hourly data provider which can reduce the hot data set's throttling issue. hourly EDX data provider doesn't contain all the data, you need verify and make sure the hourly data set is available. example link: https://edx.corp.amazon.com/providers/cmls-raw-hourly-data/subjects/fortress-retail/datasets/na-1
                )
            ),
            DataSource( 
                data_source_name='TAGS',
                data_source_type='ANDES', # this is an example of Andes data source
                andes_data_source_properties=AndesDataSourceProperties(
                    provider='26b27bde-3847-49c6-a07c-0289c17d9c33',
                    table_name='fraud-tags-eu',
                )
            ),
        ]
    ), 
    transform_specification=TransformSpecification( # transformSQL should refer the above data source name to query the data
        transform_sql="""
        select 
            RAW_MDS.*, 
            TAGS.is_frd AS IS_FRD
        from RAW_MDS 
        left join TAGS 
            on RAW_MDS.objectId=TAGS.order_id  
            AND TAGS.order_day_timestamp >= TO_TIMESTAMP('${startDate}', 'yyyy-MM-dd')  
            AND TAGS.order_day_timestamp <= TO_TIMESTAMP('${endDate}', 'yyyy-MM-dd')
        where TAGS.is_frd!=-1 and rand()<0.15 and (
            ((marketplaceCountryCode in ('GB','DE','FR','IT','ES') and isQueued=1) or 
            (marketplaceCountryCode in ('TR','NL','SE','SA','AE','EG','PL','BE') and isSidelined=1)) 
        ) 
        """,
        job_split_options=JobSplitOptions( 
            split_job=validation_split_job, # You can enable job split option by changing this function to True, but you need provide merge_sql. INPUT will the all the data after split executes, you can write extra logic in SQL, e.g. using group by for statistics or dedup.  
            days_per_split=7,
            merge_sql="""
                WITH data AS (
                    SELECT INPUT.*,
                        TO_TIMESTAMP(INPUT.transactionDate, 'EEE MMM dd HH:mm:ss zzz yyyy') AS transactionDateV2
                    FROM INPUT
                    ),
                dedup AS (
                    SELECT *
                    FROM (
                        SELECT *, ROW_NUMBER() OVER (PARTITION BY objectId ORDER BY transactionDateV2 DESC) AS __rownum__
                        FROM data
                        )
                    WHERE __rownum__ = 1
                    )
                select * from dedup
                where rand() < (select 8000000 / count(*) from dedup)
            """
            )
    ),
    output_specification=OutputSpecification(
        # output_schema=list(pullVars)+['IS_FRD'],  # output_schema should be provided as the final output fields. 
        output_schema=var_list+['IS_FRD'],
        output_path=output_path, # 
        output_format='UNESCAPED_TSV', #     # output format can be CSV, UNESCAPED_TSV, JSON, ION, PARQUET. CSV is the default format if you don't specify it
        output_save_mode='ERRORIFEXISTS',   # output save mode can setup to support different case, it can be OVERWRITE, ERRORIFEXISTS, APPEND, IGNORE. In default it's ERRORIFEXISTS. ",
        output_file_count=0, # output file count can be set to reduce or increase final the number of files. Too many output files will cause S3 throttling failure; Too few output will encounter performance issues. current setting is 30 per day, you can provide the overrides for your overrides.
        keep_dot_in_output_schema=True, # When set to true, the output file header will contain normal the '.'. Otherwise when set to False, the output file header will replace every '.' with '__DOT__'.
        # edit to test 
        include_header_in_s3_output=True # When set to true, the s3 output file will include header. Note that only S3 supports output with header.
    ),
    cradle_job_specification=CradleJobSpecification(
        cluster_type='LARGE', 
        cradle_account='BRP-ML-Payment-Generate-Data',
        extra_spark_job_arguments='', # you can customize the spark job driver memory if you need by vending parameters here
        job_retry_count=4, # job retry count in case of failure, in default Cradle will retry once if it fails. you can customize retry times. 
    )
)

# cradle_loading_request_dict = coral_utils.convert_coral_to_dict(request)
cradle_validation_request_dict = coral_utils.convert_coral_to_dict(request_validation)

In [None]:
request_calibration = CreateCradleDataLoadJobRequest(
    data_sources=DataSourcesSpecification(
        start_date=calibration_start_date, # data start date
        end_date=calibration_end_date, # data end date
        data_sources = [ # data sources a list of data source properties
            DataSource( 
                data_source_name='RAW_MDS',  # data source name, it should be uniq across the list of data source. this name should be used as table name when you write the SQL
                data_source_type='MDS', # data source type, it can be 'MDS/ANDES/EDX', you need setup the properties according to this type
                mds_data_source_properties=MdsDataSourceProperties( #
                    service_name='FORTRESS_RETAIL',
                    org_id='2',
                    region='EU',
                    # output_schema=[Field(field_name=f, field_type='STRING') for f in pullVars],
                    output_schema=[Field(field_name=f, field_type='STRING') for f in var_list],
                    use_hourly_edx_data_set=False, # MDS/EDX have another data set which merges the raw manifest. you can change this to True to use hourly data provider which can reduce the hot data set's throttling issue. hourly EDX data provider doesn't contain all the data, you need verify and make sure the hourly data set is available. example link: https://edx.corp.amazon.com/providers/cmls-raw-hourly-data/subjects/fortress-retail/datasets/na-1
                )
            ),
            DataSource( 
                data_source_name='TAGS',
                data_source_type='ANDES', # this is an example of Andes data source
                andes_data_source_properties=AndesDataSourceProperties(
                    provider='26b27bde-3847-49c6-a07c-0289c17d9c33',
                    table_name='fraud-tags-eu',
                )
            ),
        ]
    ), 
    transform_specification=TransformSpecification( # transformSQL should refer the above data source name to query the data
        transform_sql="""
        select 
            RAW_MDS.*, 
            TAGS.is_frd AS IS_FRD
        from RAW_MDS 
        left join TAGS 
            on RAW_MDS.objectId=TAGS.order_id  
            AND TAGS.order_day_timestamp >= TO_TIMESTAMP('${startDate}', 'yyyy-MM-dd')  
            AND TAGS.order_day_timestamp <= TO_TIMESTAMP('${endDate}', 'yyyy-MM-dd')
        where rand() < 0.06 and (
            ((marketplaceCountryCode in ('GB','DE','FR','IT','ES') and isQueued=1) or 
            (marketplaceCountryCode in ('TR','NL','SE','SA','AE','EG','PL','BE') and isSidelined=1)) 
        ) 
        """,
        job_split_options=JobSplitOptions( 
            split_job=calibration_split_job, # You can enable job split option by changing this function to True, but you need provide merge_sql. INPUT will the all the data after split executes, you can write extra logic in SQL, e.g. using group by for statistics or dedup.  
            days_per_split=7,
            merge_sql="""
                WITH data AS (
                    SELECT INPUT.*,
                        TO_TIMESTAMP(INPUT.transactionDate, 'EEE MMM dd HH:mm:ss zzz yyyy') AS transactionDateV2
                    FROM INPUT
                    ),
                dedup AS (
                    SELECT *
                    FROM (
                        SELECT *, ROW_NUMBER() OVER (PARTITION BY objectId ORDER BY transactionDateV2 DESC) AS __rownum__
                        FROM data
                        )
                    WHERE __rownum__ = 1
                    )
                select * from dedup
                where rand() < (select 8000000 / count(*) from dedup)
            """
            )
    ),
    output_specification=OutputSpecification(
        # output_schema=list(pullVars)+['IS_FRD'],  # output_schema should be provided as the final output fields. 
        output_schema=var_list+['IS_FRD'],
        output_path=output_path, # 
        output_format='UNESCAPED_TSV', #     # output format can be CSV, UNESCAPED_TSV, JSON, ION, PARQUET. CSV is the default format if you don't specify it
        output_save_mode='ERRORIFEXISTS',   # output save mode can setup to support different case, it can be OVERWRITE, ERRORIFEXISTS, APPEND, IGNORE. In default it's ERRORIFEXISTS. ",
        output_file_count=0, # output file count can be set to reduce or increase final the number of files. Too many output files will cause S3 throttling failure; Too few output will encounter performance issues. current setting is 30 per day, you can provide the overrides for your overrides.
        keep_dot_in_output_schema=True, # When set to true, the output file header will contain normal the '.'. Otherwise when set to False, the output file header will replace every '.' with '__DOT__'.
        # edit to test 
        include_header_in_s3_output=True # When set to true, the s3 output file will include header. Note that only S3 supports output with header.
    ),
    cradle_job_specification=CradleJobSpecification(
        cluster_type='LARGE', 
        cradle_account='BRP-ML-Payment-Generate-Data',
        extra_spark_job_arguments='', # you can customize the spark job driver memory if you need by vending parameters here
        job_retry_count=4, # job retry count in case of failure, in default Cradle will retry once if it fails. you can customize retry times. 
    )
)

# cradle_loading_request_dict = coral_utils.convert_coral_to_dict(request)
cradle_calibration_request_dict = coral_utils.convert_coral_to_dict(request_calibration)

#### Load files to prepare the model registration

In [None]:
input_var_dict={}
for var in input_data_seq_cat_otf_vars:
    input_var_dict[var]="TEXT"
for var in input_data_seq_cat_vars:
    input_var_dict[var]="TEXT"
for var in input_data_seq_num_otf_vars:
    input_var_dict[var]="TEXT"
for var in input_data_seq_num_vars:
    input_var_dict[var]="NUMERIC"
for var in input_data_dense_num_vars:
    input_var_dict[var]="NUMERIC"   
input_var_dict['objectId']="TEXT"   
input_var_dict['orderDate']="TEXT"    
for var in numerical_cat_vars:
    input_var_dict[var]="NUMERIC" 
del input_var_dict['objectId']

In [None]:
output_var_dict = {"score-percentile": 'NUMERIC',"legacy-score": 'NUMERIC',"ProbabilityScore": 'NUMERIC'}

In [None]:
sample_payload_s3_bucket = "sandboxuserdependency-maxueyu-personals3bucket-ysvmoa568sen" 
sample_payload_s3_key = "EUTSAModel/pzkwcif1mual/AddInferenceDependencies/payload.tar.gz"

In [None]:
# model_registration_config = {
#     "model_domain": "integration-test",
#     "model_objective": "TestObjective",
#     "source_model_inference_content_types": ['application/json'],
#     "source_model_inference_response_types": ["application/json"],
#     "source_model_inference_input_variable_list": input_var_dict,
#     "source_model_inference_output_variable_list": output_var_dict,
#     "model_registration_region": "NA",
#     "source_model_inference_image_arn": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.1.0-cpu-py310",
#     "source_model_region": "us-east-1",
#     "model_owner": "amzn1.abacus.team.5y3aajyhecgqmg6rjxga",
#     "source_model_environment_variable_map": {
#         "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
#         "SAGEMAKER_PROGRAM": "pytorch_inference_handler.py",
#         "SAGEMAKER_REGION": "us-east-1"
#     },
#     "load_testing_info_map": {
#         "expected_tps": 100,
#         "max_latency_in_millisecond": 100,
#         "sample_payload_s3_bucket": sample_payload_s3_bucket,
#         "sample_payload_s3_key": sample_payload_s3_key,
#         "instance_type_list": ["ml.m5.xlarge"],
#         # Maximum error rate load test will accept, test will fail if error rate is higher than the number
#         "max_acceptable_error_rate": 0.2
#     }
# }

In [None]:
model_registration_config = {
    "model_domain": "FORTRESS_RETAIL",
    "model_objective": "EUTSASuspectQueueModel",
    "source_model_inference_content_types": ['application/json'],
    "source_model_inference_response_types": ["application/json"],
    "source_model_inference_input_variable_list": input_var_dict,
    "source_model_inference_output_variable_list": output_var_dict,
    "model_registration_region": "EU",
    "source_model_inference_image_arn": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/pytorch-inference:2.1.0-cpu-py310",
    "source_model_region": "eu-west-1",
    "model_owner": "amzn1.abacus.team.5y3aajyhecgqmg6rjxga",
    "source_model_environment_variable_map": {
        "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
        "SAGEMAKER_PROGRAM": "pytorch_inference_handler.py",
        "SAGEMAKER_REGION": "eu-west-1"
    },
    "load_testing_info_map": {
        "expected_tps": 100,
        "max_latency_in_millisecond": 100,
        "sample_payload_s3_bucket": sample_payload_s3_bucket,
        "sample_payload_s3_key": sample_payload_s3_key,
        "instance_type_list": ["ml.m5.xlarge"],
        # Maximum error rate load test will accept, test will fail if error rate is higher than the number
        "max_acceptable_error_rate": 0.2
    }
}

#### Overwrite the default execution document

In [None]:
from mods_workflow_helper.sagemaker_pipeline_helper import SagemakerPipelineHelper, SecurityConfig

default_execution_doc = SagemakerPipelineHelper.get_pipeline_default_execution_document(pipeline)
test_execution_doc = default_execution_doc

test_execution_doc['PIPELINE_STEP_CONFIGS']['Training_Data_Download'] = {}
test_execution_doc['PIPELINE_STEP_CONFIGS']['Validation_Data_Download'] = {}
test_execution_doc['PIPELINE_STEP_CONFIGS']['Calibration_Data_Download'] = {}

test_execution_doc['PIPELINE_STEP_CONFIGS']['Training_Data_Download']['STEP_CONFIG'] = cradle_training_request_dict
test_execution_doc['PIPELINE_STEP_CONFIGS']['Validation_Data_Download']['STEP_CONFIG'] = cradle_validation_request_dict
test_execution_doc['PIPELINE_STEP_CONFIGS']['Calibration_Data_Download']['STEP_CONFIG'] = cradle_calibration_request_dict

test_execution_doc['PIPELINE_STEP_CONFIGS']['MimsModelRegistrationProcessingStep'] = {}
test_execution_doc['PIPELINE_STEP_CONFIGS']['MimsModelRegistrationProcessingStep']['STEP_CONFIG'] = model_registration_config

In [None]:
test_execution_doc

In [None]:
from mods_workflow_helper.sagemaker_pipeline_helper import SagemakerPipelineHelper

SagemakerPipelineHelper.start_pipeline_execution(
    pipeline=pipeline,
    secure_config=security_config,
    sagemaker_session=session,
    preparation_space_local_root="/tmp",
    pipeline_execution_document=test_execution_doc
)

## Manually Set Scaling Policy

In [None]:
# from secure_ai_sandbox_python_lib.session import Session
#
# # Initialize sandbox_session
# sandbox_session = Session(session_folder='/tmp/temp_folder', retail_region='EU')
#
# # Create the MIMS resource
# mims = sandbox_session.resource('MIMSModelRegistrar')
#
# # the Scaling Policy of the Endpoint
# endpoint_scaling_config_map = {
#             "instance_type": "ml.m5.xlarge",
#             "min_capacity": 18,
#             "max_capacity": 36,
#             "scaling_policy_map": {
#                 "version": "1.0",
#                 "target_value": 1000.0, #// TPM traffic per minutes 6000TPM = 100TPS
#                 "scale_in_cooldown": 300, #// seconds wait for delete instance
#                 "scale_out_cooldown": 60, #// seconds wait for add instance
#                 "disable_scale_in": False
#             }
#         }
#
#
# model_domain='FORTRESS_RETAIL'
# model_objective='EUTSASuspectQueueModel'
# model_id = "2024-02-14-68203-rapid-cork"
#
# response = mims.set_scaling_policy(model_region="EU",
#                                     model_domain=model_domain,
#                                     model_objective=model_objective,
#                                     model_id=model_id,
#                                     endpoint_scaling_config_map=endpoint_scaling_config_map)
#
# response