# Model STAGING validation JOB

This notebook execution is automatically triggered using MLFLow webhook. It's defined as a **job** and will programatically validate the model before moving it to STAGING.

<img src="https://github.com/QuentinAmbard/databricks-demo/raw/main/product_demos/mlops-end2end-flow-5.png" width="1200">

<!-- Collect usage data (view). Remove it to disable collection. View README for more details.  -->
<img width="1px" src="https://ppxrzfxige.execute-api.us-west-2.amazonaws.com/v1/analytics?category=data-science&org_id=1549883858499596&notebook=%2F05_job_staging_validation&demo_name=mlops-end2end&event=VIEW&path=%2F_dbdemos%2Fdata-science%2Fmlops-end2end%2F05_job_staging_validation&version=1">
<!-- [metadata={"description":"MLOps end2end workflow: Trigger Model testing and validation job.",
 "authors":["quentin.ambard@databricks.com"],
 "db_resources":{},
  "search_tags":{"vertical": "retail", "step": "Model testing", "components": ["mlflow"]},
                 "canonicalUrl": {"AWS": "", "Azure": "", "GCP": ""}}] -->

### A cluster has been created for this demo
To run this demo, just select the cluster `dbdemos-mlops-end2end-shawnzou2020` from the dropdown menu ([open cluster configuration](https://dbc-abdbb8e0-f50f.cloud.databricks.com/#setting/clusters/0410-014028-ndqe9et5/configuration)). <br />
*Note: If the cluster was deleted after 30 days, you can re-create it with `dbdemos.create_cluster('mlops-end2end')` or re-install the demo: `dbdemos.install('mlops-end2end')`*


## General Validation Checks

<img style="float: right" src="https://github.com/QuentinAmbard/databricks-demo/raw/main/retail/resources/images/churn-mlflow-webhook-1.png" width=600 >

In the context of MLOps, there are more tests than simply how accurate a model will be.  To ensure the stability of our ML system and compliance with any regulatory requirements, we will subject each model added to the registry to a series of validation checks.  These include, but are not limited to:
<br><br>
* __Inference on production data__
* __Input schema ("signature") compatibility with current model version__
* __Accuracy on multiple slices of the training data__
* __Model documentation__

In this notebook we explore some approaches to performing these tests, and how we can add metadata to our models with tagging if they have passed a given test or not.

This part is typically specific to your line of business and quality requirement.

For each test, we'll add information using tags to know what has been validated in the model. We can also add Comments if needed.

In [0]:
%run ./_resources/00-setup $reset_all_data=false $catalog="hive_metastore"



USE CATALOG `hive_metastore`
using cloud_storage_path /Users/quentin.ambard@databricks.com/demos/retail
using catalog.database `hive_metastore`.`retail_quentin_ambard`


## Fetch Model information

Remember how webhooks can send data from one webservice to another?  With MLflow webhooks we send data about a model, and in the following cell we fetch that data to know which model is meant to be tested. 

This is be done getting the `event_message` received by MLFlow webhook: `dbutils.widgets.get('event_message')`

To keep things simple we use a helper function `fetch_webhook_data`, the details of which are found in the _API_Helpers_ notebook.  

In [0]:
# Get the model in transition, its name and version from the metadata received by the webhook
model_name, model_version = fetch_webhook_data()

client = MlflowClient()
model_name = "dbdemos_mlops_churn"
model_details = client.get_model_version(model_name, model_version)
run_info = client.get_run(run_id=model_details.run_id)


#### Validate prediction

We want to test to see that the model can predict on production data.  So, we will load the model and the latest from the feature store and test making some predictions.

In [0]:
from databricks.feature_store import FeatureStoreClient

fs = FeatureStoreClient()

# Read from feature store 
data_source = run_info.data.tags['db_table']
features = fs.read_table(data_source)

# Load model as a Spark UDF
model_uri = f'models:/{model_name}/{model_version}'
loaded_model = mlflow.pyfunc.spark_udf(spark, model_uri=model_uri)

# Select the feature table cols by model input schema
input_column_names = loaded_model.metadata.get_input_schema().input_names()

# Predict on a Spark DataFrame
try:
  display(features.withColumn('predictions', loaded_model(*input_column_names)))
  client.set_model_version_tag(name=model_name, version=model_version, key="predicts", value=1)
except Exception: 
  print("Unable to predict on features.")
  client.set_model_version_tag(name=model_name, version=model_version, key="predicts", value=0)
  pass

2023/06/21 12:58:30 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'


customer_id,senior_citizen,tenure,monthly_charges,total_charges,churn,gender_female,gender_male,partner_no,partner_yes,dependents_no,dependents_yes,phone_service_no,phone_service_yes,multiple_lines_no,multiple_lines_no_phone_service,multiple_lines_yes,internet_service_dsl,internet_service_fiber_optic,internet_service_no,online_security_no,online_security_no_internet_service,online_security_yes,online_backup_no,online_backup_no_internet_service,online_backup_yes,device_protection_no,device_protection_no_internet_service,device_protection_yes,tech_support_no,tech_support_no_internet_service,tech_support_yes,streaming_tv_no,streaming_tv_no_internet_service,streaming_tv_yes,streaming_movies_no,streaming_movies_no_internet_service,streaming_movies_yes,contract_month_to_month,contract_one_year,contract_two_year,paperless_billing_no,paperless_billing_yes,payment_method_bank_transfer__automatic_,payment_method_credit_card__automatic_,payment_method_electronic_check,payment_method_mailed_check,predictions
7590-VHVEG,0,1,29.85,29.85,0,1,0,0,1,1,0,1,0,0,1,0,1,0,0,1,0,0,0,0,1,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,1,0,1.0
5575-GNVDE,0,34,56.95,1889.5,0,0,1,1,0,1,0,0,1,1,0,0,1,0,0,0,0,1,1,0,0,0,0,1,1,0,0,1,0,0,1,0,0,0,1,0,1,0,0,0,0,1,0.0
3668-QPYBK,0,2,53.85,108.15,1,0,1,1,0,1,0,0,1,1,0,0,1,0,0,0,0,1,0,0,1,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,0,1,0.0
7795-CFOCW,0,45,42.3,1840.75,0,0,1,1,0,1,0,1,0,0,1,0,1,0,0,0,0,1,1,0,0,0,0,1,0,0,1,1,0,0,1,0,0,0,1,0,1,0,1,0,0,0,0.0
9237-HQITU,0,2,70.7,151.65,1,1,0,1,0,1,0,0,1,1,0,0,0,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,0,1,0,1.0
9305-CDSKC,0,8,99.65,820.5,1,1,0,1,0,1,0,0,1,0,0,1,0,1,0,1,0,0,1,0,0,0,0,1,1,0,0,0,0,1,0,0,1,1,0,0,0,1,0,0,1,0,1.0
1452-KIOVK,0,22,89.1,1949.4,0,0,1,1,0,0,1,0,1,0,0,1,0,1,0,1,0,0,0,0,1,1,0,0,1,0,0,0,0,1,1,0,0,1,0,0,0,1,0,1,0,0,0.0
6713-OKOMC,0,10,29.75,301.9,0,1,0,1,0,1,0,1,0,0,1,0,1,0,0,0,0,1,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,0,1,0.0
7892-POOKP,0,28,104.8,3046.05,1,1,0,0,1,1,0,0,1,0,0,1,0,1,0,1,0,0,1,0,0,0,0,1,0,0,1,0,0,1,0,0,1,1,0,0,0,1,0,0,1,0,1.0
6388-TABGU,0,62,56.15,3487.95,0,0,1,1,0,0,1,0,1,1,0,0,1,0,0,0,0,1,0,0,1,1,0,0,1,0,0,1,0,0,1,0,0,0,1,0,1,0,1,0,0,0,0.0


#### Signature check

When working with ML models you often need to know some basic functional properties of the model at hand, such as “What inputs does it expect?” and “What output does it produce?”.  The model **signature** defines the schema of a model’s inputs and outputs. Model inputs and outputs can be either column-based or tensor-based. 

See [here](https://mlflow.org/docs/latest/models.html#signature-enforcement) for more details.

In [0]:
if not loaded_model.metadata.signature:
  print("This model version is missing a signature.  Please push a new version with a signature!  See https://mlflow.org/docs/latest/models.html#model-metadata for more details.")
  client.set_model_version_tag(name=model_name, version=model_version, key="has_signature", value=0)
else:
  client.set_model_version_tag(name=model_name, version=model_version, key="has_signature", value=1)


#### Demographic accuracy

How does the model perform across various slices of the customer base?

In [0]:
import numpy as np
features = features.withColumn('predictions', loaded_model(*input_column_names)).toPandas()
features['accurate'] = np.where(features.churn == features.predictions, 1, 0)

# Check run tags for demographic columns and accuracy in each segment
try:
  demographics = run_info.data.tags['demographic_vars'].split(",")
  slices = features.groupby(demographics).accurate.agg(acc = 'sum', obs = lambda x:len(x), pct_acc = lambda x:sum(x)/len(x))
  
  # Threshold for passing on demographics is 55%
  demo_test = "pass" if slices['pct_acc'].any() > 0.55 else "fail"
  
  # Set tags in registry
  client.set_model_version_tag(name=model_name, version=model_version, key="demo_test", value=demo_test)

  print(slices)
except KeyError:
  print("KeyError: No demographics_vars tagged with this model version.")
  client.set_model_version_tag(name=model_name, version=model_version, key="demo_test", value="none")
  pass

KeyError: No demographics_vars tagged with this model version.


#### Description check

Has the data scientist provided a description of the model being submitted?

In [0]:
# If there's no description or an insufficient number of charaters, tag accordingly
if not model_details.description:
  client.set_model_version_tag(name=model_name, version=model_version, key="has_description", value=0)
  print("Did you forget to add a description?")
elif not len(model_details.description) > 20:
  client.set_model_version_tag(name=model_name, version=model_version, key="has_description", value=0)
  print("Your description is too basic, sorry.  Please resubmit with more detail (40 char min).")
else:
  client.set_model_version_tag(name=model_name, version=model_version, key="has_description", value=1)

#### Artifact check
Has the data scientist logged supplemental artifacts along with the original model?

In [0]:
import os

# Create local directory 
local_dir = "/tmp/model_artifacts"
if not os.path.exists(local_dir):
    os.mkdir(local_dir)

# Download artifacts from tracking server - no need to specify DBFS path here
local_path = client.download_artifacts(run_info.info.run_id, "", local_dir)

# Tag model version as possessing artifacts or not
if not os.listdir(local_path):
  client.set_model_version_tag(name=model_name, version=model_version, key="has_artifacts", value=0)
  print("There are no artifacts associated with this model.  Please include some data visualization or data profiling.  MLflow supports HTML, .png, and more.")
else:
  client.set_model_version_tag(name=model_name, version=model_version, key = "has_artifacts", value = 1)
  print("Artifacts downloaded in: {}".format(local_path))
  print("Artifacts: {}".format(os.listdir(local_path)))

  local_path = client.download_artifacts(run_info.info.run_id, "", local_dir)


Artifacts downloaded in: /tmp/model_artifacts/
Artifacts: ['estimator.html', 'precision_recall_curve_plot.png', 'training_precision_recall_curve.png', 'training_confusion_matrix.png', 'confusion_matrix.png', 'lift_curve_plot.png', 'model', 'roc_curve_plot.png', 'training_roc_curve.png']


## Results

Here's a summary of the testing results:

In [0]:
results = client.get_model_version(model_name, model_version)
results.tags

Out[27]: {'demo_test': 'none',
 'has_artifacts': '1',
 'has_description': '1',
 'has_signature': '1',
 'predicts': '1'}

Notify the Slack channel with the same webhook used to alert on transition change in MLflow.

In [0]:
slack_message = f"Registered model <b>{model_name}</b> version <b>{model_version}</b> baseline test results: {results.tags}"
send_notification(slack_message)

slack isn't properly setup in this workspace.


## Move to Staging or Archived

The next phase of this models' lifecycle will be to `Staging` or `Archived`, depending on how it fared in testing.

In [0]:
# If any checks failed, reject and move to Archived
if '0' in results or 'fail' in results: 
  print("Rejecting transition...")
  reject_transition(model_name,
                   model_version,
                   stage='Archived',
                   comment='Tests failed, moving to archived.  Check the tags or the job run to see what happened.')
  
else: 
  print("Accepting transition...")
  accept_transition(model_name,
                   model_version,
                   stage='Staging',
                   comment='All tests passed!  Moving to staging.')

Accepting transition...


### Congratulation, our model is now automatically tested and will be transitioned accordingly 

We now have the certainty that our model is ready to be used as it matches our quality standard.


Next: [Run batch inference from our STAGING model]($./06_staging_inference)