In [None]:
from sagemaker_core.main.shapes import TrainingJob

from sagemaker import Session, get_execution_role

sagemaker_session = Session()
role = get_execution_role()
region = sagemaker_session.boto_region_name
bucket = sagemaker_session.default_bucket()

In [None]:

from sagemaker.modules.configs import SourceCode
from sagemaker.modules.train.model_trainer import ModelTrainer

xgboost_image = "433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest"

source_code = SourceCode(
    command="echo 'Hello World' && env",
)
model_trainer = ModelTrainer(
    training_image=xgboost_image,
    source_code=source_code,
)

model_trainer.train()

In [None]:
import numpy as np
from sagemaker.serve.builder.schema_builder import SchemaBuilder
import pandas as pd
from xgboost import XGBClassifier
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve import ModelBuilder

data = {
    'Name': ['Alice', 'Bob', 'Charlie']
}
df = pd.DataFrame(data)
schema_builder = SchemaBuilder(sample_input=df, sample_output=df)


class XGBoostSpec(InferenceSpec):
    def load(self, model_dir: str):
        print(model_dir)
        model = XGBClassifier()
        model.load_model(model_dir + "/xgboost-model")
        return model

    def invoke(self, input_object: object, model: object):
        prediction_probabilities = model.predict_proba(input_object)
        predictions = np.argmax(prediction_probabilities, axis=1)
        return predictions

model_builder = ModelBuilder(
    model=model_trainer, # ModelTrainer object passed onto ModelBuilder directly 
    role_arn=role,
    image_uri=xgboost_image,
    inference_spec=XGBoostSpec(),
    schema_builder=schema_builder,
    instance_type="ml.c6i.xlarge"
)
model=model_builder.build()
predictor=model_builder.deploy()

predictor
assert model.model_data == model_trainer._latest_training_job.model_artifacts.s3_model_artifacts

print(model.model_data)

In [None]:
training_job: TrainingJob = model_trainer._latest_training_job

model_builder = ModelBuilder(
    model=training_job, # Sagemaker core's TrainingJob object passed onto ModelBuilder directly 
    role_arn=role,
    image_uri=xgboost_image,
    schema_builder=schema_builder,
    inference_spec=XGBoostSpec(),
    instance_type="ml.c6i.xlarge"
)
model=model_builder.build()

assert model.model_data == training_job.model_artifacts.s3_model_artifacts

print(model.model_data)