# Drift detection

In this step, we will define drift detection rules to run periodically on the inference data.

**Drift detection** refers to the process of identifying changes in the statistical properties of input data, which can lead to a decline in model performance over time. This is crucial for maintaining the accuracy and reliability of models in dynamic environments, as it allows for timely interventions such as model retraining or adaptation to new data distributions

In order to simulate some data drifts, we will use [_dbldatagen_ library](https://github.com/databrickslabs/dbldatagen), a Databricks Labs project which is a Python library for generating synthetic data using Spark.

We will simulate label drift using the data generator package.
**Label drift** occurs when the distribution of the ground truth labels changes over time, which can happen due to shifts in labeling criteria or the introduction of labeling errors.

_We will set all labels to True_

### A cluster has been created for this demo
To run this demo, just select the cluster `dbdemos-mlops-end2end-edgar_aguilerarod` from the dropdown menu ([open cluster configuration](https://dbc-07122dbb-1c85.cloud.databricks.com/#setting/clusters/0102-173414-9ev1v92w/configuration)). <br />
*Note: If the cluster was deleted after 30 days, you can re-create it with `dbdemos.create_cluster('mlops-end2end')` or re-install the demo: `dbdemos.install('mlops-end2end')`*



<img src="https://github.com/databricks-demos/dbdemos-resources/blob/main/images/product/mlops/advanced/banners/mlflow-uc-end-to-end-advanced-8.png?raw=true" width="1200">

<!-- Collect usage data (view). Remove it to disable collection. View README for more details.  -->
<img width="1px" src="https://ppxrzfxige.execute-api.us-west-2.amazonaws.com/v1/analytics?category=data-science&org_id=1832744760933926&notebook=%2F02-mlops-advanced%2F08_drift_detection&demo_name=mlops-end2end&event=VIEW&path=%2F_dbdemos%2Fdata-science%2Fmlops-end2end%2F02-mlops-advanced%2F08_drift_detection&version=1">
<!-- [metadata={"description":"MLOps end2end workflow: Batch to automatically retrain model on a monthly basis.",
 "authors":["quentin.ambard@databricks.com"],
 "db_resources":{},
  "search_tags":{"vertical": "retail", "step": "Model testing", "components": ["mlflow"]},
                 "canonicalUrl": {"AWS": "", "Azure": "", "GCP": ""}}] -->

In [0]:
%pip install -qU "databricks-sdk>=0.28.0"
%pip install -qU dbldatagen
dbutils.library.restartPython()

In [0]:
dbutils.widgets.dropdown("perf_metric", "f1_score.macro", ["accuracy_score", "precision.weighted", "recall.weighted", "f1_score.macro"])
dbutils.widgets.dropdown("drift_metric", "js_distance", ["chi_squared_test.statistic", "chi_squared_test.pvalue", "tv_distance", "l_infinity_distance", "js_distance"])
dbutils.widgets.text("model_id", "*", "Model Id")

Run setup notebook & generate synthetic data

In [0]:
%run ../_resources/00-setup $reset_all_data=false $adv_mlops=true $gen_synthetic_data=true

## Refresh the monitor 

The previous step performs a write of the synthetic data to the inteference table. We should referesh the monitor to re-compute the metrics.

**PS:** Refresh is only necessary if the monitored table has undergone changes

In [0]:
import time
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import MonitorInfoStatus, MonitorRefreshInfoState


w = WorkspaceClient()
refresh_info = w.quality_monitors.run_refresh(table_name=f"{catalog}.{db}.advanced_churn_inference_table")

while refresh_info.state in (MonitorRefreshInfoState.PENDING, MonitorRefreshInfoState.RUNNING):
  refresh_info = w.quality_monitors.get_refresh(table_name=f"{catalog}.{db}.advanced_churn_inference_table", refresh_id=refresh_info.refresh_id)
  time.sleep(30)

In [0]:
monitor_info = w.quality_monitors.get(table_name=f"{catalog}.{db}.advanced_churn_inference_table")
drift_table_name = monitor_info.drift_metrics_table_name
profile_table_name = monitor_info.profile_metrics_table_name


## Inspect dashboard

Once the monitor is refreshed, refreshing the monitoring dashboard will show the latest model performance metrics. When evaluated against the latest labelled data, the model has poor accuracy, weighted F1 score and recall. On the other hand, it has a weighted precision of 1.

We expect this because the model is now heavily weighted towards the `churn = Yes` class. All predictions of `Yes` are correct, leading to a weighted precision of 1.

<br>

<img src="https://github.com/databricks-demos/dbdemos-resources/blob/main/images/product/mlops/advanced/08_model_kpis.png?raw=true" width="1200">

<br>

We will go ahead and illustrate how you can programatically retrieve the drift metrics and trigger model retraining.

However, it is worthwhile to mention that by inspecting the confusion matrix in the monitoring dashboard, we can see that the latest labelled data only has the `Yes` label. i.e. all customers have churned. This is an unlikely scenario. That should lead us to question whether labelling was done correctly, or if there were data quality issues upstream. These causes of label drift do not necessitate model retraining.

<br>

<img src="https://github.com/databricks-demos/dbdemos-resources/blob/main/images/product/mlops/advanced/08_confusion_matrix.png?raw=true" width="1200">

<br>

## Retrieve drift metrics

Query Lakehouse Monitoring's drift metrics table for the inference table being monitored.
Here we're testing if these metrics have exceeded a certain threshold (defined by the business):
1. Prediction drift (Jensen–Shannon distance) > 0.2
2. Label drift (Jensen–Shannon distance) > 0.2
3. Expected Loss (daily average per user) > 30
4. Performance(i.e. F1-Score) < 0.4

In [0]:
metric = dbutils.widgets.get("perf_metric")
drift = dbutils.widgets.get("drift_metric")
model_id = dbutils.widgets.get("model_id")

Construct dataframe to detect performance degradation from the profile metrics table generated by lakehouse monitoring

In [0]:
performance_metrics_df = spark.sql(f"""
SELECT
  window.start as time,
  {metric} AS performance_metric,
  expected_loss,
  Model_Version AS `Model Id`
FROM {profile_table_name}
WHERE
  window.start >= "2024-06-01"
	AND log_type = "INPUT"
  AND column_name = ":table"
  AND slice_key is null
  AND slice_value is null
  AND Model_Version = '{model_id}'
ORDER BY
  window.start
"""
)
display(performance_metrics_df)

Construct dataframe to detect drifts from the drift metrics table generated by lakehouse monitoring.

In [0]:
drift_metrics_df = spark.sql(f"""
  SELECT
  window.start AS time,
  column_name,
  {drift} AS drift_metric,
  Model_Version AS `Model Id`
FROM {drift_table_name}
WHERE
  column_name IN ('prediction', 'churn')
  AND window.start >= "2024-06-01"
  AND slice_key is null
  AND slice_value is null
  AND Model_Version = '{model_id}'
  AND drift_type = "CONSECUTIVE"
ORDER BY
  window.start
"""
)
display(drift_metrics_df )

In [0]:
from pyspark.sql.functions import first
# If no drift on the label or prediction, we skip it
if not drift_metrics_df.isEmpty():
    unstacked_drift_metrics_df = (
        drift_metrics_df.groupBy("time", "`Model Id`")
        .pivot("column_name")
        .agg(first("drift_metric"))
        .orderBy("time")
    )
    display(unstacked_drift_metrics_df)

In [0]:
all_metrics_df = performance_metrics_df
if not drift_metrics_df.isEmpty():
    all_metrics_df = performance_metrics_df.join(
        unstacked_drift_metrics_df, on=["time", "Model Id"], how="inner"
    )

display(all_metrics_df)


## Count total violations and save as task value

Here we will define the different threshholds for the metrics we are interested in to qualify a drift:
- Performance metric < 0.5 
- Average Expected Loss per customer (our custom metric connected to business) > 30 dollars

In [0]:
from pyspark.sql.functions import col, abs


performance_violation_count = all_metrics_df.where(
    (col("performance_metric") < 0.5) & (abs(col("expected_loss")) > 30)
).count()

drift_violation_count = 0
if not drift_metrics_df.isEmpty():
    drift_violation_count = all_metrics_df.where(
        (col("churn") > 0.19) & (col("prediction") > 0.19)
    ).count()

all_violations_count = drift_violation_count + performance_violation_count

In [0]:
print(f"Total number of joint violations: {all_violations_count}")

## Next: Trigger model retraining

Upon detecting the number of violations, we should automate some actions, such as:
- Retrain the machine learning model
- Send an alert to owners via slack or email

One way of performing this in Databricks is to add branching logic to your job with [the If/else condition task](https://docs.databricks.com/en/jobs/conditional-tasks.html#add-branching-logic-to-your-job-with-the-ifelse-condition-task). 


<img src="https://github.com/databricks-demos/dbdemos-resources/raw/main/images/product/mlops/advanced/08_view_retraining_workflow.png?raw=true" width="1200">

In order to do that, we should save the number of violations in a [task value](https://docs.databricks.com/en/jobs/share-task-context.html) to be consumed in the If/else condition. 

In our workflow, we will trigger a model training, which will be a job run task of the train model job.

In [0]:
dbutils.jobs.taskValues.set(key = 'all_violations_count', value = all_violations_count)