# End to End ML Project

In [1]:
# import libraries
import mlflow
import pandas as pd
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from mlflow.models import infer_signature

In [2]:
## set the tracking uri
mlflow.set_tracking_uri(uri="http://127.0.0.1:5000")

In [3]:
# load the dataset
X, y = datasets.load_iris(return_X_y=True)
X.shape, y.shape

((150, 4), (150,))

In [4]:
# train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

X_train.shape, X_test.shape, y_train.shape, y_test.shape

((120, 4), (30, 4), (120,), (30,))

In [5]:
# define the model hyperparameter
params = {"penalty":"l2", "solver":"lbfgs", "max_iter":1000, "multi_class":"auto", "random_state":8888}

# initialize the model
model = LogisticRegression(**params)

# fit the model
model.fit(X_train, y_train)



In [6]:
# Prediction on test set
y_pred = model.predict(X_test)

y_pred

array([2, 0, 2, 2, 0, 2, 2, 1, 2, 0, 2, 2, 0, 1, 0, 2, 1, 2, 2, 0, 1, 2,
       1, 1, 0, 2, 0, 1, 2, 2])

In [7]:
accuracy = accuracy_score(y_test, y_pred)
print(accuracy)

0.9666666666666667


## MLFlow Tracking

In [9]:
# Create a new MLFlow Experiment
mlflow.set_experiment("MLFlow Quickstart")

# Start a MLFlow run
with mlflow.start_run():
    # log the hyperparameters
    mlflow.log_params(params)

    # log the accuracy metrics
    mlflow.log_metric("accuracy", accuracy)

    # set a tag that we can use to reming ourselves what this run was for
    mlflow.set_tag("Training Info", "Basic LR Model for Iris data")

    # Infer the model signature    
    signature = infer_signature(X_train, model.predict(X_test)) 

    # log the model
    model_info = mlflow.sklearn.log_model(
        sk_model=model,
        artifact_path="iris_model",
        signature=signature,
        input_example=X_train,
        registered_model_name="tracking-quickstart"
    )



2024/10/26 11:33:07 INFO mlflow.tracking.fluent: Experiment with name 'MLFlow Quickstart' does not exist. Creating a new experiment.


Successfully registered model 'tracking-quickstart'.
2024/10/26 11:33:13 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-quickstart, version 1
Created version '1' of model 'tracking-quickstart'.
2024/10/26 11:33:13 INFO mlflow.tracking._tracking_service.client: 🏃 View run worried-stag-287 at: http://127.0.0.1:5000/#/experiments/775594214204317554/runs/9e77193d95b14a94b4979f3d0105a1cf.
2024/10/26 11:33:13 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/775594214204317554.


## Inferencing and Model Validation

Model inferencing is the process of using a trained model to make predictions on new data.

In [10]:
from mlflow.models import validate_serving_input

model_uri = 'runs:/9e77193d95b14a94b4979f3d0105a1cf/iris_model' # this is also the path stored in model_info.model_uri

# The model is logged with an input example. MLflow converts
# it into the serving payload format for the deployed model endpoint,
# and saves it to 'serving_input_payload.json'
serving_payload = """{
  "inputs": [
    [
      5.4,
      3.9,
      1.7,
      0.4
    ],
    [
      6.3,
      2.5,
      5.0,
      1.9
    ],
    [
      6.0,
      2.9,
      4.5,
      1.5
    ],
    [
      7.1,
      3.0,
      5.9,
      2.1
    ],
    [
      5.8,
      2.7,
      3.9,
      1.2
    ],
    [
      5.1,
      3.5,
      1.4,
      0.3
    ],
    [
      6.3,
      2.5,
      4.9,
      1.5
    ],
    [
      5.2,
      3.5,
      1.5,
      0.2
    ],
    [
      5.5,
      2.3,
      4.0,
      1.3
    ],
    [
      4.9,
      2.4,
      3.3,
      1.0
    ],
    [
      5.1,
      3.8,
      1.5,
      0.3
    ],
    [
      6.8,
      3.0,
      5.5,
      2.1
    ],
    [
      6.3,
      3.3,
      4.7,
      1.6
    ],
    [
      6.2,
      3.4,
      5.4,
      2.3
    ],
    [
      7.7,
      3.0,
      6.1,
      2.3
    ],
    [
      7.7,
      2.8,
      6.7,
      2.0
    ],
    [
      5.0,
      3.2,
      1.2,
      0.2
    ],
    [
      4.6,
      3.1,
      1.5,
      0.2
    ],
    [
      4.8,
      3.0,
      1.4,
      0.1
    ],
    [
      6.4,
      3.2,
      4.5,
      1.5
    ],
    [
      5.5,
      2.6,
      4.4,
      1.2
    ],
    [
      5.4,
      3.9,
      1.3,
      0.4
    ],
    [
      7.3,
      2.9,
      6.3,
      1.8
    ],
    [
      4.6,
      3.6,
      1.0,
      0.2
    ],
    [
      6.3,
      2.7,
      4.9,
      1.8
    ],
    [
      5.2,
      3.4,
      1.4,
      0.2
    ],
    [
      6.4,
      2.8,
      5.6,
      2.1
    ],
    [
      5.5,
      2.4,
      3.8,
      1.1
    ],
    [
      6.3,
      2.8,
      5.1,
      1.5
    ],
    [
      6.3,
      2.3,
      4.4,
      1.3
    ],
    [
      5.7,
      2.5,
      5.0,
      2.0
    ],
    [
      6.2,
      2.2,
      4.5,
      1.5
    ],
    [
      4.9,
      3.1,
      1.5,
      0.1
    ],
    [
      7.2,
      3.6,
      6.1,
      2.5
    ],
    [
      7.0,
      3.2,
      4.7,
      1.4
    ],
    [
      5.1,
      3.8,
      1.9,
      0.4
    ],
    [
      5.4,
      3.7,
      1.5,
      0.2
    ],
    [
      6.9,
      3.1,
      5.1,
      2.3
    ],
    [
      4.4,
      2.9,
      1.4,
      0.2
    ],
    [
      5.6,
      3.0,
      4.1,
      1.3
    ],
    [
      5.0,
      3.3,
      1.4,
      0.2
    ],
    [
      6.5,
      3.2,
      5.1,
      2.0
    ],
    [
      5.0,
      2.0,
      3.5,
      1.0
    ],
    [
      6.4,
      3.1,
      5.5,
      1.8
    ],
    [
      6.7,
      3.3,
      5.7,
      2.5
    ],
    [
      5.1,
      3.4,
      1.5,
      0.2
    ],
    [
      6.4,
      2.9,
      4.3,
      1.3
    ],
    [
      5.7,
      2.6,
      3.5,
      1.0
    ],
    [
      4.8,
      3.4,
      1.6,
      0.2
    ],
    [
      6.1,
      3.0,
      4.6,
      1.4
    ],
    [
      6.2,
      2.8,
      4.8,
      1.8
    ],
    [
      4.8,
      3.1,
      1.6,
      0.2
    ],
    [
      6.8,
      2.8,
      4.8,
      1.4
    ],
    [
      6.0,
      3.4,
      4.5,
      1.6
    ],
    [
      6.6,
      3.0,
      4.4,
      1.4
    ],
    [
      5.7,
      4.4,
      1.5,
      0.4
    ],
    [
      5.1,
      3.5,
      1.4,
      0.2
    ],
    [
      5.1,
      3.7,
      1.5,
      0.4
    ],
    [
      6.7,
      2.5,
      5.8,
      1.8
    ],
    [
      4.9,
      2.5,
      4.5,
      1.7
    ],
    [
      6.4,
      3.2,
      5.3,
      2.3
    ],
    [
      6.5,
      3.0,
      5.5,
      1.8
    ],
    [
      6.6,
      2.9,
      4.6,
      1.3
    ],
    [
      5.8,
      4.0,
      1.2,
      0.2
    ],
    [
      7.7,
      3.8,
      6.7,
      2.2
    ],
    [
      5.6,
      2.5,
      3.9,
      1.1
    ],
    [
      5.8,
      2.7,
      4.1,
      1.0
    ],
    [
      5.6,
      2.7,
      4.2,
      1.3
    ],
    [
      4.8,
      3.0,
      1.4,
      0.3
    ],
    [
      7.4,
      2.8,
      6.1,
      1.9
    ],
    [
      4.8,
      3.4,
      1.9,
      0.2
    ],
    [
      6.0,
      2.2,
      4.0,
      1.0
    ],
    [
      4.6,
      3.4,
      1.4,
      0.3
    ],
    [
      5.3,
      3.7,
      1.5,
      0.2
    ],
    [
      5.0,
      2.3,
      3.3,
      1.0
    ],
    [
      6.5,
      2.8,
      4.6,
      1.5
    ],
    [
      6.0,
      2.7,
      5.1,
      1.6
    ],
    [
      5.1,
      3.3,
      1.7,
      0.5
    ],
    [
      4.4,
      3.2,
      1.3,
      0.2
    ],
    [
      5.4,
      3.0,
      4.5,
      1.5
    ],
    [
      5.8,
      2.6,
      4.0,
      1.2
    ],
    [
      6.1,
      2.6,
      5.6,
      1.4
    ],
    [
      5.6,
      3.0,
      4.5,
      1.5
    ],
    [
      7.7,
      2.6,
      6.9,
      2.3
    ],
    [
      5.4,
      3.4,
      1.5,
      0.4
    ],
    [
      6.4,
      2.8,
      5.6,
      2.2
    ],
    [
      5.7,
      3.0,
      4.2,
      1.2
    ],
    [
      5.8,
      2.7,
      5.1,
      1.9
    ],
    [
      6.0,
      2.2,
      5.0,
      1.5
    ],
    [
      5.7,
      2.9,
      4.2,
      1.3
    ],
    [
      5.0,
      3.4,
      1.6,
      0.4
    ],
    [
      5.5,
      2.5,
      4.0,
      1.3
    ],
    [
      5.0,
      3.6,
      1.4,
      0.2
    ],
    [
      6.4,
      2.7,
      5.3,
      1.9
    ],
    [
      4.7,
      3.2,
      1.6,
      0.2
    ],
    [
      6.1,
      2.8,
      4.7,
      1.2
    ],
    [
      5.9,
      3.0,
      4.2,
      1.5
    ],
    [
      4.9,
      3.1,
      1.5,
      0.2
    ],
    [
      5.1,
      3.8,
      1.6,
      0.2
    ],
    [
      4.3,
      3.0,
      1.1,
      0.1
    ],
    [
      5.6,
      2.8,
      4.9,
      2.0
    ],
    [
      6.7,
      3.0,
      5.0,
      1.7
    ],
    [
      6.5,
      3.0,
      5.2,
      2.0
    ],
    [
      5.0,
      3.5,
      1.3,
      0.3
    ],
    [
      6.9,
      3.2,
      5.7,
      2.3
    ],
    [
      5.2,
      4.1,
      1.5,
      0.1
    ],
    [
      5.5,
      3.5,
      1.3,
      0.2
    ],
    [
      5.7,
      2.8,
      4.1,
      1.3
    ],
    [
      5.5,
      2.4,
      3.7,
      1.0
    ],
    [
      6.3,
      3.4,
      5.6,
      2.4
    ],
    [
      5.6,
      2.9,
      3.6,
      1.3
    ],
    [
      6.1,
      2.8,
      4.0,
      1.3
    ],
    [
      5.0,
      3.4,
      1.5,
      0.2
    ],
    [
      6.7,
      3.1,
      5.6,
      2.4
    ],
    [
      5.0,
      3.0,
      1.6,
      0.2
    ],
    [
      6.9,
      3.1,
      4.9,
      1.5
    ],
    [
      6.7,
      3.3,
      5.7,
      2.1
    ],
    [
      5.4,
      3.4,
      1.7,
      0.2
    ],
    [
      4.9,
      3.6,
      1.4,
      0.1
    ],
    [
      6.9,
      3.1,
      5.4,
      2.1
    ]
  ]
}"""

# Validate the serving payload works on the model
validate_serving_input(model_uri, serving_payload)

array([0, 2, 1, 2, 1, 0, 1, 0, 1, 1, 0, 2, 1, 2, 2, 2, 0, 0, 0, 1, 1, 0,
       2, 0, 2, 0, 2, 1, 2, 1, 2, 1, 0, 2, 1, 0, 0, 2, 0, 1, 0, 2, 1, 2,
       2, 0, 1, 1, 0, 1, 2, 0, 1, 1, 1, 0, 0, 0, 2, 1, 2, 2, 1, 0, 2, 1,
       1, 1, 0, 2, 0, 1, 0, 0, 1, 1, 2, 0, 0, 1, 1, 2, 1, 2, 0, 2, 1, 2,
       2, 1, 0, 1, 0, 2, 0, 1, 1, 0, 0, 0, 2, 1, 2, 0, 2, 0, 0, 1, 1, 2,
       1, 1, 0, 2, 0, 1, 2, 0, 0, 2])

In [11]:
# the path that is used for model inferencing can also be obtained as
model_info.model_uri

'runs:/9e77193d95b14a94b4979f3d0105a1cf/iris_model'

## Load the model back for prediction

There is one more way to inference / load the model as a generic python function and validate the output

In [15]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

predictions = loaded_model.predict(X_test)

# load the original dataset
iris_features_name = datasets.load_iris().feature_names

result = pd.DataFrame(X_test, columns=iris_features_name)

result["actual_class"] = y_test
result["predicted_class"] = predictions

result


Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),actual_class,predicted_class
0,6.7,3.0,5.2,2.3,2,2
1,4.4,3.0,1.3,0.2,0,0
2,5.8,2.7,5.1,1.9,2,2
3,6.5,3.0,5.8,2.2,2,2
4,5.0,3.5,1.6,0.6,0,0
5,6.3,2.9,5.6,1.8,2,2
6,6.0,3.0,4.8,1.8,2,2
7,6.2,2.9,4.3,1.3,1,1
8,7.2,3.0,5.8,1.6,2,2
9,4.9,3.0,1.4,0.2,0,0


In [16]:
result[:5]

Unnamed: 0,sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),actual_class,predicted_class
0,6.7,3.0,5.2,2.3,2,2
1,4.4,3.0,1.3,0.2,0,0
2,5.8,2.7,5.1,1.9,2,2
3,6.5,3.0,5.8,2.2,2,2
4,5.0,3.5,1.6,0.6,0,0
