In [0]:
"""
Model Validation Notebook

This notebook performs model validation using the MLflow model validation API after training and registering a model in the model registry, prior to deployment to the "champion" alias. It is designed to run as part of a continuous deployment (CD) pipeline, triggered by an automated model training job, followed by validation and deployment, as defined in `mlops_dbx/resources/model-workflow-resource.yml`.

Parameters:
    - experiment_name: Name of the MLflow experiment.
    - run_mode: Mode for model validation. Options:
        - disabled: Skip validation and allow deployment.
        - dry_run: Run validation, ignore failures, and allow deployment.
        - enabled: Run validation, block deployment if validation fails.
    - enable_baseline_comparison: Whether to load the current "champion" model as baseline for comparison.
    - validation_input: Input table for validation data.
    - model_type: Type of model ("regressor" or "classifier").
    - targets: Name of the column containing evaluation labels.
    - custom_metrics_loader_function: Function name to load custom metrics.
    - validation_thresholds_loader_function: Function name to load validation thresholds.
    - evaluator_config_loader_function: Function name to load evaluator config.
    - model_name: Full model name in registry.
    - model_version: Candidate model alias/version.

Workflow:
    1. Reads parameters from widgets.
    2. Sets up MLflow experiment and model URIs.
    3. Loads validation data and configuration.
    4. Loads custom metrics, thresholds, and evaluator config.
    5. Runs model evaluation using MLflow, optionally compares with baseline.
    6. Logs validation results and metrics as artifacts.
    7. Validates evaluation results against thresholds.
    8. Updates model version description and assigns "challenger" alias if validation passes.
    9. Handles validation failures according to run_mode.

Notes:
    - Baseline comparison is currently disabled for models registered with Feature Store.
    - Uses Databricks Feature Store for batch scoring.
    - Artifacts and metrics are logged to MLflow for traceability.
    - Model validation status is appended to model version description in registry.

References:
    - MLflow evaluate API: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.evaluate
    - Model Validation documentation: https://mlflow.org/docs/latest/models.html#model-validation
    - Feature Store limitation: https://github.com/databricks/mlops-stacks/issues/70
"""

In [0]:
dbutils.widgets.text(
    "experiment_name",
    "/Workspace/Shared/mlops_talk/telco_churn_model",
    "Experiment Name",
)
dbutils.widgets.dropdown("run_mode", "enabled", ["disabled", "dry_run", "enabled"], "Run Mode")
dbutils.widgets.dropdown("enable_baseline_comparison", "false", ["true", "false"], "Enable Baseline Comparison")
dbutils.widgets.text("validation_input", "mlops_dbx_talk_dev.churn.telco_churn_validation", "Validation Input")

dbutils.widgets.text("model_type", "classifier", "Model Type")
dbutils.widgets.text("targets", "churn", "Targets")
dbutils.widgets.text("custom_metrics_loader_function", "custom_metrics", "Custom Metrics Loader Function")
dbutils.widgets.text("validation_thresholds_loader_function", "validation_thresholds", "Validation Thresholds Loader Function")
dbutils.widgets.text("evaluator_config_loader_function", "evaluator_config", "Evaluator Config Loader Function")
dbutils.widgets.text("model_name", "mlops_dbx_talk_dev.churn.telco_churn_model", "Full (Three-Level) Model Name")
dbutils.widgets.text("model_version", "staging", "Candidate Model Alias")

In [0]:
run_mode = dbutils.widgets.get("run_mode").lower()
assert run_mode == "disabled" or run_mode == "dry_run" or run_mode == "enabled"

if run_mode == "disabled":
    print(
        "Model validation is in DISABLED mode. Exit model validation without blocking model deployment."
    )
    dbutils.notebook.exit(0)
dry_run = run_mode == "dry_run"

if dry_run:
    print(
        "Model validation is in DRY_RUN mode. Validation threshold validation failures will not block model deployment."
    )
else:
    print(
        "Model validation is in ENABLED mode. Validation threshold validation failures will block model deployment."
    )

In [0]:
import importlib
import mlflow
import os
import tempfile
import traceback

from mlflow.tracking.client import MlflowClient
from mlflow.models.evaluation.base import EvaluationResult

client = MlflowClient(registry_uri="databricks-uc")
mlflow.set_registry_uri('databricks-uc')

# set experiment
experiment_name = dbutils.widgets.get("experiment_name")
mlflow.set_experiment(experiment_name)

# set model evaluation parameters that can be inferred from the job
# model_uri = dbutils.jobs.taskValues.get("Train", "model_uri", debugValue="")
# model_name = dbutils.jobs.taskValues.get("Train", "model_name", debugValue="")
# model_version = dbutils.jobs.taskValues.get("Train", "model_version", debugValue="")

# if model_uri == "":
model_name = dbutils.widgets.get("model_name")
model_version = dbutils.widgets.get("model_version")
model_uri = "models:/" + model_name + "@" + model_version

baseline_model_uri = "models:/" + model_name + "@champion"

evaluators = "default"
assert model_uri != "", "model_uri notebook parameter must be specified"
assert model_name != "", "model_name notebook parameter must be specified"
assert model_version != "", "model_version notebook parameter must be specified"

In [0]:
# take input
enable_baseline_comparison = dbutils.widgets.get("enable_baseline_comparison")


enable_baseline_comparison = "false" 
print(
    "Currently baseline model comparison is not supported for models registered with feature store. Please refer to "
    "issue https://github.com/databricks/mlops-stacks/issues/70 for more details."
)

assert enable_baseline_comparison == "true" or enable_baseline_comparison == "false"
enable_baseline_comparison = enable_baseline_comparison == "true"

validation_input = dbutils.widgets.get("validation_input")
assert validation_input
data = spark.table(validation_input)

model_type = dbutils.widgets.get("model_type")
targets = dbutils.widgets.get("targets")

assert model_type
assert targets

custom_metrics_loader_function_name = dbutils.widgets.get("custom_metrics_loader_function")
validation_thresholds_loader_function_name = dbutils.widgets.get("validation_thresholds_loader_function")
evaluator_config_loader_function_name = dbutils.widgets.get("evaluator_config_loader_function")
assert custom_metrics_loader_function_name
assert validation_thresholds_loader_function_name
assert evaluator_config_loader_function_name

import sys
sys.path.append('../..')
from validation.validation import custom_metrics, validation_thresholds, evaluator_config

validation_thresholds = validation_thresholds()
custom_metrics = custom_metrics()
evaluator_config = evaluator_config()

# custom_metrics_loader_function = getattr(
#     importlib.import_module("validation"), custom_metrics_loader_function_name
# )
# validation_thresholds_loader_function = getattr(
#     importlib.import_module("validation"), validation_thresholds_loader_function_name
# )
# evaluator_config_loader_function = getattr(
#     importlib.import_module("validation"), evaluator_config_loader_function_name
# )
# custom_metrics = custom_metrics_loader_function()
# validation_thresholds = validation_thresholds_loader_function()
# evaluator_config = evaluator_config_loader_function()

In [0]:
# helper methods
def get_run_link(run_info):
    return "[Run](#mlflow/experiments/{0}/runs/{1})".format(
        run_info.experiment_id, run_info.run_id
    )


def get_training_run(model_name, model_version):
    version = client.get_model_version_by_alias(model_name, model_version)
    return mlflow.get_run(run_id=version.run_id)


def generate_run_name(training_run):
    return None if not training_run else training_run.info.run_name + "-validation"


def generate_description(training_run):
    return (
        None
        if not training_run
        else "Model Training Details: {0}\n".format(get_run_link(training_run.info))
    )


def log_to_model_description(run, success):
    run_link = get_run_link(run.info)
    description = client.get_model_version_by_alias(model_name, model_version).description
    version = client.get_model_version_by_alias(model_name, model_version).version
    status = "SUCCESS" if success else "FAILURE"
    if description != "":
        description += "\n\n---\n\n"
    description += "Model Validation Status: {0}\nValidation Details: {1}".format(
        status, run_link
    )
    client.update_model_version(
        name=model_name, version=version, description=description
    )


In [0]:


# Temporary fix as FS model can't predict as a pyfunc model
# MLflow evaluate can take a lambda function instead of a model uri for a model
# but id does not work for the baseline model as it requires a model_uri (baseline comparison is set to false)

from databricks.feature_store import FeatureStoreClient

def get_fs_model(df, model_uri):
    fs_client = FeatureStoreClient()
    return (
        fs_client.score_batch(model_uri, spark.createDataFrame(df))
        # .select("prediction")
        .toPandas()
    )


training_run = get_training_run(model_name, model_version)

# run evaluate
with mlflow.start_run(
    run_name=generate_run_name(training_run),
    description=generate_description(training_run),
) as run, tempfile.TemporaryDirectory() as tmp_dir:
    validation_thresholds_file = os.path.join(tmp_dir, "validation_thresholds.txt")
    with open(validation_thresholds_file, "w") as f:
        if validation_thresholds:
            for metric_name in validation_thresholds:
                f.write(
                    "{0:30}  {1}\n".format(
                        metric_name, str(validation_thresholds[metric_name])
                    )
                )
    mlflow.log_artifact(validation_thresholds_file)

    try:
        eval_result = mlflow.evaluate(
            # model=get_fs_model,
            data=get_fs_model(data.toPandas(), model_uri),
            targets=targets,
            predictions="prediction",
            model_type=model_type,
            evaluators=evaluators,
            extra_metrics=custom_metrics,
            evaluator_config=evaluator_config,
        )
        if enable_baseline_comparison:
            baseline_eval_result = mlflow.evaluate(
                # model=get_fs_model,
                data=get_fs_model(data.toPandas(), baseline_model_uri),
                targets=targets,
                predictions="prediction",
                model_type=model_type,
                evaluators=evaluators,
                extra_metrics=custom_metrics,
                evaluator_config=evaluator_config,
            )
        else:
            baseline_eval_result = None
            
        metrics_file = os.path.join(tmp_dir, "metrics.txt")
        with open(metrics_file, "w") as f:
            f.write(
                "{0:30}  {1:30}  {2}\n".format("metric_name", "candidate", "baseline")
            )
            for metric in eval_result.metrics:
                candidate_metric_value = str(eval_result.metrics[metric])
                baseline_metric_value = "N/A"
                if (baseline_eval_result is not None) and (metric in baseline_eval_result.metrics):
                    mlflow.log_metric(
                        "baseline_" + metric, eval_result.baseline_model_metrics[metric]
                    )
                    baseline_metric_value = str(
                        eval_result.baseline_model_metrics[metric]
                    )
                f.write(
                    "{0:30}  {1:30}  {2}\n".format(
                        metric, candidate_metric_value, baseline_metric_value
                    )
                )
        mlflow.log_artifact(metrics_file)
        
        mlflow.validate_evaluation_results(validation_thresholds, eval_result, baseline_eval_result)

        log_to_model_description(run, True)
        version = client.get_model_version_by_alias(model_name, model_version).version
        
        # Assign "challenger" alias to indicate model version has passed validation checks
        print("Validation checks passed. Assigning 'challenger' alias to model version.")
        client.set_registered_model_alias(model_name, "challenger", version)
        client.delete_registered_model_alias(
                name=model_name,
                alias="staging")
        
    except Exception as err:
        raise ValueError(err)
        # log_to_model_description(run, False)
        # error_file = os.path.join(tmp_dir, "error.txt")
        # with open(error_file, "w") as f:
        #     f.write("Validation failed : " + str(err) + "\n")
        #     f.write(traceback.format_exc())
        # mlflow.log_artifact(error_file)
        # if not dry_run:
        #     raise err
        # else:
        #     print(
        #         "Model validation failed in DRY_RUN. It will not block model deployment."
        #     )