Skip to content

Commit

Permalink
✨ ⬆️ Make kedro-mlflow compatible with pydantic v2 (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Oct 27, 2023
1 parent df985f6 commit 31b47e5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- :sparkles: Add support for python 3.11 ([#450, rxm7706](https://github.com/Galileo-Galilei/kedro-mlflow/pull/450))
- :sparkles: :arrow_up: Add support for pydantic v2 ([#476](https://github.com/Galileo-Galilei/kedro-mlflow/pull/476))

### Changed

Expand Down
7 changes: 6 additions & 1 deletion kedro_mlflow/framework/hooks/mlflow_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mlflow.models import infer_signature
from mlflow.tracking import MlflowClient
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH
from pydantic import __version__ as pydantic_version

from kedro_mlflow.config.kedro_mlflow_config import KedroMlflowConfig
from kedro_mlflow.framework.hooks.utils import (
Expand Down Expand Up @@ -74,7 +75,11 @@ def after_context_created(
# but we got an empty dict
conf_mlflow_yml = {}

mlflow_config = KedroMlflowConfig.parse_obj({**conf_mlflow_yml})
mlflow_config = (
KedroMlflowConfig.model_validate({**conf_mlflow_yml})
if pydantic_version > "2.0.0"
else KedroMlflowConfig.parse_obj({**conf_mlflow_yml})
)

self._already_active_mlflow = False
if mlflow.active_run():
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
kedro>=0.18.1, <0.19.0
kedro_datasets
mlflow>=1.0.0, <3.0.0
pydantic>=1.0.0, <2.0.0
pydantic>=1.0.0, <3.0.0
8 changes: 4 additions & 4 deletions tests/config/test_get_mlflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def request_headers(self):
credentials=None,
request_header_provider=dict(
type="custom_rhp.CustomRequestHeaderProviderInitKwargs",
init_kwargs=dict(a=1),
init_kwargs=dict(a="a"),
),
),
tracking=dict(
Expand Down Expand Up @@ -389,7 +389,7 @@ def request_headers(self):
mtrr._request_header_provider_registry._registry[-1].__class__.__name__
== "CustomRequestHeaderProviderInitKwargs"
)
assert mtrr._request_header_provider_registry._registry[-1].a == "1"
assert mtrr._request_header_provider_registry._registry[-1].a == "a"
assert not hasattr(mtrr._request_header_provider_registry._registry[-1], "context")


Expand Down Expand Up @@ -423,7 +423,7 @@ def request_headers(self):
request_header_provider=dict(
type="custom_rhp.CustomRequestHeaderProviderInitKwargsKedroContext",
pass_context=True,
init_kwargs=dict(b=2),
init_kwargs=dict(b="b"),
),
),
tracking=dict(
Expand Down Expand Up @@ -452,7 +452,7 @@ def request_headers(self):
mtrr._request_header_provider_registry._registry[-1].__class__.__name__
== "CustomRequestHeaderProviderInitKwargsKedroContext"
)
assert mtrr._request_header_provider_registry._registry[-1].b == "2"
assert mtrr._request_header_provider_registry._registry[-1].b == "b"
assert hasattr(mtrr._request_header_provider_registry._registry[-1], "context")
assert isinstance(
mtrr._request_header_provider_registry._registry[-1].context, KedroContext
Expand Down

0 comments on commit 31b47e5

Please sign in to comment.