In [0]:

from mlflow.tracking import MlflowClient
import mlflow
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, BooleanType, IntegerType

# Set the registry URI for Unity Catalog
mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"

# Initialize the MLflow client
client = MlflowClient()

# List to collect results for the master model log (version-level)
results = []

# Search for all registered models
registered_models = client.search_registered_models()

# Iterate over models and filter for those in the field_demos.ml_ops namespace
for model in registered_models:
    if model.name.startswith("field_demos.ml_ops"):
        # Retrieve all versions of this model
        model_versions = client.search_model_versions(f"name='{model.name}'")
        if not model_versions:
            results.append({
                "model_name": model.name,
                "version": None,
                "run_id": None,
                "accuracy": None,
                "description": None
            })
        else:
            for version in model_versions:
                # Retrieve run details to extract metrics
                run = client.get_run(version.run_id)
                # Get the accuracy metric (returns None if not found)
                accuracy = run.data.metrics.get("accuracy")
                results.append({
                    "model_name": model.name,
                    "version": version.version,
                    "run_id": version.run_id,
                    "accuracy": float(accuracy) if accuracy is not None else None,
                    "description": version.description
                })

# Create the master model log Delta table
master_schema = StructType([
    StructField("model_name", StringType(), True),
    StructField("version", StringType(), True),
    StructField("run_id", StringType(), True),
    StructField("accuracy", DoubleType(), True),
    StructField("description", StringType(), True)
])

master_df = spark.createDataFrame(results, schema=master_schema)
master_df.write.format("delta") \
    .mode("overwrite") \
    .saveAsTable("field_demos.ml_ops.master_model_log")

print("Master model log successfully written to Delta table 'field_demos.ml_ops.master_model_log'.")

# Build the model-level report with additional challenger info
# For each model we track:
# - number_of_versions: count of versions
# - has_prod_model: True if any version has description "prod"
# - max_prod_accuracy: highest accuracy among prod versions
# - best_challenger: best non-prod record (dict with keys: accuracy, version) encountered
report_dict = {}
for record in results:
    model_name = record["model_name"]
    if model_name not in report_dict:
        report_dict[model_name] = {
            "number_of_versions": 0,
            "has_prod_model": False,
            "max_prod_accuracy": None,   # Highest accuracy among prod versions
            "best_challenger": None      # Store best challenger record for non-prod
        }
    report_dict[model_name]["number_of_versions"] += 1

    # Check for prod version
    if record["description"] == "prod":
        report_dict[model_name]["has_prod_model"] = True
        if record["accuracy"] is not None:
            current_max = report_dict[model_name]["max_prod_accuracy"]
            if current_max is None or record["accuracy"] > current_max:
                report_dict[model_name]["max_prod_accuracy"] = record["accuracy"]
    else:
        # For non-prod versions, update best challenger record if applicable
        if record["accuracy"] is not None:
            best_challenger = report_dict[model_name]["best_challenger"]
            if best_challenger is None or record["accuracy"] > best_challenger["accuracy"]:
                report_dict[model_name]["best_challenger"] = {
                    "accuracy": record["accuracy"],
                    "version": record["version"]
                }

# Create the report list from the aggregated dictionary
report_results = []
for model_name, agg in report_dict.items():
    has_prod = agg["has_prod_model"]
    max_prod_accuracy = agg["max_prod_accuracy"]
    best_challenger = agg["best_challenger"]

    # Determine if retraining is needed
    if has_prod:
        needs_retrained = (max_prod_accuracy is None) or (max_prod_accuracy < 0.75)
    else:
        needs_retrained = True

    # New logic: needs_inspected is true if a non-prod (challenger) exists and its accuracy is higher than the prod accuracy.
    if has_prod and best_challenger is not None and max_prod_accuracy is not None and best_challenger["accuracy"] > max_prod_accuracy:
        needs_inspected = True
        challenger_accuracy = best_challenger["accuracy"]
        challenger_version = best_challenger["version"]
        prod_accuracy_col = max_prod_accuracy
    else:
        needs_inspected = False
        challenger_accuracy = None
        challenger_version = None
        prod_accuracy_col = max_prod_accuracy

    report_results.append({
        "model_name": model_name,
        "number_of_versions": agg["number_of_versions"],
        "has_prod_model": has_prod,
        "needs_retrained": needs_retrained,
        "needs_inspected": needs_inspected,
        "prod_accuracy": prod_accuracy_col,
        "challenger_accuracy": challenger_accuracy,
        "challenger_version": challenger_version
    })

# Create the report Delta table with the additional challenger version column
report_schema = StructType([
    StructField("model_name", StringType(), True),
    StructField("number_of_versions", IntegerType(), True),
    StructField("has_prod_model", BooleanType(), True),
    StructField("needs_retrained", BooleanType(), True),
    StructField("needs_inspected", BooleanType(), True),
    StructField("prod_accuracy", DoubleType(), True),
    StructField("challenger_accuracy", DoubleType(), True),
    StructField("challenger_version", StringType(), True)
])

report_df = spark.createDataFrame(report_results, schema=report_schema)
report_df.write.format("delta") \
    .mode("overwrite") \
    .saveAsTable("field_demos.ml_ops.master_model_report")

print("Model report successfully written to Delta table 'field_demos.ml_ops.master_model_report'.")