# Register ML Model using SageMaker
After training an ML model, you might want to evaluate its performance and have it reviewed by a data scientist or MLOps engineer in your organization before using it in production. To do this, you can register your model versions to the SageMaker model registry. The SageMaker model registry is a repository that data scientists or engineers can use to catalog machine learning (ML) models and manage model versions and their associated metadata, such as training metrics. They can also manage and log the approval status of a model.

After you register your model versions to the SageMaker model registry, a data scientist or your MLOps team can access the SageMaker model registry through SageMaker Studio. Additionally, the data scientist or MLOps team can evaluate your model and update its approval status. If the model doesn’t perform to their requirements, the data scientist or MLOps team can update the status to `Rejected`. If the model does perform to their requirements, then the data scientist or MLOps team can update the status to `Approved`. Then, they can deploy your model to an endpoint or automate model deployment with CI/CD pipelines. 

You can use the SageMaker model registry feature to seamlessly integrate models with the MLOps processes in your organization.

The following diagram summarizes an example of registering a model version built in SageMaker Studio to the SageMaker model registry for integration into an MLOps workflow.

![register model](img/sagemaker-mlops-register-model-diagram.jpg)

## Model Registration
With a model registry, you can version control the model that you train, similar to version control software code. When version control is enabled, you can track the performance of the model over time and make informed decisions about using the best model to serve in your production environment. The SageMaker Model Registry is structured as several Model (Package) Groups with model packages in each group. These Model Groups can optionally be added to one or more Collections. Each model package in a Model Group corresponds to a trained model. The version of each model package is a numerical value that starts at 1 and is incremented with each new model package added to a Model Group. For example, if 5 model packages are added to a Model Group, the model package versions will be 1, 2, 3, 4, and 5.  

The following diagram depicts how model versioning is organized in SageMaker Model Registry:

![sm model registry](img/sagemaker-model-registry-diagram.jpg)


In [None]:
%pip install sagemaker mlflow==2.13.2 sagemaker-mlflow==0.1.0

Import relevant libraries

In [None]:
import json
import os
import sagemaker
import boto3
import mlflow
from time import gmtime, strftime
from sagemaker.model import Model
from sagemaker.model_metrics import (
    MetricsSource,
    ModelMetrics,
    FileSource
)
from sagemaker import Model
from sagemaker.model_card.model_card import ModelCard, TrainingDetails, TrainingJobDetails, ModelOverview
from botocore.exceptions import ClientError

# Define Helper Functions

In [None]:
def download_from_s3(s3_client, local_file_path, bucket_name, s3_file_path):
    try:
        # Download the file
        s3_client.download_file(bucket_name, s3_file_path, local_file_path)
        print(f"File downloaded successfully to {local_file_path}")
        return True
    except ClientError as e:
        if e.response['Error']['Code'] == "404":
            print("The object does not exist.")
        else:
            print(f"An error occurred: {e}")
        return False
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return False

def upload_to_s3(s3_client, local_file_path, bucket_name, s3_file_path=None):
    # If S3 file path is not specified, use the basename of the local file
    if s3_file_path is None:
        s3_file_path = os.path.basename(local_file_path)

    try:
        # Upload the file
        s3_client.upload_file(local_file_path, bucket_name, s3_file_path)
        print(f"File {local_file_path} uploaded successfully to {bucket_name}/{s3_file_path}")
        return True
    except ClientError as e:
        print(f"ClientError: {e}")
        return False
    except FileNotFoundError:
        print(f"The file {local_file_path} was not found")
        return False
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return False
        
def write_params(s3_client, step_name, params, notebook_param_s3_bucket_prefix):
    local_file_path = f"{step_name}.json"
    with open(local_file_path, "w") as f:
        f.write(json.dumps(params))
    base_local_file_path = os.path.basename(local_file_path)
    bucket_name = notebook_param_s3_bucket_prefix.split("/")[2] # Format: s3://<bucket_name>/..
    s3_file_path = os.path.join("/".join(notebook_param_s3_bucket_prefix.split("/")[3:]), base_local_file_path)
    upload_to_s3(s3_client, local_file_path, bucket_name, s3_file_path)
    
def read_params(s3_client, notebook_param_s3_bucket_prefix, step_name):
    local_file_path = f"{step_name}.json"
    base_local_file_path = os.path.basename(local_file_path)
    bucket_name = notebook_param_s3_bucket_prefix.split("/")[2] # Format: s3://<bucket_name>/..
    s3_file_path = os.path.join("/".join(notebook_param_s3_bucket_prefix.split("/")[3:]),  base_local_file_path)
    downloaded = download_from_s3(s3_client, local_file_path, bucket_name, s3_file_path)
    with open(local_file_path, "r") as f:
        data = f.read()
        params = json.loads(data)
    return params


# Initialize Variables
Similar to the previous notebooks in this repository, the variables defined in the following cell are specifically used throughout this notebook. In addition to the hardcoded values, these variables can be passed into the notebook as parameters when the notebook is scheduled to run remotely, such as a SageMaker Pipeline job, or a CICD pipeline through SageMaker Project. We'll dive into how to pass parameters into this notebook in the next lab. Please refer to [this](https://docs.aws.amazon.com/sagemaker/latest/dg/notebook-auto-run-troubleshoot-override.html) documentation for more information notebook parameterization.

Similar to `02-preprocess.ipynb` notebook, the following variables can be obtained via SageMaker Studio launcher. Instructions and screenshots are provided in the notebook to guide you through it if you need additional assistance. 

In [None]:
region = "us-east-1"
os.environ["AWS_DEFAULT_REGION"] = region
boto_session = boto3.Session(region_name=region)
sess = sagemaker.Session(boto_session=boto_session)
bucket_name = sess.default_bucket()
bucket_prefix = "player-churn/xgboost"
notebook_param_s3_bucket_prefix=f"s3://{bucket_name}/{bucket_prefix}/params"
experiment_name = "player-churn-model-experiment"
run_id = None
model_package_group_name = "player-churn-model-group" # Provide a new model package group name. For example: player-churn-model-group
mlflow_tracking_server_arn = "" # Provide a valid mlflow tracking server ARN. You can find the value in the output from 00-start-here.ipynb
model_approval_status = "PendingManualApproval"
model_statistics_s3_path = None
model_constraints_s3_path = None
model_data_statistics_s3_path = None
model_data_constraints_s3_path = None

In [None]:
assert len(model_package_group_name) > 0
assert len(mlflow_tracking_server_arn) > 0

Retrieves step variables from previous notebooks.

In [None]:
preprocess_step_name = "02-preprocess"
train_step_name = "03-train"
evaluation_step_name = "04-evaluation"

s3_client = boto3.client("s3", region_name=region)
preprocess_step_params = read_params(s3_client, notebook_param_s3_bucket_prefix, preprocess_step_name)
train_step_params = read_params(s3_client, notebook_param_s3_bucket_prefix, train_step_name)
evaluation_step_params = read_params(s3_client, notebook_param_s3_bucket_prefix, evaluation_step_name)
experiment_name = preprocess_step_params["experiment_name"]

The following cell integrates MLFlow tracking server with this model registry job.

In [None]:
suffix = strftime('%d-%H-%M-%S', gmtime())
mlflow.set_tracking_uri(mlflow_tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name=experiment_name)
run = mlflow.start_run(run_id=run_id) if run_id else mlflow.start_run(run_name=f"register-{suffix}", nested=True)

Download the model evaluation metrics from `04-evaluation` step. The metrics will provide the performance characteristics for the particular model registered in the SageMaker model registry.

In [None]:
local_file_path = "evaluation.json"
evaluation_result_s3_path = evaluation_step_params["evaluation_result_s3_path"]
s3_file_path = "/".join(evaluation_result_s3_path.split("/")[3:])
evaluation_result_bucket_name = evaluation_result_s3_path.split("/")[2]

In [None]:
download_from_s3(s3_client, local_file_path, bucket_name, s3_file_path)

In [None]:
mlflow.log_artifact(local_path=local_file_path)

Captures model baseline metrics if available

In [None]:
model_metrics = ModelMetrics(
    model_statistics=MetricsSource(
        s3_uri=model_statistics_s3_path,
        content_type="application/json",
    ) if model_statistics_s3_path else None,
    model_constraints=MetricsSource(
        s3_uri=model_constraints_s3_path,
        content_type="application/json",
    ) if model_constraints_s3_path else None,
    model_data_statistics=MetricsSource(
        s3_uri=model_data_statistics_s3_path,
        content_type="application/json",
    ) if model_data_statistics_s3_path else None,
    model_data_constraints=MetricsSource(
        s3_uri=model_data_constraints_s3_path,
        content_type="application/json",
    ) if model_data_constraints_s3_path else None,
)



Using the model information from the training job, we could collect the model artifact details from the SageMaker training job for model registration process.

In [None]:
XGBOOST_IMAGE_URI = sagemaker.image_uris.retrieve(
            "xgboost",
            region=boto3.Session().region_name,
            version="1.7-1"
)
model_data = train_step_params["model_s3_path"]
model_role = sagemaker.get_execution_role()
model_name = "player-churn-model"

In [None]:
model = Model(image_uri=XGBOOST_IMAGE_URI, sagemaker_session=sess, model_data=model_data, role=model_role, name=model_name)
training_details = TrainingDetails.from_model_s3_artifacts(
                model_artifacts=[model_data], sagemaker_session=sess
            )
training_job_details = training_details.training_job_details
training_datasets = [preprocess_step_params["train_data"], preprocess_step_params["validation_data"]]
training_job_details.training_datasets = training_datasets

model_card = ModelCard(
    name="estimator_card",
    training_details=training_details,
    sagemaker_session=sess,
)

model_overview = ModelOverview(model_artifact=[model_data])
model_card.model_overview = model_overview

## Register Model package group using a SageMaker Model Object

In [None]:
model_package = model.register(
    content_types=["text/csv"],
    response_types=["text/csv"],
    inference_instances=["ml.m5.xlarge", "ml.m5.large"],
    transform_instances=["ml.m5.xlarge", "ml.m5.large"],
    model_package_group_name=model_package_group_name,
    approval_status=model_approval_status,
    model_metrics=model_metrics,
    domain="MACHINE_LEARNING",
    task="CLASSIFICATION",
    model_card=model_card
)

# SageMaker Model Package Group
To visualize the newly registered model, navigate to SageMaker Studio Launcher, select `Models` in the left pane, the model group will be shown on the right pane. The following diagram shows a new version of the XGBoost model created in the given model package group from SageMaker Studio Console. 

![sagemaker model registry](img/sagamaker-model-registry-diagram.jpg)

In [None]:
mlflow.log_params({
    "model_package_arn":model_package.model_package_arn,
    "model_statistics_uri":model_statistics_s3_path if model_statistics_s3_path else '',
    "model_constraints_uri":model_constraints_s3_path if model_constraints_s3_path else '',
    "data_statistics_uri":model_data_statistics_s3_path if model_data_statistics_s3_path else '',
    "data_constraints_uri":model_data_constraints_s3_path if model_data_constraints_s3_path else '',
})