In [0]:
##################################################################################
# Model Validation Notebook
##
# This notebook uses mlflow.evaluate API to run model validation after training
# a model in model registry, before deploying it to the "champion" alias.
#
# Parameters:
#
# * env                         - Environment (dev, staging, prod)
# * run_mode                    - disabled/dry_run/enabled
# * enable_baseline_comparison  - Compare against champion model
# * catalog/schema/table        - Validation data location
# * forecast_horizon            - Forecast horizon
# * model_name                  - Three-level UC model name
# * model_version               - Model version to validate
# * experiment_name             - MLflow experiment
##################################################################################

# MAGIC %load_ext autoreload
# MAGIC %autoreload 2

In [None]:
# DBTITLE 1, Install dependencies
# MAGIC %pip install prophet databricks-sdk mlflow pandas
dbutils.library.restartPython()

In [0]:
# DBTITLE 1, Notebook arguments
dbutils.widgets.text("experiment_name", "/dev-prophet-forecast-experiment", "Experiment Name")
dbutils.widgets.dropdown("run_mode", "dry_run", ["disabled", "dry_run", "enabled"], "Run Mode")
dbutils.widgets.dropdown("enable_baseline_comparison", "false", ["true", "false"], "Enable Baseline Comparison")
dbutils.widgets.text("catalog", "johannes_oehler", label="Data Catalog")
dbutils.widgets.text("schema", "vectorlab", label="Data Schema")
dbutils.widgets.text("table", "forecast_data", label="Data Table")
dbutils.widgets.text("forecast_horizon", "10", label="Forecast Horizon")
dbutils.widgets.text("model_name", "johannes_oehler.vectorlab.prophet_forecast", "Model Name")
dbutils.widgets.text("model_version", "", "Model Version")

In [None]:
# DBTITLE 1, Check run mode
run_mode = dbutils.widgets.get("run_mode").lower()
assert run_mode == "disabled" or run_mode == "dry_run" or run_mode == "enabled"

if run_mode == "disabled":
    print("Model validation is in DISABLED mode. Exit model validation without blocking model deployment.")
    dbutils.notebook.exit(0)
    
dry_run = run_mode == "dry_run"

if dry_run:
    print("Model validation is in DRY_RUN mode. Validation threshold failures will not block deployment.")
else:
    print("Model validation is in ENABLED mode. Validation threshold failures will block deployment.")


In [None]:
# DBTITLE 1, Setup MLflow and get model info
import mlflow
from mlflow.tracking.client import MlflowClient

client = MlflowClient(registry_uri="databricks-uc")
mlflow.set_registry_uri('databricks-uc')

# Set experiment
experiment_name = dbutils.widgets.get("experiment_name")
mlflow.set_experiment(experiment_name)

# Get model info from training task or widgets
model_uri = dbutils.jobs.taskValues.get("Train", "model_uri", debugValue="")
model_name = dbutils.jobs.taskValues.get("Train", "model_name", debugValue="")
model_version = dbutils.jobs.taskValues.get("Train", "model_version", debugValue="")

# Fall back to widgets if not running in a workflow
if model_uri == "":
    model_name = dbutils.widgets.get("model_name")
    model_version = dbutils.widgets.get("model_version")
    model_uri = "models:/" + model_name + "/" + model_version

baseline_model_uri = "models:/" + model_name + "@champion"

assert model_uri != "", "model_uri notebook parameter must be specified"
assert model_name != "", "model_name notebook parameter must be specified"
assert model_version != "", "model_version notebook parameter must be specified"

print(f"Validating model: {model_uri}")
print(f"Model name: {model_name}")
print(f"Model version: {model_version}")


In [None]:
# DBTITLE 1, Get validation parameters
enable_baseline_comparison = dbutils.widgets.get("enable_baseline_comparison")
assert enable_baseline_comparison == "true" or enable_baseline_comparison == "false"
enable_baseline_comparison = enable_baseline_comparison == "true"

# Get data parameters
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
table = dbutils.widgets.get("table")
forecast_horizon = int(dbutils.widgets.get("forecast_horizon"))

print(f"Validation data source: {catalog}.{schema}.{table}")
print(f"Forecast horizon: {forecast_horizon}")
print(f"Baseline comparison enabled: {enable_baseline_comparison}")


In [None]:
# DBTITLE 1, Load and prepare validation data
import pandas as pd
from pyspark.sql.functions import col, lit, to_date

# Load data
query = f"SELECT date, store, SUM(sales) as sales FROM {catalog}.{schema}.{table} GROUP BY date, store ORDER BY date desc"
df = spark.sql(query)

# Filter to single store for simplicity
df = df.filter(df.store == 1)

# Get test data (most recent forecast_horizon days)
test_df = df.orderBy(df.date.desc()).limit(forecast_horizon)

# Clean data - remove missing values
cleaned_df = test_df.na.drop(subset=["sales"]) 

# Remove outliers using IQR method
quartiles = cleaned_df.approxQuantile("sales", [0.25, 0.75], 0.05) 
IQR = quartiles[1] - quartiles[0]
lower_bound = 0
upper_bound = quartiles[1] + 1.5 * IQR

no_outliers_df = cleaned_df.filter(
    (col("sales") > lit(lower_bound)) 
    & (col("sales") <= lit(upper_bound)) 
)

# Prepare data in format expected by Prophet model (ds, y columns)
validation_df = no_outliers_df.select(
    to_date(col("date")).alias("ds"), 
    col("sales").alias("y").cast("double")
).orderBy(col("ds").asc())

# Convert to Pandas for mlflow.evaluate
validation_data = validation_df.toPandas()
validation_data["ds"] = pd.to_datetime(validation_data["ds"])

print(f"Validation dataset size: {len(validation_data)} records")
validation_data.head()


In [None]:
# DBTITLE 1, Define validation thresholds
from mlflow.models import MetricThreshold

# Define validation thresholds for forecasting
# These will determine if the model passes validation
validation_thresholds = {
    "mean_squared_error": MetricThreshold(
        threshold=1000,  # MSE should be <= 1000
        greater_is_better=False,
    ),
    "mean_absolute_error": MetricThreshold(
        threshold=25,  # MAE should be <= 25
        greater_is_better=False,
    ),
    "root_mean_squared_error": MetricThreshold(
        threshold=30,  # RMSE should be <= 30
        greater_is_better=False,
    ),
}

# Define custom metrics (optional)
custom_metrics = []

# Evaluator config (optional)
evaluator_config = {}

print("Validation thresholds configured:")
for metric_name, threshold in validation_thresholds.items():
    print(f"  {metric_name}: <= {threshold.threshold}")


In [None]:
# DBTITLE 1, Helper functions
def get_run_link(run_info):
    return "[Run](#mlflow/experiments/{0}/runs/{1})".format(
        run_info.experiment_id, run_info.run_id
    )


def get_training_run(model_name, model_version):
    version = client.get_model_version(model_name, model_version)
    return mlflow.get_run(run_id=version.run_id)


def generate_run_name(training_run):
    return None if not training_run else training_run.info.run_name + "-validation"


def generate_description(training_run):
    return (
        None
        if not training_run
        else "Model Training Details: {0}\n".format(get_run_link(training_run.info))
    )


def log_to_model_description(run, success):
    run_link = get_run_link(run.info)
    description = client.get_model_version(model_name, model_version).description
    status = "SUCCESS" if success else "FAILURE"
    if description != "" and description is not None:
        description += "\n\n---\n\n"
    else:
        description = ""
    description += "Model Validation Status: {0}\nValidation Details: {1}".format(
        status, run_link
    )
    client.update_model_version(
        name=model_name, version=model_version, description=description
    )


In [None]:
# DBTITLE 1, Run model validation with mlflow.evaluate
import os
import tempfile
import traceback

training_run = get_training_run(model_name, model_version)

# Run mlflow.evaluate
with mlflow.start_run(
    run_name=generate_run_name(training_run),
    description=generate_description(training_run),
) as run, tempfile.TemporaryDirectory() as tmp_dir:
    
    # Log validation thresholds
    validation_thresholds_file = os.path.join(tmp_dir, "validation_thresholds.txt")
    with open(validation_thresholds_file, "w") as f:
        if validation_thresholds:
            for metric_name in validation_thresholds:
                f.write(
                    "{0:30}  {1}\n".format(
                        metric_name, str(validation_thresholds[metric_name])
                    )
                )
    mlflow.log_artifact(validation_thresholds_file)

    try:
        # Run mlflow.evaluate
        eval_result = mlflow.evaluate(
            model=model_uri,
            data=validation_data,
            targets="y",
            model_type="regressor",
            evaluators="default",
            validation_thresholds=validation_thresholds,
            custom_metrics=custom_metrics,
            baseline_model=None
            if not enable_baseline_comparison
            else baseline_model_uri,
            evaluator_config=evaluator_config,
        )
        
        # Log metrics comparison
        metrics_file = os.path.join(tmp_dir, "metrics.txt")
        with open(metrics_file, "w") as f:
            f.write(
                "{0:30}  {1:30}  {2}\n".format("metric_name", "candidate", "baseline")
            )
            for metric in eval_result.metrics:
                candidate_metric_value = str(eval_result.metrics[metric])
                baseline_metric_value = "N/A"
                if metric in eval_result.baseline_model_metrics:
                    mlflow.log_metric(
                        "baseline_" + metric, eval_result.baseline_model_metrics[metric]
                    )
                    baseline_metric_value = str(
                        eval_result.baseline_model_metrics[metric]
                    )
                f.write(
                    "{0:30}  {1:30}  {2}\n".format(
                        metric, candidate_metric_value, baseline_metric_value
                    )
                )
        mlflow.log_artifact(metrics_file)
        log_to_model_description(run, True)
        
        # Assign "challenger" alias to indicate model version has passed validation checks
        print("Validation checks passed. Assigning 'challenger' alias to model version.")
        client.set_registered_model_alias(model_name, "challenger", model_version)
        
        print("\n=== Validation Results ===")
        print(f"Model: {model_uri}")
        print(f"Status: PASSED")
        print(f"\nKey Metrics:")
        for metric_name in ["mean_squared_error", "mean_absolute_error", "root_mean_squared_error"]:
            if metric_name in eval_result.metrics:
                print(f"  {metric_name}: {eval_result.metrics[metric_name]:.4f}")
        
    except Exception as err:
        log_to_model_description(run, False)
        error_file = os.path.join(tmp_dir, "error.txt")
        with open(error_file, "w") as f:
            f.write("Validation failed : " + str(err) + "\n")
            f.write(traceback.format_exc())
        mlflow.log_artifact(error_file)
        
        print("\n=== Validation Results ===")
        print(f"Model: {model_uri}")
        print(f"Status: FAILED")
        print(f"Error: {str(err)}")
        
        if not dry_run:
            raise err
        else:
            print(
                "\nModel validation failed in DRY_RUN mode. It will not block model deployment."
            )
