In [0]:
from pyspark.sql.functions import col

# Create widgets for min and max accuracy
dbutils.widgets.text("min_accuracy", ".5", "Min Accuracy")
dbutils.widgets.text("max_accuracy", ".8", "Max Accuracy")

# Retrieve the values from the widgets
min_accuracy = float(dbutils.widgets.get("min_accuracy"))
max_accuracy = float(dbutils.widgets.get("max_accuracy"))

In [0]:
import random
from mlflow.tracking import MlflowClient
import mlflow

# Set the registry URI for Unity Catalog
mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"

# Initialize the MLflow client
client = MlflowClient()

# Search for all registered models
registered_models = client.search_registered_models()

# Iterate over models in the specified namespace and update their accuracy metric
for model in registered_models:
    if model.name.startswith("field_demos.ml_ops"):
        print(f"Processing model: {model.name}")
        # Retrieve all versions of this model
        model_versions = client.search_model_versions(f"name='{model.name}'")
        if not model_versions:
            print("  No versions found.")
        else:
            for version in model_versions:
                new_accuracy = round(random.uniform(min_accuracy, max_accuracy), 3)
                print(f"  Updating version: {version.version} (Run ID: {version.run_id}) with new accuracy: {new_accuracy}")
                try:
                    # Resume the run using its run_id and log the new accuracy metric.
                    # This assumes the run is still active or can be resumed.
                    with mlflow.start_run(run_id=version.run_id):
                        mlflow.log_metric("accuracy", new_accuracy)
                except Exception as e:
                    print(f"  Failed to update run {version.run_id}: {e}")
        print("-" * 50)