In [None]:
import mlflow
import mlflow.sklearn
from mlflow.models.signature import infer_signature
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, f1_score


In [None]:
mlflow.set_tracking_uri(uri="http://127.0.0.1:5000")
mlflow.set_experiment("mlflow_comparemodels")


In [None]:
X , Y = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)


In [None]:
models = {
"Logistic_regression" : LogisticRegression(max_iter=1000),
"Random_forest" : RandomForestClassifier(n_estimators=100),
"Support Vector machine" : SVC(probability=True)
}

In [None]:
# Train each model, log its parameters, metrics, and other details in MLflow
for model_name, model in models.items():
    with mlflow.start_run(run_name=model_name):
        # Fit the model
        model.fit(X_train, y_train)

        # Predict on the test set
        predictions = model.predict(X_test)

        # Calculate performance metrics
        accuracy = accuracy_score(y_test, predictions)
        f1 = f1_score(y_test, predictions, average='macro')

        # Log parameters, metrics, and other metadata
        mlflow.log_param("model_name", model_name)
        mlflow.log_metric("accuracy", accuracy)
        mlflow.log_metric("f1_score", f1)

        # Add a descriptive tag for the run
        mlflow.set_tag("Training Info", f"{model_name} model for Iris dataset")

        # Infer the model signature for logging
        signature = infer_signature(X_train, model.predict(X_train))

        # Log the trained model with its signature
        mlflow.sklearn.log_model(
            sk_model=model,
            artifact_path=f"{model_name.lower().replace(' ', '_')}_model",
            signature=signature,
            input_example=X_train,
            registered_model_name=f"{model_name.replace(' ', '_')}_tracking_example"
        )
