In [0]:
dbutils.widgets.text("model_name", "")

In [0]:
model_name = dbutils.widgets.get("model_name")

In [0]:
model_name

In [0]:
import mlflow
import mlflow.pyfunc
import random
import pandas as pd
from mlflow.models.signature import infer_signature
from pyspark.sql.functions import current_timestamp
from mlflow.tracking import MlflowClient
from datetime import datetime


class DummyModel(mlflow.pyfunc.PythonModel):
    def __init__(self, model_id, coefficient, bias, accuracy):
        self.model_id = model_id
        self.coefficient = coefficient
        self.bias = bias
        self.accuracy = accuracy

    def predict(self, context, model_input):
        # Simple linear transformation: output = coefficient * input + bias
        return [self.coefficient * x + self.bias for x in model_input]

def train_dummy_model_for_model_id(model_id):
    """
    Trains a dummy model for a given model_id, logs it with MLflow,
    sets a description on the model version based on version (using the MLflow description methods), and 
    appends run details (including the UC model location) to the master table.
    """
    # Generate dummy parameters and accuracy.
    coefficient = random.uniform(0.5, 1.5)
    bias = random.uniform(-1, 1)
    accuracy = random.uniform(0.75, 0.8)
    
    model = DummyModel(model_id, coefficient, bias, accuracy)
    
    # Create an example input DataFrame for signature inference.
    example_input = pd.DataFrame({"x": [0, 1, 2, 3, 4, 5]})
    example_output = model.predict(None, example_input["x"].tolist())
    signature = infer_signature(example_input, example_output)
    
    artifact_path = f"dummy_model_{model_id}"
    registered_model_name = model_id
    
    with mlflow.start_run() as run:
        mlflow.log_param("model_id", model_id)
        mlflow.log_param("coefficient", coefficient)
        mlflow.log_param("bias", bias)
        mlflow.log_metric("accuracy", accuracy)
        
        mlflow.pyfunc.log_model(
            artifact_path=artifact_path,
            python_model=model,
            signature=signature,
            registered_model_name=registered_model_name
        )
        
        run_id = run.info.run_id
        print(f"Logged DummyModel for model_id '{model_id}' with run ID: {run_id} and accuracy: {accuracy:.3f}")
        
        client = MlflowClient()
        # Retrieve all versions of this registered model.
        model_versions = client.search_model_versions(f"name='{registered_model_name}'")
        new_version = max([int(mv.version) for mv in model_versions])
        
        # Instead of setting an alias, update the model version's description.
        if len(model_versions) == 1:
            description = "prod"
        else:
            description = f"challenger_{datetime.now().strftime('%Y%m%d%H%M%S')}"
        
        client.update_model_version(
            name=registered_model_name,
            version=str(new_version),
            description=description
        )
        print(f"Set description for model '{registered_model_name}' version {new_version} to '{description}'.")
        
        # Build the UC model location (URI) that can be used to load the model.
        # Note: The URI no longer uses alias but can refer to the version with description.
        model_location = f"models:/{registered_model_name}/{new_version}"
    
    return model

In [0]:
# Ensure MLflow uses the Databricks Unity Catalog registry.
mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"

train_dummy_model_for_model_id(model_name)