Skip to content

Commit

Permalink
FIX #122 - KedroPipelineModel now accepts MlflowAbstractModelDataSet …
Browse files Browse the repository at this point in the history
…as artifact
  • Loading branch information
Galileo-Galilei committed Nov 28, 2020
1 parent ef15487 commit f8369e2
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- Fix `TypeError: unsupported operand type(s) for /: 'str' and 'str'` when using `MlflowArtifactDataSet` with `MlflowModelSaverDataSet` ([#116](https://github.com/Galileo-Galilei/kedro-mlflow/issues/116))
- Fix various docs typo ([#6](https://github.com/Galileo-Galilei/kedro-mlflow/issues/6))
- When the underlying Kedro pipeline fails, the associated mlflow run is now marked as 'FAILED' instead of 'FINISHED'. It is rendered with a red cross instead of the green tick in the mlflow user interface ([#121](https://github.com/Galileo-Galilei/kedro-mlflow/issues/121)).
- Fix a bug which made `KedroPipelineModel` impossible to load if one of its artifact was a `MlflowModel<Saver/Logger>DataSet`. These datasets were not deepcopiable because of one their attributes was a module ([#122](https://github.com/Galileo-Galilei/kedro-mlflow/issues/122)).

### Changed

Expand Down
32 changes: 26 additions & 6 deletions kedro_mlflow/io/models/mlflow_abstract_model_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
from importlib import import_module
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -60,9 +61,26 @@ def __init__(
self._load_args = load_args or {}
self._save_args = save_args or {}

self._mlflow_model_module = self._import_module(self._flavor)
try:
self._mlflow_model_module
except ImportError as err:
raise DataSetError(err)

# TODO: check with Kajetan what was orignally intended here
# IMPORTANT: _mlflow_model_module is a property to avoid STORING
# the module as an attribute but rather store a string and load on the fly
# The goal is to make this DataSet deepcopiable for compatibility with
# KedroPipelineModel, e.g we can't just do :
# self._mlflow_model_module = self._import_module(self._flavor)

@property
def _mlflow_model_module(self): # pragma: no cover
pass

@_mlflow_model_module.getter
def _mlflow_model_module(self):
return self._import_module(self._flavor)

# TODO: check with Kajetan what was originally intended here
# @classmethod
# def _parse_args(cls, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
# parsed_kargs = {}
Expand All @@ -80,9 +98,11 @@ def __init__(

@staticmethod
def _import_module(import_path: str) -> Any:
exists = importlib.util.find_spec(import_path)
exists = find_spec(import_path)

if not exists:
raise ImportError(f"{import_path} module not found")
raise ImportError(
f"'{import_path}' module not found. Check valid flavor in mlflow documentation: https://www.mlflow.org/docs/latest/python_api/index.html"
)

return importlib.import_module(import_path)
return import_module(import_path)
5 changes: 3 additions & 2 deletions tests/io/models/test_mlflow_model_logger_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def kedro_pipeline_model(tmp_path, pipeline_ml_obj, dummy_catalog):


def test_flavor_does_not_exists():
with pytest.raises(DataSetError, match="mlflow.whoops module not found"):
with pytest.raises(DataSetError, match="'mlflow.whoops' module not found"):
MlflowModelLoggerDataSet.from_config(
name="whoops",
config={
Expand Down Expand Up @@ -268,7 +268,8 @@ def test_load_without_run_id_nor_active_run():
mlflow_model_ds = MlflowModelLoggerDataSet.from_config(**model_config)

with pytest.raises(
DataSetError, match="To access the model_uri, you must either",
DataSetError,
match="To access the model_uri, you must either",
):
mlflow_model_ds.load()

Expand Down
66 changes: 64 additions & 2 deletions tests/mlflow/test_kedro_pipeline_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from pathlib import Path

import mlflow
import pandas as pd
import pytest
from kedro.extras.datasets.pickle import PickleDataSet
from kedro.io import DataCatalog, MemoryDataSet
from kedro.pipeline import Pipeline, node
from sklearn.linear_model import LinearRegression

from kedro_mlflow.io.models import MlflowModelSaverDataSet
from kedro_mlflow.mlflow import KedroPipelineModel
from kedro_mlflow.pipeline import pipeline_ml_factory

Expand Down Expand Up @@ -149,8 +152,6 @@ def test_model_packaging_missing_artifacts(tmp_path, pipeline_ml_obj):
}
)

# model not persisted

kedro_model = KedroPipelineModel(pipeline_ml=pipeline_ml_obj, catalog=catalog)

mlflow_tracking_uri = (tmp_path / "mlruns").as_uri()
Expand All @@ -170,3 +171,64 @@ def test_model_packaging_missing_artifacts(tmp_path, pipeline_ml_obj):
mlflow.pyfunc.load_model(
model_uri=(Path(r"runs:/") / run_id / "model").as_posix()
)


def test_kedro_pipeline_ml_loading_deepcoiable_catalog(tmp_path):

# create pipelien and catalog. The training will not be triggered
def fit_fun(data):
pass

def predict_fun(model, data):
return model.predict(data)

training_pipeline = Pipeline([node(func=fit_fun, inputs="data", outputs="model")])

inference_pipeline = Pipeline(
[
node(func=predict_fun, inputs=["model", "data"], outputs="predictions"),
]
)

ml_pipeline = pipeline_ml_factory(
training=training_pipeline,
inference=inference_pipeline,
input_name="data",
)

# emulate training by creating the model manually
model_dataset = MlflowModelSaverDataSet(
filepath=(tmp_path / "model.pkl").resolve().as_posix(), flavor="mlflow.sklearn"
)

data = pd.DataFrame(
data=[
[1, 2],
[3, 4],
],
columns=["a", "b"],
)
labels = [4, 6]
linreg = LinearRegression()
linreg.fit(data, labels)
model_dataset.save(linreg)

# check that mlflow loading is ok
catalog = DataCatalog({"data": MemoryDataSet(), "model": model_dataset})

kedro_model = KedroPipelineModel(pipeline_ml=ml_pipeline, catalog=catalog)
artifacts = ml_pipeline.extract_pipeline_artifacts(catalog)

mlflow_tracking_uri = (tmp_path / "mlruns").as_uri()
mlflow.set_tracking_uri(mlflow_tracking_uri)

with mlflow.start_run():
mlflow.pyfunc.log_model(
artifact_path="model", python_model=kedro_model, artifacts=artifacts
)
run_id = mlflow.active_run().info.run_id

loaded_model = mlflow.pyfunc.load_model(
model_uri=(Path(r"runs:/") / run_id / "model").as_posix()
)
loaded_model.predict(data) == [4.0, 6.0]

0 comments on commit f8369e2

Please sign in to comment.