In [17]:
import mlflow
from mlflow.models import infer_signature
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

In [18]:
## load the dataset
X,y=datasets.load_iris(return_X_y=True)

## split the data into training and tests sets
X_train,X_test,y_traint,y_test= train_test_split(X,y,test_size=0.20)

# define the model hyperparameter
params = {"penalty":"l2", "random_state":8888, "solver":"lbfgs", "max_iter":1000,"multi_class":"auto"}

# train the model
lr = LogisticRegression(**params)
lr.fit(X_train,y_traint)



In [19]:
# prediction on the test set

y_pred = lr.predict(X_test)
y_pred

array([1, 0, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 0, 1, 0, 1, 1, 2, 0, 2, 0, 0,
       2, 2, 1, 1, 0, 2, 2, 0])

In [20]:
accuracy = accuracy_score(y_test,y_pred)
accuracy

0.9666666666666667

In [21]:
# MLFLOW Tracking
mlflow.set_tracking_uri(uri="http://127.0.0.1:5000")

In [None]:
# New MLFLOW experilent
mlflow.set_experiment("MlFLOW Quickstart")

## Start and Mlflow run
with mlflow.start_run():
    mlflow.log_params(params)

    ## Log the accuracy metrics
    mlflow.log_metric("accuracy",accuracy)
    mlflow.set_tag("Training Info", "Basic LR model for iris Data")

    # Infer the model signature
    signature = infer_signature(X_train,lr.predict(X_train))

    # Log the model
    model_info = mlflow.sklearn.log_model(
        sk_model = lr,
        artifact_path = "iris_model",
        signature = signature,
        input_example = X_train,
        registered_model_name = "tracking-quistart",
    )



Successfully registered model 'tracking-quistart'.
2024/12/30 12:46:28 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quistart, version 1


🏃 View run sedate-sheep-415 at: http://127.0.0.1:5000/#/experiments/525465168285721666/runs/4702f4a7919640a6aa27f6818e851f97
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/525465168285721666


Created version '1' of model 'tracking-quistart'.


In [25]:
# define the model hyperparameter
params = {"random_state":8888, "solver":"newton-cg", "max_iter":1000,"multi_class":"auto"}

# train the model
lr = LogisticRegression(**params)
lr.fit(X_train,y_traint)



In [26]:
# prediction on the test set

y_pred = lr.predict(X_test)
y_pred

array([1, 0, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 0, 1, 0, 1, 1, 2, 0, 2, 0, 0,
       2, 2, 1, 1, 0, 2, 2, 0])

In [27]:
accuracy = accuracy_score(y_test,y_pred)
accuracy

0.9666666666666667

In [28]:
# New MLFLOW experilent
mlflow.set_experiment("MlFLOW Quickstart")

## Start and Mlflow run
with mlflow.start_run():
    mlflow.log_params(params)

    ## Log the accuracy metrics
    mlflow.log_metric("accuracy",accuracy)
    mlflow.set_tag("Training Info", "Basic LR model for iris Data")

    # Infer the model signature
    signature = infer_signature(X_train,lr.predict(X_train))

    # Log the model
    model_info = mlflow.sklearn.log_model(
        sk_model = lr,
        artifact_path = "iris_model",
        signature = signature,
        input_example = X_train,
        registered_model_name = "tracking-quistart",
    )

Registered model 'tracking-quistart' already exists. Creating a new version of this model...
2024/12/30 14:50:49 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quistart, version 2


🏃 View run delightful-fowl-832 at: http://127.0.0.1:5000/#/experiments/525465168285721666/runs/022047dde6374141825695fa8b7f80d8
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/525465168285721666


Created version '2' of model 'tracking-quistart'.


## Inferencing and validate model

In [31]:
from mlflow.models import validate_serving_input

model_uri = 'runs:/022047dde6374141825695fa8b7f80d8/iris_model'

# The model is logged with an input example. MLflow converts
# it into the serving payload format for the deployed model endpoint,
# and saves it to 'serving_input_payload.json'
serving_payload = """{
  "inputs": [
    [
      5.1,
      3.8,
      1.6,
      0.2
    ],
    [
      5.2,
      3.4,
      1.4,
      0.2
    ],
    [
      6.0,
      3.0,
      4.8,
      1.8
    ],
    [
      6.3,
      2.8,
      5.1,
      1.5
    ],
    [
      6.9,
      3.1,
      4.9,
      1.5
    ],
    [
      6.1,
      3.0,
      4.6,
      1.4
    ],
    [
      5.8,
      2.7,
      4.1,
      1.0
    ],
    [
      4.8,
      3.0,
      1.4,
      0.1
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      5.1,
      3.8,
      1.9,
      0.4
    ],
    [
      7.2,
      3.2,
      6.0,
      1.8
    ],
    [
      6.3,
      2.5,
      4.9,
      1.5
    ],
    [
      7.7,
      2.6,
      6.9,
      2.3
    ],
    [
      5.4,
      3.9,
      1.7,
      0.4
    ],
    [
      6.9,
      3.1,
      5.4,
      2.1
    ],
    [
      5.7,
      2.5,
      5.0,
      2.0
    ],
    [
      6.7,
      3.1,
      5.6,
      2.4
    ],
    [
      5.1,
      3.7,
      1.5,
      0.4
    ],
    [
      5.5,
      2.5,
      4.0,
      1.3
    ],
    [
      6.7,
      3.3,
      5.7,
      2.1
    ],
    [
      6.0,
      2.2,
      5.0,
      1.5
    ],
    [
      5.7,
      3.0,
      4.2,
      1.2
    ],
    [
      5.6,
      2.7,
      4.2,
      1.3
    ],
    [
      5.2,
      3.5,
      1.5,
      0.2
    ],
    [
      6.7,
      3.1,
      4.4,
      1.4
    ],
    [
      5.5,
      2.3,
      4.0,
      1.3
    ],
    [
      5.2,
      4.1,
      1.5,
      0.1
    ],
    [
      5.8,
      2.6,
      4.0,
      1.2
    ],
    [
      6.1,
      2.8,
      4.7,
      1.2
    ],
    [
      6.7,
      3.0,
      5.2,
      2.3
    ],
    [
      6.5,
      2.8,
      4.6,
      1.5
    ],
    [
      5.0,
      3.4,
      1.6,
      0.4
    ],
    [
      5.0,
      3.6,
      1.4,
      0.2
    ],
    [
      7.4,
      2.8,
      6.1,
      1.9
    ],
    [
      6.3,
      3.4,
      5.6,
      2.4
    ],
    [
      7.7,
      3.0,
      6.1,
      2.3
    ],
    [
      5.5,
      3.5,
      1.3,
      0.2
    ],
    [
      6.7,
      3.3,
      5.7,
      2.5
    ],
    [
      5.7,
      2.8,
      4.5,
      1.3
    ],
    [
      5.1,
      3.3,
      1.7,
      0.5
    ],
    [
      4.6,
      3.1,
      1.5,
      0.2
    ],
    [
      5.9,
      3.2,
      4.8,
      1.8
    ],
    [
      5.0,
      3.5,
      1.3,
      0.3
    ],
    [
      5.6,
      2.9,
      3.6,
      1.3
    ],
    [
      6.6,
      3.0,
      4.4,
      1.4
    ],
    [
      4.9,
      3.1,
      1.5,
      0.2
    ],
    [
      5.7,
      2.9,
      4.2,
      1.3
    ],
    [
      7.7,
      2.8,
      6.7,
      2.0
    ],
    [
      6.5,
      3.0,
      5.5,
      1.8
    ],
    [
      6.1,
      3.0,
      4.9,
      1.8
    ],
    [
      4.6,
      3.4,
      1.4,
      0.3
    ],
    [
      6.4,
      2.9,
      4.3,
      1.3
    ],
    [
      7.2,
      3.0,
      5.8,
      1.6
    ],
    [
      6.3,
      2.9,
      5.6,
      1.8
    ],
    [
      5.6,
      3.0,
      4.1,
      1.3
    ],
    [
      5.4,
      3.7,
      1.5,
      0.2
    ],
    [
      6.8,
      3.0,
      5.5,
      2.1
    ],
    [
      4.5,
      2.3,
      1.3,
      0.3
    ],
    [
      6.0,
      2.2,
      4.0,
      1.0
    ],
    [
      5.0,
      3.5,
      1.6,
      0.6
    ],
    [
      5.8,
      2.8,
      5.1,
      2.4
    ],
    [
      5.6,
      3.0,
      4.5,
      1.5
    ],
    [
      6.6,
      2.9,
      4.6,
      1.3
    ],
    [
      6.2,
      2.2,
      4.5,
      1.5
    ],
    [
      5.7,
      2.8,
      4.1,
      1.3
    ],
    [
      5.6,
      2.8,
      4.9,
      2.0
    ],
    [
      7.0,
      3.2,
      4.7,
      1.4
    ],
    [
      6.9,
      3.1,
      5.1,
      2.3
    ],
    [
      6.4,
      3.2,
      4.5,
      1.5
    ],
    [
      4.8,
      3.4,
      1.9,
      0.2
    ],
    [
      4.4,
      2.9,
      1.4,
      0.2
    ],
    [
      5.4,
      3.4,
      1.7,
      0.2
    ],
    [
      5.2,
      2.7,
      3.9,
      1.4
    ],
    [
      7.3,
      2.9,
      6.3,
      1.8
    ],
    [
      4.9,
      3.1,
      1.5,
      0.1
    ],
    [
      5.4,
      3.4,
      1.5,
      0.4
    ],
    [
      6.3,
      2.7,
      4.9,
      1.8
    ],
    [
      6.7,
      2.5,
      5.8,
      1.8
    ],
    [
      6.7,
      3.1,
      4.7,
      1.5
    ],
    [
      6.4,
      3.1,
      5.5,
      1.8
    ],
    [
      6.1,
      2.9,
      4.7,
      1.4
    ],
    [
      4.8,
      3.0,
      1.4,
      0.3
    ],
    [
      4.9,
      2.4,
      3.3,
      1.0
    ],
    [
      5.1,
      3.4,
      1.5,
      0.2
    ],
    [
      5.3,
      3.7,
      1.5,
      0.2
    ],
    [
      5.1,
      3.5,
      1.4,
      0.2
    ],
    [
      5.7,
      3.8,
      1.7,
      0.3
    ],
    [
      4.3,
      3.0,
      1.1,
      0.1
    ],
    [
      5.7,
      2.6,
      3.5,
      1.0
    ],
    [
      4.8,
      3.1,
      1.6,
      0.2
    ],
    [
      6.0,
      2.7,
      5.1,
      1.6
    ],
    [
      6.4,
      3.2,
      5.3,
      2.3
    ],
    [
      5.1,
      3.8,
      1.5,
      0.3
    ],
    [
      4.7,
      3.2,
      1.3,
      0.2
    ],
    [
      4.9,
      2.5,
      4.5,
      1.7
    ],
    [
      5.0,
      3.0,
      1.6,
      0.2
    ],
    [
      5.8,
      4.0,
      1.2,
      0.2
    ],
    [
      5.5,
      4.2,
      1.4,
      0.2
    ],
    [
      6.5,
      3.0,
      5.2,
      2.0
    ],
    [
      6.4,
      2.8,
      5.6,
      2.2
    ],
    [
      5.1,
      3.5,
      1.4,
      0.3
    ],
    [
      6.4,
      2.8,
      5.6,
      2.1
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      6.5,
      3.2,
      5.1,
      2.0
    ],
    [
      6.8,
      2.8,
      4.8,
      1.4
    ],
    [
      4.6,
      3.6,
      1.0,
      0.2
    ],
    [
      6.1,
      2.8,
      4.0,
      1.3
    ],
    [
      5.4,
      3.0,
      4.5,
      1.5
    ],
    [
      7.6,
      3.0,
      6.6,
      2.1
    ],
    [
      5.0,
      3.3,
      1.4,
      0.2
    ],
    [
      5.7,
      4.4,
      1.5,
      0.4
    ],
    [
      7.9,
      3.8,
      6.4,
      2.0
    ],
    [
      5.0,
      2.0,
      3.5,
      1.0
    ],
    [
      6.0,
      3.4,
      4.5,
      1.6
    ],
    [
      4.7,
      3.2,
      1.6,
      0.2
    ],
    [
      5.5,
      2.6,
      4.4,
      1.2
    ],
    [
      6.5,
      3.0,
      5.8,
      2.2
    ],
    [
      6.0,
      2.9,
      4.5,
      1.5
    ],
    [
      5.5,
      2.4,
      3.7,
      1.0
    ],
    [
      4.6,
      3.2,
      1.4,
      0.2
    ]
  ]
}"""

# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

array([0, 0, 2, 2, 1, 1, 1, 0, 2, 0, 2, 1, 2, 0, 2, 2, 2, 0, 1, 2, 1, 1,
       1, 0, 1, 1, 0, 1, 1, 2, 1, 0, 0, 2, 2, 2, 0, 2, 1, 0, 0, 2, 0, 1,
       1, 0, 1, 2, 2, 2, 0, 1, 2, 2, 1, 0, 2, 0, 1, 0, 2, 1, 1, 1, 1, 2,
       1, 2, 1, 0, 0, 0, 1, 2, 0, 0, 2, 2, 1, 2, 1, 0, 1, 0, 0, 0, 0, 0,
       1, 0, 2, 2, 0, 0, 1, 0, 0, 0, 2, 2, 0, 2, 2, 2, 1, 0, 1, 1, 2, 0,
       0, 2, 1, 1, 0, 1, 2, 1, 1, 0])

In [32]:
model_info.model_uri

'runs:/022047dde6374141825695fa8b7f80d8/iris_model'

## Load the model back for prediction as a generic python function model

In [None]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)