# Model Monitoring

In [0]:
%%sql

use catalog main;

use schema dbdemos_mlops;

In [0]:
import os
import time
from datetime import datetime, timedelta

import pyspark.sql.functions as F
from pyspark.sql.types import DoubleType, StructField

from mlflow import MlflowClient

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import (
    MonitorInferenceLog,
    MonitorInferenceLogProblemType,
    MonitorMetric,
    MonitorMetricType,
    MonitorInfoStatus,
    MonitorRefreshInfoState
)

### We need to Create an Inference Table

the inference table will be used to store the inference and model data to detect data, label and model drift

In [0]:
client = MlflowClient()

model_name = f"advanced_mlops_churn"

model = client.get_registered_model(name=model_name)
model_version = int(client.get_model_version_by_alias(name=model_name, alias="Champion").version)

features_df = spark.read.table('advanced_churn_cust_ids')

inference_df = (
    features_df
    .withColumn("prediction", F.lit(None).cast('string'))
    .withColumn("model_name", F.lit(model_name)) 
    .withColumn("model_version", F.lit(model_version)) 
    .withColumn("model_alias", F.lit("Champion")) 
    .withColumn("inference_timestamp", F.lit(datetime.now()- timedelta(days=2)))
)

inference_df.limit(0).write.mode('overwrite').saveAsTable('advanced_churn_inference_table')

In [0]:
%sql

ALTER TABLE advanced_churn_inference_table SET TBLPROPERTIES (delta.enableChangeDataFeed = true)

### We also need a Baseline Table

this allows us to compare the inference, prediction and model information

In [0]:
from databricks.feature_engineering import FeatureEngineeringClient
inference_df = spark.read.table("advanced_churn_cust_ids")

fe = FeatureEngineeringClient()

model_name = f"advanced_mlops_churn"
model_uri = f"models:/{model_name}@Champion"

preds_df = fe.score_batch(df=inference_df, model_uri=model_uri, result_type="string")
display(preds_df)

In [0]:
baseline_df = preds_df.withColumn("model_name", F.lit(model_name)) \
                              .withColumn("model_version", F.lit(model_version)) \

baseline_df = baseline_df.drop('customer_id', 'transaction_ts')

baseline_df.write.mode('overwrite').saveAsTable('advanced_churn_baseline')

### Create a Custom metric 

In [0]:
expected_loss_metric = [
  MonitorMetric(
    type=MonitorMetricType.CUSTOM_METRIC_TYPE_AGGREGATE,
    name="expected_loss",
    input_columns=[":table"],
    definition="""avg(CASE
    WHEN {{prediction_col}} != {{label_col}} AND {{label_col}} = 'Yes' THEN -monthly_charges
    ELSE 0 END
    )""",
    output_data_type= StructField("output", DoubleType()).json()
  )
]

### Create monitor

In [0]:
w = WorkspaceClient()

try:
  info = w.quality_monitors.create(
    table_name=f"main.dbdemos_mlops.advanced_churn_inference_table",
    inference_log=MonitorInferenceLog(
            problem_type=MonitorInferenceLogProblemType.PROBLEM_TYPE_CLASSIFICATION,
            prediction_col="prediction",
            timestamp_col="inference_timestamp",
            granularities=["1 day"],
            model_id_col="model_version",
            label_col="churn", 
    ),
    assets_dir=f"{os.getcwd()}/monitoring", 
    output_schema_name=f"main.dbdemos_mlops",
    baseline_table_name=f"main.dbdemos_mlops.advanced_churn_baseline",
    slicing_exprs=["senior_citizen='Yes'", "contract"], 
    custom_metrics=expected_loss_metric)
  
except Exception as lhm_exception:
  if "already exist" in str(lhm_exception).lower():
    print(f"Monitor for advanced_churn_inference_table already exists, retrieving monitor info:")
    info = w.quality_monitors.get(table_name=f"main.dbdemos_mlops.advanced_churn_inference_table")
  else:
    raise lhm_exception

In [0]:
while info.status == MonitorInfoStatus.MONITOR_STATUS_PENDING:
  info = w.quality_monitors.get(table_name=f"main.dbdemos_mlops.advanced_churn_inference_table")
  time.sleep(10)

assert info.status == MonitorInfoStatus.MONITOR_STATUS_ACTIVE, "Error creating monitor"

In [0]:
def get_refreshes():
  return w.quality_monitors.list_refreshes(table_name=f"main.dbdemos_mlops.advanced_churn_inference_table").refreshes

refreshes = get_refreshes()
if len(refreshes) == 0:
  w.quality_monitors.run_refresh(table_name=f"main.dbdemos_mlops.advanced_churn_inference_table")
  time.sleep(5)
  refreshes = get_refreshes()

run_info = refreshes[0]
while run_info.state in (MonitorRefreshInfoState.PENDING, MonitorRefreshInfoState.RUNNING):
  run_info = w.quality_monitors.get_refresh(table_name=f"main.dbdemos_mlops.advanced_churn_inference_table", refresh_id=run_info.refresh_id)
  print(f"waiting for refresh to complete {run_info.state}...")
  time.sleep(30)

assert run_info.state == MonitorRefreshInfoState.SUCCESS, "Monitor refresh failed"

In [0]:
w.quality_monitors.get(table_name=f"main.dbdemos_mlops.advanced_churn_inference_table")