In [1]:
from joblib import dump
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

import mlflow
from mlflow.models import infer_signature
from mlflow.pyfunc import PythonModel

In [3]:
def get_or_create_experiment(experiment_name):
    if experiment := mlflow.get_experiment_by_name(experiment_name):
        return experiment.experiment_id
    else:
        return mlflow.create_experiment(experiment_name)

In [4]:
mlflow.set_tracking_uri("http://127.0.0.1:5000")
experiment_id = get_or_create_experiment("Custom Predict Test")
mlflow.set_experiment(experiment_id=experiment_id)

<Experiment: artifact_location='mlflow-artifacts:/320722760442736137', creation_time=1729732474296, experiment_id='320722760442736137', last_update_time=1729732474296, lifecycle_stage='active', name='Custom Predict Test', tags={}>

### LOAD DATA

In [9]:
iris = load_iris()
x = iris.data[:, 2:]
y = iris.target

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=9001)

model = LogisticRegression(random_state=0, max_iter=5_000, solver="newton-cg").fit(x_train, y_train)

In [10]:
model.predict(x_test)[:5]

array([1, 2, 2, 1, 0])

In [11]:
model.predict_proba(x_test)[:5]

array([[2.63986038e-03, 6.62183383e-01, 3.35176756e-01],
       [1.24644993e-04, 8.36504104e-02, 9.16224945e-01],
       [1.30873544e-04, 1.37649920e-01, 8.62219207e-01],
       [3.70883962e-03, 7.13033599e-01, 2.83257562e-01],
       [9.82596897e-01, 1.74030230e-02, 7.99199819e-08]])

In [12]:
model.predict_log_proba(x_test)[:5]

array([[ -5.93702925,  -0.41221275,  -1.09309726],
       [ -8.99004092,  -2.48110895,  -0.08749337],
       [ -8.94127902,  -1.98304163,  -0.14824574],
       [ -5.59703622,  -0.33822674,  -1.26139868],
       [ -0.01755632,  -4.05111135, -16.34223993]])

### Mlflow Setup

In [13]:
sklearn_path = "/tmp/sklearn_model"

with mlflow.start_run() as run:
    mlflow.sklearn.save_model(
        sk_model=model,
        path=sklearn_path,
        input_example=x_train[:2],
    )

2024/10/24 09:19:31 INFO mlflow.tracking._tracking_service.client: 🏃 View run colorful-skink-636 at: http://127.0.0.1:5000/#/experiments/320722760442736137/runs/dde24f7275f843a98ad7d7f850bbead9.
2024/10/24 09:19:31 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/320722760442736137.


In [14]:
loaded_logreg_model = mlflow.pyfunc.load_model(sklearn_path)

In [15]:
loaded_logreg_model.predict(x_test)

array([1, 2, 2, 1, 0, 1, 2, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 2, 1,
       1, 0, 1, 1, 0, 0, 1, 2])

### Customizing Behavior of the model

In [17]:
# Use joblib for serialization
model_directory = "/tmp/sklearn_model.joblib"
dump(model, model_directory)

['/tmp/sklearn_model.joblib']

In [18]:
# Define Custom Python Model
class ModelWrapper(PythonModel):
    def __init__(self):
        self.model = None

    def load_context(self, context):
        from joblib import load

        self.model = load(context.artifacts["model_path"])

    def predict(self, context, model_input, params=None):
        params = params or {"predict_method": "predict"}
        predict_method = params.get("predict_method")

        if predict_method == "predict":
            return self.model.predict(model_input)
        elif predict_method == "predict_proba":
            return self.model.predict_proba(model_input)
        elif predict_method == "predict_log_proba":
            return self.model.predict_log_proba(model_input)
        else:
            raise ValueError(f"The prediction method '{predict_method}' is not supported.")

In [19]:
# Define the required artifacts associated with the saved custom pyfunc
artifacts = {"model_path": model_directory}

# Define the signature associated with the model
signature = infer_signature(x_train, params={"predict_method": "predict_proba"})


In [20]:
signature

inputs: 
  [Tensor('float64', (-1, 2))]
outputs: 
  None
params: 
  ['predict_method': string (default: predict_proba)]

In [21]:
pyfunc_path = "/tmp/dynamic_regressor"

with mlflow.start_run() as run:
    mlflow.pyfunc.save_model(
        path=pyfunc_path,
        python_model=ModelWrapper(),
        input_example=x_train,
        signature=signature,
        artifacts=artifacts,
        pip_requirements=["joblib", "sklearn"],
    )


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

2024/10/24 09:31:06 INFO mlflow.tracking._tracking_service.client: 🏃 View run capable-mole-842 at: http://127.0.0.1:5000/#/experiments/320722760442736137/runs/f141378be7754b4f8ba305b330d11203.
2024/10/24 09:31:06 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/320722760442736137.


In [22]:
loaded_dynamic = mlflow.pyfunc.load_model(pyfunc_path)

In [23]:
loaded_dynamic.predict(x_test)

array([[2.63986038e-03, 6.62183383e-01, 3.35176756e-01],
       [1.24644993e-04, 8.36504104e-02, 9.16224945e-01],
       [1.30873544e-04, 1.37649920e-01, 8.62219207e-01],
       [3.70883962e-03, 7.13033599e-01, 2.83257562e-01],
       [9.82596897e-01, 1.74030230e-02, 7.99199819e-08],
       [6.53971812e-03, 7.53980623e-01, 2.39479659e-01],
       [2.30147058e-06, 1.29718438e-02, 9.87025855e-01],
       [9.71313215e-01, 2.86866206e-02, 1.64768103e-07],
       [3.36795311e-01, 6.61251371e-01, 1.95331849e-03],
       [9.81875377e-01, 1.81244830e-02, 1.40238274e-07],
       [9.70731372e-01, 2.92684102e-02, 2.18211689e-07],
       [6.53971812e-03, 7.53980623e-01, 2.39479659e-01],
       [1.06934443e-02, 8.88082803e-01, 1.01223753e-01],
       [3.35035729e-03, 6.57571588e-01, 3.39078055e-01],
       [9.82239738e-01, 1.77601562e-02, 1.05867483e-07],
       [9.82596897e-01, 1.74030230e-02, 7.99199819e-08],
       [1.62646625e-03, 5.43435159e-01, 4.54938375e-01],
       [9.82596897e-01, 1.74030

In [24]:
loaded_dynamic.predict(x_test, params={"predict_method": "predict"})

array([1, 2, 2, 1, 0, 1, 2, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 2, 1,
       1, 0, 1, 1, 0, 0, 1, 2])

In [25]:
loaded_dynamic.predict(x_test, params={"predict_method": "predict_log_proba"})

array([[-5.93702925e+00, -4.12212748e-01, -1.09309726e+00],
       [-8.99004092e+00, -2.48110895e+00, -8.74933717e-02],
       [-8.94127902e+00, -1.98304163e+00, -1.48245740e-01],
       [-5.59703622e+00, -3.38226736e-01, -1.26139868e+00],
       [-1.75563172e-02, -4.05111135e+00, -1.63422399e+01],
       [-5.02986122e+00, -2.82388610e-01, -1.42928680e+00],
       [-1.29819623e+01, -4.34497414e+00, -1.30590446e-02],
       [-2.91062935e-02, -3.55132445e+00, -1.56187268e+01],
       [-1.08827992e+00, -4.13621222e-01, -6.23822556e+00],
       [-1.82908862e-02, -4.01049160e+00, -1.57799229e+01],
       [-2.97055002e-02, -3.53124650e+00, -1.53378002e+01],
       [-5.02986122e+00, -2.82388610e-01, -1.42928680e+00],
       [-4.53812441e+00, -1.18690294e-01, -2.29042184e+00],
       [-5.69868828e+00, -4.19201642e-01, -1.08152495e+00],
       [-1.79198682e-02, -4.03079774e+00, -1.60610777e+01],
       [-1.75563172e-02, -4.05111135e+00, -1.63422399e+01],
       [-6.42134556e+00, -6.09844882e-01