In [0]:
import mlflow
import mlflow.pyfunc
import random
import pandas as pd
from mlflow.models.signature import infer_signature
from pyspark.sql.functions import current_timestamp
from mlflow.tracking import MlflowClient
from datetime import datetime


class DummyModel(mlflow.pyfunc.PythonModel):
    def __init__(self, model_id, coefficient, bias, accuracy):
        self.model_id = model_id
        self.coefficient = coefficient
        self.bias = bias
        self.accuracy = accuracy

    def predict(self, context, model_input):
        # Simple linear transformation: output = coefficient * input + bias
        return [self.coefficient * x + self.bias for x in model_input]

def train_dummy_model_for_model_id(model_id):
    """
    Trains a dummy model for a given model_id, logs it with MLflow,
    sets a description on the model version based on version (using the MLflow description methods), and 
    appends run details (including the UC model location) to the master table.
    """
    # Generate dummy parameters and accuracy.
    coefficient = random.uniform(0.5, 1.5)
    bias = random.uniform(-1, 1)
    accuracy = random.uniform(0.7, 0.8)
    
    model = DummyModel(model_id, coefficient, bias, accuracy)
    
    # Create an example input DataFrame for signature inference.
    example_input = pd.DataFrame({"x": [0, 1, 2, 3, 4, 5]})
    example_output = model.predict(None, example_input["x"].tolist())
    signature = infer_signature(example_input, example_output)
    
    artifact_path = f"dummy_model_{model_id}"
    registered_model_name = model_id
    
    with mlflow.start_run() as run:
        mlflow.log_param("model_id", model_id)
        mlflow.log_param("coefficient", coefficient)
        mlflow.log_param("bias", bias)
        mlflow.log_metric("accuracy", accuracy)
        
        mlflow.pyfunc.log_model(
            artifact_path=artifact_path,
            python_model=model,
            signature=signature,
            registered_model_name=registered_model_name
        )
        
        run_id = run.info.run_id
        print(f"Logged DummyModel for model_id '{model_id}' with run ID: {run_id} and accuracy: {accuracy:.3f}")
        
        client = MlflowClient()
        # Retrieve all versions of this registered model.
        model_versions = client.search_model_versions(f"name='{registered_model_name}'")
        new_version = max([int(mv.version) for mv in model_versions])
        
        # Instead of setting an alias, update the model version's description.
        if len(model_versions) == 1:
            description = "prod"
        else:
            description = f"challenger_{datetime.now().strftime('%Y%m%d%H%M%S')}"
        
        client.update_model_version(
            name=registered_model_name,
            version=str(new_version),
            description=description
        )
        print(f"Set description for model '{registered_model_name}' version {new_version} to '{description}'.")
        
        # Build the UC model location (URI) that can be used to load the model.
        # Note: The URI no longer uses alias but can refer to the version with description.
        model_location = f"models:/{registered_model_name}/{new_version}"
    
    return model

In [0]:
# List of model_ids to track the definition of different models.
model_ids = [
    "field_demos.ml_ops.DummyModel_churn_prediction",
    "field_demos.ml_ops.DummyModel_fraud_detection",
    "field_demos.ml_ops.DummyModel_customer_segmentation",
    "field_demos.ml_ops.DummyModel_demand_forecasting",
    "field_demos.ml_ops.DummyModel_product_recommendation"
]

# Ensure MLflow uses the Databricks Unity Catalog registry.
mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"

# Train and log a dummy model for each model_id.
models = {}
for model_id in model_ids:
    models[model_id] = train_dummy_model_for_model_id(model_id)

In [0]:
def generate_master_model_log():
    from mlflow.tracking import MlflowClient
    import mlflow
    from pyspark.sql import SparkSession
    from pyspark.sql.types import StructType, StructField, StringType, DoubleType, BooleanType, IntegerType

    # Set the registry URI for Unity Catalog
    mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"

    # Initialize the MLflow client
    client = MlflowClient()

    # List to collect results for the master model log (version-level)
    results = []

    # Search for all registered models
    registered_models = client.search_registered_models()

    # Iterate over models and filter for those in the field_demos.ml_ops namespace
    for model in registered_models:
        if model.name.startswith("field_demos.ml_ops"):
            # Retrieve all versions of this model
            model_versions = client.search_model_versions(f"name='{model.name}'")
            if not model_versions:
                results.append({
                    "model_name": model.name,
                    "version": None,
                    "run_id": None,
                    "accuracy": None,
                    "description": None
                })
            else:
                for version in model_versions:
                    # Retrieve run details to extract metrics
                    run = client.get_run(version.run_id)
                    # Get the accuracy metric (returns None if not found)
                    accuracy = run.data.metrics.get("accuracy")
                    results.append({
                        "model_name": model.name,
                        "version": version.version,
                        "run_id": version.run_id,
                        "accuracy": float(accuracy) if accuracy is not None else None,
                        "description": version.description
                    })

    # Create the master model log Delta table
    master_schema = StructType([
        StructField("model_name", StringType(), True),
        StructField("version", StringType(), True),
        StructField("run_id", StringType(), True),
        StructField("accuracy", DoubleType(), True),
        StructField("description", StringType(), True)
    ])

    master_df = spark.createDataFrame(results, schema=master_schema)
    master_df.write.format("delta") \
        .mode("overwrite") \
        .saveAsTable("field_demos.ml_ops.master_model_log")

    print("Master model log successfully written to Delta table 'field_demos.ml_ops.master_model_log'.")

    # Build the model-level report with additional challenger info
    # For each model we track:
    # - number_of_versions: count of versions
    # - has_prod_model: True if any version has description "prod"
    # - max_prod_accuracy: highest accuracy among prod versions
    # - best_challenger: best non-prod record (dict with keys: accuracy, version) encountered
    report_dict = {}
    for record in results:
        model_name = record["model_name"]
        if model_name not in report_dict:
            report_dict[model_name] = {
                "number_of_versions": 0,
                "has_prod_model": False,
                "max_prod_accuracy": None,   # Highest accuracy among prod versions
                "best_challenger": None      # Store best challenger record for non-prod
            }
        report_dict[model_name]["number_of_versions"] += 1

        # Check for prod version
        if record["description"] == "prod":
            report_dict[model_name]["has_prod_model"] = True
            if record["accuracy"] is not None:
                current_max = report_dict[model_name]["max_prod_accuracy"]
                if current_max is None or record["accuracy"] > current_max:
                    report_dict[model_name]["max_prod_accuracy"] = record["accuracy"]
        else:
            # For non-prod versions, update best challenger record if applicable
            if record["accuracy"] is not None:
                best_challenger = report_dict[model_name]["best_challenger"]
                if best_challenger is None or record["accuracy"] > best_challenger["accuracy"]:
                    report_dict[model_name]["best_challenger"] = {
                        "accuracy": record["accuracy"],
                        "version": record["version"]
                    }

    # Create the report list from the aggregated dictionary
    report_results = []
    for model_name, agg in report_dict.items():
        has_prod = agg["has_prod_model"]
        max_prod_accuracy = agg["max_prod_accuracy"]
        best_challenger = agg["best_challenger"]

        # Determine if retraining is needed
        if has_prod:
            needs_retrained = (max_prod_accuracy is None) or (max_prod_accuracy < 0.75)
        else:
            needs_retrained = True

        # New logic: needs_inspected is true if a non-prod (challenger) exists and its accuracy is higher than the prod accuracy.
        if has_prod and best_challenger is not None and max_prod_accuracy is not None and best_challenger["accuracy"] > max_prod_accuracy:
            needs_inspected = True
            challenger_accuracy = best_challenger["accuracy"]
            challenger_version = best_challenger["version"]
            prod_accuracy_col = max_prod_accuracy
        else:
            needs_inspected = False
            challenger_accuracy = None
            challenger_version = None
            prod_accuracy_col = max_prod_accuracy

        report_results.append({
            "model_name": model_name,
            "number_of_versions": agg["number_of_versions"],
            "has_prod_model": has_prod,
            "needs_retrained": needs_retrained,
            "needs_inspected": needs_inspected,
            "prod_accuracy": prod_accuracy_col,
            "challenger_accuracy": challenger_accuracy,
            "challenger_version": challenger_version
        })

    # Create the report Delta table with the additional challenger version column
    report_schema = StructType([
        StructField("model_name", StringType(), True),
        StructField("number_of_versions", IntegerType(), True),
        StructField("has_prod_model", BooleanType(), True),
        StructField("needs_retrained", BooleanType(), True),
        StructField("needs_inspected", BooleanType(), True),
        StructField("prod_accuracy", DoubleType(), True),
        StructField("challenger_accuracy", DoubleType(), True),
        StructField("challenger_version", StringType(), True)
    ])

    report_df = spark.createDataFrame(report_results, schema=report_schema)
    report_df.write.format("delta") \
        .mode("overwrite") \
        .saveAsTable("field_demos.ml_ops.master_model_report")

    print("Model report successfully written to Delta table 'field_demos.ml_ops.master_model_report'.")

In [0]:
generate_master_model_log()

In [0]:
%sql
select * from field_demos.ml_ops.master_model_log

In [0]:
%sql
select * from field_demos.ml_ops.master_model_report

In [0]:
def automated_retrain():
    import datetime
    from pyspark.sql import SparkSession
    from pyspark.sql.types import StructType, StructField, StringType, TimestampType
    from pyspark.sql.functions import current_timestamp

    # Placeholder retrain function. Replace this with your actual retraining logic.
    def retrain_model(model_name):
        print(f"Retraining model '{model_name}'...")
        train_dummy_model_for_model_id(model_name)
        return f"Success"

    # Read the model-level report table
    try:
        report_df = spark.table("field_demos.ml_ops.master_model_report")
    except Exception as e:
        print("Error reading master_model_report table:", e)
        return

    # Filter models that need retraining
    models_to_retrain = report_df.filter("needs_retrained = true").collect()

    # List to collect retraining logs
    retrain_logs = []

    for row in models_to_retrain:
        model_name = row["model_name"]
        # Call the retraining function
        status = retrain_model(model_name)
        # Record the retraining event
        retrain_logs.append({
            "model_name": model_name,
            "retrain_time": datetime.datetime.now(),
            "status": status,
            "details": f"Retraining executed for model {model_name}."
        })

    if retrain_logs:
        # Define the schema for the retrain log table
        log_schema = StructType([
            StructField("model_name", StringType(), True),
            StructField("retrain_time", TimestampType(), True),
            StructField("status", StringType(), True),
            StructField("details", StringType(), True)
        ])

        # Create a DataFrame from the logs
        log_df = spark.createDataFrame(retrain_logs, schema=log_schema)

        # Write logs to the Delta table (append mode)
        log_df.write.format("delta") \
            .mode("append") \
            .saveAsTable("field_demos.ml_ops.model_retrain_log")

        print("Retraining events logged to Delta table 'field_demos.ml_ops.model_retrain_log'.")
    else:
        print("No models required retraining.")

In [0]:
automated_retrain()

In [0]:
generate_master_model_log()

In [0]:
%sql
select * from field_demos.ml_ops.master_model_log

In [0]:
%sql
select * from field_demos.ml_ops.master_model_report