# Data Drift Detection

In [0]:
%%sql

use catalog main;

use schema dbdemos_mlops;

In [0]:
import time
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import MonitorInfoStatus, MonitorRefreshInfoState

w = WorkspaceClient()

In [0]:
catalog = 'main'
db = 'dbdemos_mlops'

monitor_info = w.quality_monitors.get(table_name=f"{catalog}.{db}.advanced_churn_inference_table")
drift_table_name = monitor_info.drift_metrics_table_name
profile_table_name = monitor_info.profile_metrics_table_name

In [0]:
dbutils.widgets.dropdown("perf_metric", "f1_score.macro", ["accuracy_score", "precision.weighted", "recall.weighted", "f1_score.macro"])
dbutils.widgets.dropdown("drift_metric", "js_distance", ["chi_squared_test.statistic", "chi_squared_test.pvalue", "tv_distance", "l_infinity_distance", "js_distance"])
dbutils.widgets.text("model_id", "*", "Model Id")

In [0]:
metric = dbutils.widgets.get("perf_metric")
drift = dbutils.widgets.get("drift_metric")
model_id = dbutils.widgets.get("model_id")

In [0]:
performance_metrics_df = spark.sql(f"""
SELECT
  window.start as time,
  {metric} AS performance_metric,
  expected_loss,
  Model_Version AS `Model Id`
FROM {profile_table_name}
WHERE
  window.start >= "2024-06-01"
	AND log_type = "INPUT"
  AND column_name = ":table"
  AND slice_key is null
  AND slice_value is null
  AND Model_Version = '{model_id}'
ORDER BY
  window.start
"""
)
display(performance_metrics_df)

In [0]:
drift_metrics_df = spark.sql(f"""
  SELECT
  window.start AS time,
  column_name,
  {drift} AS drift_metric,
  Model_Version AS `Model Id`
FROM {drift_table_name}
WHERE
  column_name IN ('prediction', 'churn')
  AND window.start >= "2024-06-01"
  AND slice_key is null
  AND slice_value is null
  AND Model_Version = '{model_id}'
  AND drift_type = "CONSECUTIVE"
ORDER BY
  window.start
"""
)
display(drift_metrics_df )

In [0]:
from pyspark.sql.functions import first

# if no drift on the label or prediction, we skip it
if not drift_metrics_df.isEmpty():
    unstacked_drift_metrics_df = (
        drift_metrics_df.groupBy("time", "`Model Id`")
        .pivot("column_name")
        .agg(first("drift_metric"))
        .orderBy("time")
    )
    display(unstacked_drift_metrics_df)

In [0]:
all_metrics_df = performance_metrics_df
if not drift_metrics_df.isEmpty():
    all_metrics_df = performance_metrics_df.join(
        unstacked_drift_metrics_df, on=["time", "Model Id"], how="inner"
    )

display(all_metrics_df)

### Count total number of violations and save as task value
- performance metric < 0.5
- Average expected loss per customer (custom business metric) > $30.0

In [0]:
from pyspark.sql.functions import col, abs


performance_violation_count = all_metrics_df.where(
    (col("performance_metric") < 0.5) & (abs(col("expected_loss")) > 30)
).count()

drift_violation_count = 0
if not drift_metrics_df.isEmpty():
    drift_violation_count = all_metrics_df.where(
        (col("churn") > 0.19) & (col("prediction") > 0.19)
    ).count()

all_violations_count = drift_violation_count + performance_violation_count

print(f"Total number of joint violations: {all_violations_count}")

### Next: Trigger model retraining
if violations are detected, we should automatically:
- retrain the machine learning model
- send an alert to owners via email

In [0]:
dbutils.jobs.taskValues.set(key = 'all_violations_count', value = all_violations_count)