In [1]:
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import mlflow
from mlflow.models import infer_signature

In [None]:
## set the tracking URI
mlflow.set_tracking_uri("http://127.0.0.1:5000")

In [None]:
## load the iris dataset
X,y=datasets.load_iris(return_X_y=True)

## split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20)

## define the model hyperparameters
params = { "solver":"newton-cg", "max_iter":1000, "multi_class":"auto","random_state":1000}
#params = {"penalty":"l2", "solver":"lbfgs", "max_iter":1000, "multi_class":"auto","random_state":8888}

## train the model
lr = LogisticRegression(**params)
lr.fit(X_train, y_train)

In [None]:
## prediction on the test set
y_pred = lr.predict(X_test)
y_pred

In [None]:
## calculate the accuracy of the model
accuracy = accuracy_score(y_test, y_pred)
print(f"Model accuracy: {accuracy}")

## MLFlow Tracking

In [None]:
##MLFlow tracking setup
mlflow.set_tracking_uri(uri="http://127.0.0.1:5000")


##create a new mlflow experiment
mlflow.set_experiment("MLFlow Iris Experiment")

# Start an MLFlow run
with mlflow.start_run():
    # Log the model parameters
    mlflow.log_params(params)

    # Log the accuracy metric
    mlflow.log_metric("accuracy", 1.0)
    #mlflow.log_metric("accuracy", accuracy)

    # Set a tag for the run
    mlflow.set_tag("Training Info-2", "Basic Logistic Regression model for iris data")

    #Infer the model signature
    signature=infer_signature(X_train,lr.predict(X_train))

    # Log the model
    model_info = mlflow.sklearn.log_model(
        sk_model=lr,
        artifact_path="model",
        signature=signature,
        input_example=X_train,  # Log a few examples for input
        #registered_model_name="tracking_iris_model" #register the model only after validating
    )

In [None]:
model_info.model_uri

## Inferencing and Validating Model

In [None]:
from mlflow.models import validate_serving_input

model_uri = model_info.model_uri

# The model is logged with an input example. MLflow converts
# it into the serving payload format for the deployed model endpoint,
# and saves it to 'serving_input_payload.json'
serving_payload = """{
  "inputs": [
    [
      5.7,
      3.8,
      1.7,
      0.3
    ],
    [
      4.8,
      3.4,
      1.6,
      0.2
    ],
    [
      5.6,
      2.9,
      3.6,
      1.3
    ],
    [
      5.4,
      3.7,
      1.5,
      0.2
    ],
    [
      6.7,
      3.3,
      5.7,
      2.5
    ],
    [
      6.5,
      3.0,
      5.8,
      2.2
    ],
    [
      5.1,
      2.5,
      3.0,
      1.1
    ],
    [
      7.7,
      3.0,
      6.1,
      2.3
    ],
    [
      5.2,
      4.1,
      1.5,
      0.1
    ],
    [
      5.0,
      3.3,
      1.4,
      0.2
    ],
    [
      4.6,
      3.6,
      1.0,
      0.2
    ],
    [
      4.5,
      2.3,
      1.3,
      0.3
    ],
    [
      5.6,
      2.8,
      4.9,
      2.0
    ],
    [
      6.2,
      3.4,
      5.4,
      2.3
    ],
    [
      6.3,
      3.4,
      5.6,
      2.4
    ],
    [
      4.8,
      3.4,
      1.9,
      0.2
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      7.2,
      3.0,
      5.8,
      1.6
    ],
    [
      5.4,
      3.9,
      1.7,
      0.4
    ],
    [
      6.9,
      3.2,
      5.7,
      2.3
    ],
    [
      6.5,
      3.0,
      5.5,
      1.8
    ],
    [
      7.7,
      2.6,
      6.9,
      2.3
    ],
    [
      6.4,
      3.1,
      5.5,
      1.8
    ],
    [
      5.2,
      3.5,
      1.5,
      0.2
    ],
    [
      6.3,
      2.9,
      5.6,
      1.8
    ],
    [
      4.9,
      3.1,
      1.5,
      0.2
    ],
    [
      7.9,
      3.8,
      6.4,
      2.0
    ],
    [
      5.5,
      4.2,
      1.4,
      0.2
    ],
    [
      6.1,
      3.0,
      4.9,
      1.8
    ],
    [
      5.0,
      3.2,
      1.2,
      0.2
    ],
    [
      4.9,
      3.0,
      1.4,
      0.2
    ],
    [
      6.0,
      2.2,
      4.0,
      1.0
    ],
    [
      5.0,
      3.6,
      1.4,
      0.2
    ],
    [
      4.6,
      3.4,
      1.4,
      0.3
    ],
    [
      6.9,
      3.1,
      5.1,
      2.3
    ],
    [
      5.0,
      3.5,
      1.3,
      0.3
    ],
    [
      5.8,
      2.8,
      5.1,
      2.4
    ],
    [
      4.9,
      3.1,
      1.5,
      0.1
    ],
    [
      5.6,
      2.7,
      4.2,
      1.3
    ],
    [
      7.0,
      3.2,
      4.7,
      1.4
    ],
    [
      6.3,
      2.7,
      4.9,
      1.8
    ],
    [
      5.0,
      3.4,
      1.5,
      0.2
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      6.1,
      2.9,
      4.7,
      1.4
    ],
    [
      7.2,
      3.6,
      6.1,
      2.5
    ],
    [
      6.7,
      3.1,
      5.6,
      2.4
    ],
    [
      6.3,
      3.3,
      6.0,
      2.5
    ],
    [
      6.3,
      2.5,
      4.9,
      1.5
    ],
    [
      4.8,
      3.1,
      1.6,
      0.2
    ],
    [
      6.4,
      2.8,
      5.6,
      2.1
    ],
    [
      6.2,
      2.9,
      4.3,
      1.3
    ],
    [
      5.7,
      2.5,
      5.0,
      2.0
    ],
    [
      6.0,
      3.4,
      4.5,
      1.6
    ],
    [
      5.2,
      3.4,
      1.4,
      0.2
    ],
    [
      5.7,
      4.4,
      1.5,
      0.4
    ],
    [
      5.0,
      2.3,
      3.3,
      1.0
    ],
    [
      6.2,
      2.2,
      4.5,
      1.5
    ],
    [
      6.3,
      2.3,
      4.4,
      1.3
    ],
    [
      6.3,
      2.8,
      5.1,
      1.5
    ],
    [
      5.1,
      3.3,
      1.7,
      0.5
    ],
    [
      6.0,
      2.9,
      4.5,
      1.5
    ],
    [
      6.5,
      3.0,
      5.2,
      2.0
    ],
    [
      5.3,
      3.7,
      1.5,
      0.2
    ],
    [
      6.8,
      2.8,
      4.8,
      1.4
    ],
    [
      5.9,
      3.0,
      5.1,
      1.8
    ],
    [
      4.7,
      3.2,
      1.6,
      0.2
    ],
    [
      5.5,
      3.5,
      1.3,
      0.2
    ],
    [
      6.7,
      3.0,
      5.2,
      2.3
    ],
    [
      5.7,
      2.6,
      3.5,
      1.0
    ],
    [
      5.8,
      2.7,
      4.1,
      1.0
    ],
    [
      5.5,
      2.5,
      4.0,
      1.3
    ],
    [
      5.1,
      3.8,
      1.9,
      0.4
    ],
    [
      5.7,
      2.8,
      4.5,
      1.3
    ],
    [
      6.7,
      3.1,
      4.4,
      1.4
    ],
    [
      6.4,
      3.2,
      5.3,
      2.3
    ],
    [
      7.7,
      2.8,
      6.7,
      2.0
    ],
    [
      5.5,
      2.6,
      4.4,
      1.2
    ],
    [
      5.7,
      2.9,
      4.2,
      1.3
    ],
    [
      7.4,
      2.8,
      6.1,
      1.9
    ],
    [
      7.6,
      3.0,
      6.6,
      2.1
    ],
    [
      6.3,
      3.3,
      4.7,
      1.6
    ],
    [
      5.4,
      3.0,
      4.5,
      1.5
    ],
    [
      6.2,
      2.8,
      4.8,
      1.8
    ],
    [
      4.4,
      3.2,
      1.3,
      0.2
    ],
    [
      6.6,
      2.9,
      4.6,
      1.3
    ],
    [
      7.2,
      3.2,
      6.0,
      1.8
    ],
    [
      4.7,
      3.2,
      1.3,
      0.2
    ],
    [
      5.7,
      2.8,
      4.1,
      1.3
    ],
    [
      5.5,
      2.3,
      4.0,
      1.3
    ],
    [
      5.1,
      3.8,
      1.5,
      0.3
    ],
    [
      6.1,
      2.6,
      5.6,
      1.4
    ],
    [
      4.6,
      3.2,
      1.4,
      0.2
    ],
    [
      6.7,
      3.3,
      5.7,
      2.1
    ],
    [
      6.7,
      3.1,
      4.7,
      1.5
    ],
    [
      5.6,
      2.5,
      3.9,
      1.1
    ],
    [
      5.1,
      3.5,
      1.4,
      0.2
    ],
    [
      5.0,
      3.4,
      1.6,
      0.4
    ],
    [
      5.1,
      3.5,
      1.4,
      0.3
    ],
    [
      4.8,
      3.0,
      1.4,
      0.1
    ],
    [
      6.1,
      2.8,
      4.7,
      1.2
    ],
    [
      5.9,
      3.0,
      4.2,
      1.5
    ],
    [
      5.0,
      3.5,
      1.6,
      0.6
    ],
    [
      6.1,
      3.0,
      4.6,
      1.4
    ],
    [
      6.8,
      3.0,
      5.5,
      2.1
    ],
    [
      5.1,
      3.8,
      1.6,
      0.2
    ],
    [
      5.7,
      3.0,
      4.2,
      1.2
    ],
    [
      4.4,
      2.9,
      1.4,
      0.2
    ],
    [
      4.4,
      3.0,
      1.3,
      0.2
    ],
    [
      5.4,
      3.4,
      1.5,
      0.4
    ],
    [
      7.7,
      3.8,
      6.7,
      2.2
    ],
    [
      5.0,
      2.0,
      3.5,
      1.0
    ],
    [
      6.4,
      2.9,
      4.3,
      1.3
    ],
    [
      6.0,
      2.2,
      5.0,
      1.5
    ],
    [
      4.6,
      3.1,
      1.5,
      0.2
    ],
    [
      5.8,
      2.7,
      3.9,
      1.2
    ],
    [
      4.9,
      2.5,
      4.5,
      1.7
    ],
    [
      4.8,
      3.0,
      1.4,
      0.3
    ],
    [
      5.4,
      3.9,
      1.3,
      0.4
    ],
    [
      6.4,
      3.2,
      4.5,
      1.5
    ],
    [
      6.3,
      2.5,
      5.0,
      1.9
    ]
  ]
}"""

# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

# Load the Model back for prediction as a generic python function model

In [None]:
loaded_model=mlflow.pyfunc.load_model(model_info.model_uri)
predictions = loaded_model.predict(X_test)

iris_features_name=datasets.load_iris().feature_names

result=pd.DataFrame(X_test,columns=iris_features_name)
result["actual_class"]=y_test
result["predicted_class"]=predictions

In [None]:
result[:5] # Display the first 5 rows of the result DataFrame

## Inferencing from model registry

In [None]:
import mlflow.sklearn
import mlflow.sklearn
model_name = "tracking_iris_model"
model_version = "latest"

model_uri = f"models:/{model_name}/{model_version}"

model=mlflow.sklearn.load_model(model_uri)
model

In [None]:
model_uri

In [None]:
y_pred_new=model.predict(X_test)
y_pred_new