# Import Dependencies

In [0]:
import mlflow
import mlflow.sklearn
from mlflow.tracking import MlflowClient
from mlflow.models.signature import infer_signature

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

import os
import pandas as pd

# Read Configs

In [0]:
# Read job parameters via widgets if present, else fall back to env/defaults.
def get_param(name, default):
    try:
        # When run as a Notebook task, dbutils is available
        return dbutils.widgets.get(name)
    except Exception:
        # Fallback for local/dev runs
        return os.getenv(name.upper(), str(default))

# Define widgets so Job task can pass parameters
try:
    dbutils.widgets.text("n_estimators", "100", "n_estimators")
    dbutils.widgets.text("accuracy_threshold", "0.85", "accuracy_threshold")
    dbutils.widgets.text("git_ref", "", "git_ref")
    dbutils.widgets.text("git_sha", "", "git_sha")
except Exception:
    pass

# Train Model

In [0]:
n_estimators = int(get_param("n_estimators", 100))
accuracy_threshold = float(get_param("accuracy_threshold", 0.85))
git_ref = get_param("git_ref", "")
git_sha = get_param("git_sha", "")

experiment_path = "/Shared/mlflow-ci-cd-poc"
mlflow.set_experiment(experiment_path)

with mlflow.start_run() as run:
    # Data
    X, y = load_iris(return_X_y=True)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    X_train_df = pd.DataFrame(X_train, columns=load_iris().feature_names)
    X_test_df = pd.DataFrame(X_test, columns=load_iris().feature_names)

    # Train
    model = RandomForestClassifier(n_estimators=n_estimators, random_state=42)
    model.fit(X_train_df, y_train)

    # Evaluate
    preds = model.predict(X_test_df)
    acc = accuracy_score(y_test, preds)

    # Log
    mlflow.log_param("n_estimators", n_estimators)
    mlflow.log_metric("accuracy", acc)
    mlflow.set_tag("env", "dev")
    if git_ref:
        mlflow.set_tag("git_ref", git_ref)
    if git_sha:
        mlflow.set_tag("git_sha", git_sha)

    print(f"[INFO] accuracy={acc:.4f}")

    signature = infer_signature(X_train_df, model.predict(X_train_df))
    
    mlflow.sklearn.log_model(
        sk_model=model,
        artifact_path="model",
        signature=signature,
        input_example=X_train_df.iloc[:5]
    )

    # Register in Model Registry
    model_name = "mlflow_ci_cd_poc_model"
    registered = mlflow.register_model(
        model_uri=f"runs:/{run.info.run_id}/model",
        name=model_name
    )


[INFO] accuracy=1.0000


Registered model 'mlflow_ci_cd_poc_model' already exists. Creating a new version of this model...
Created version '2' of model 'workspace.default.mlflow_ci_cd_poc_model'.
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection
INFO:py4j.clientserver:Closing down clientserver connection


# Register Code

In [0]:
client = MlflowClient()
version = registered.version

client.set_model_version_tag(model_name, version, "env", "dev")
client.set_model_version_tag(model_name, version, "accuracy", str(acc))
if git_ref:
    client.set_model_version_tag(model_name, version, "git_ref", git_ref)
if git_sha:
    client.set_model_version_tag(model_name, version, "git_sha", git_sha)