# Train a Scikit-Learn model in SageMaker and track with MLFlow

## Intro

The main objective of this notebook is to show how you can securely interact with a MLflow server using Amazon SageMaker Studio and SageMaker Trainings.
This notebook is meant to be used with the SageMaker Studio user profile `mlflow-admin` and the MLflow server created by [this CDK deployment](https://github.com/aws-samples/sagemaker-studio-mlflow-integration.git) since it ensures the user has the right permissions in place to execute the lab.
We will train a model in SageMaker, but use MLflow to track the experiments, register the model, and we will then deploy to the SageMaker managed infrastructure the resulting model.

## Pre-Requisites

* Successfully deployed the CDK sample in [this repository](https://github.com/aws-samples/sagemaker-studio-mlflow-integration.git).
* Access  to the `mlflow-admin` user profile in the created SageMaker Studio domain and use the `Base Python 2.0` image on a `Python 3` kernel.

## The Machine Learning Problem

In this example, we will solve a regression problem which aims to answer the question: "what is the expected price of a house in the California area?".
The target variable is the house value for California districts, expressed in hundreds of thousands of dollars ($100,000).

## Install required and/or update libraries

At the time of writing, we have used the `sagemaker` SDK version 2. The MLFlow SDK library used is the one corresponding to our MLflow server version, i.e., `2.8.0`.
We install the `mlflow[gateway]==2.8.0` to ensure that all required dependencies are installed.

In [None]:
%pip install -q --upgrade pip setuptools wheel
%pip install -q sagemaker 
%pip install -q requests_auth_aws_sigv4 boto3 mlflow[gateway]==2.8.0
%pip install -q langchain==0.0.332

Let's start by specifying:

- The S3 bucket and prefix that you want to use for training and model data.  This should be within the same region as the notebook instance, training, and hosting.
- The IAM role arn associated with the user profile (`sagemaker.get_execution_role()`) which we will use to train in SageMaker, track the experiment in MLflow, register a model in MLflow, and host a MLflow model in SageMaker. See the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/using-identity-based-policies.html) for more details on creating these.  The sagemaker execution role associated with the user profile `mlflow-admin` will have the appropriate permissions to do all these operations.
- The tracking URI where the MLFlow server runs
- The experiment name as the logical entity to keep our tests grouped and organized.

If you examine the SageMaker execution role of the `mlflow-admin`, you will note that it has a in-line policy attached called `restApiAdmin` grating admin permissions on all resources and methods on the REST API Gateway shielding MLflow and it looks like the following:

```json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Action": "execute-api:Invoke",
            "Resource": "arn:aws:execute-api:<AWS_REGION>:<AWS_ACCOUNT>:<REST_API_GATEWAY_ID>/*/*/*",
            "Effect": "Allow"
        }
    ]
}
```

In [None]:
import os
import pandas as pd
import json
import random
import boto3
import logging

## SageMaker and SKlearn libraries
import sagemaker

## MLFlow libraries
import mlflow
from mlflow.tracking.client import MlflowClient
import mlflow.sagemaker

logging.getLogger("mlflow").setLevel(logging.INFO)

ssm = boto3.client('ssm')

sess = sagemaker.Session()
bucket = sess.default_bucket()
region = sess.boto_region_name
tracking_uri = ssm.get_parameter(Name="mlflow-restApiUrl")['Parameter']['Value']
api_gw_id = tracking_uri.split('//')[1].split('.')[0]
experiment_name = 'DEMO-sigv4'
model_name = 'california-housing-model'

print("Tracking URI: {}".format(tracking_uri))
print('bucket: {}'.format(bucket))
print("Using AWS Region: {}".format(region))
print("MLflow server URI: {}".format(tracking_uri))

## Test MLFlow server accessibility

### Without using SigV4 (no env variable set) - should fail

Uncomment this cell below to try the MLflow SDK without the environmental variable `MLFLOW_TRACKING_AWS_SIGV4` set and verify you cannot interact with the MLflow server.

In [None]:
# try:
#     del os.environ['MLFLOW_TRACKING_AWS_SIGV4']
# except:
#     print('env variable not there')
# mlflow.set_tracking_uri(tracking_uri)
# mlflow.set_experiment(experiment_name)

### With env variable set: should succeed is the sagemaker execution role has permission to call the MLFlow endpoint

In [None]:
os.environ['MLFLOW_TRACKING_AWS_SIGV4'] = "True"
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)

## MLFlow Server access

In [None]:
!python -m requests_auth_aws_sigv4 https://{api_gw_id}.execute-api.{region}.amazonaws.com/prod/api/2.0/mlflow/experiments/get?experiment_id=0 -v

## MLflow Gateway access

In [None]:
!python -m requests_auth_aws_sigv4 https://{api_gw_id}.execute-api.{region}.amazonaws.com/prod/api/2.0/gateway/routes/ -v

# MLflow Gateway AI

## Client API

In [None]:
from mlflow.gateway import MlflowGatewayClient

gateway_client = MlflowGatewayClient(tracking_uri)

### List all routes

In [None]:
routes = gateway_client.search_routes()
for route in routes:
    print(route)
route_name = route.name

### Query a route

In [None]:
import json

response = gateway_client.query(
    route_name,  data={'prompt':'Tell me a funny story about a fish'}
)

json_formatted = json.dumps(response, indent=1)
print(json_formatted)

## Fluent API

In [None]:
from mlflow.gateway import query, set_gateway_uri

set_gateway_uri(gateway_uri=tracking_uri)

response = query(
    route_name,
    data={'prompt':'Tell me a funny story about a fish'},
)

json_formatted = json.dumps(response, indent=1)
print(json_formatted)

## Langchain

In [None]:
import mlflow
from langchain import LLMChain, PromptTemplate
from langchain.llms import MlflowAIGateway

gateway = MlflowAIGateway(
    gateway_uri=tracking_uri,
    route=route_name,
    params={
        "temperature": 0.0,
    },
)

llm_chain = LLMChain(
    llm=gateway,
    prompt=PromptTemplate(
        input_variables=["adjective"],
        template="Tell me a {adjective} joke",
    ),
)
result = llm_chain.run(adjective="funny")
print(result)

with mlflow.start_run():
    model_info = mlflow.langchain.log_model(llm_chain, "model")

model = mlflow.pyfunc.load_model(model_info.model_uri)
print(model.predict([{"adjective": "funny"}]))