Skip to content

Commit

Permalink
✅ Update MlflowArtifactDataset to keep enabling passing an instance o…
Browse files Browse the repository at this point in the history
…f the dataset instead of a string (#391)
  • Loading branch information
Galileo-Galilei committed Nov 19, 2023
1 parent 6a42d45 commit 94ad30e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import shutil
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Union

import mlflow
from kedro.io import AbstractVersionedDataset
from kedro.io import AbstractDataset, AbstractVersionedDataset
from kedro.io.core import parse_dataset_definition
from mlflow.tracking import MlflowClient

Expand All @@ -20,13 +21,16 @@ def __new__(
artifact_path: str = None,
credentials: Dict[str, Any] = None,
):
dataset, dataset_args = parse_dataset_definition(config=dataset)
if isclass(dataset["type"]) and issubclass(dataset["type"], AbstractDataset):
# parse_dataset_definition needs type to be a string, not the class itself
dataset["type"] = f"{dataset['type'].__module__}.{dataset['type'].__name__}"
dataset_obj, dataset_args = parse_dataset_definition(config=dataset)

# fake inheritance : this mlflow class should be a mother class which wraps
# all dataset (i.e. it should replace AbstractVersionedDataset)
# instead and since we can't modify the core package,
# we create a subclass which inherits dynamically from the dataset class
class MlflowArtifactDatasetChildren(dataset):
class MlflowArtifactDatasetChildren(dataset_obj):
def __init__(self, run_id, artifact_path):
super().__init__(**dataset_args)
self.run_id = run_id
Expand Down Expand Up @@ -134,7 +138,7 @@ def _load(self) -> Any: # pragma: no cover
return super()._load()

# rename the class
parent_name = dataset.__name__
parent_name = dataset_obj.__name__
MlflowArtifactDatasetChildren.__name__ = f"Mlflow{parent_name}"
MlflowArtifactDatasetChildren.__qualname__ = (
f"{parent_name}.Mlflow{parent_name}"
Expand Down
2 changes: 1 addition & 1 deletion tests/io/artifacts/test_mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import mlflow
import pandas as pd
import pytest
from kedro.io import PartitionedDataset
from kedro_datasets.pandas import CSVDataset
from kedro_datasets.partitions import PartitionedDataset
from kedro_datasets.pickle import PickleDataset
from pytest_lazyfixture import lazy_fixture

Expand Down

0 comments on commit 94ad30e

Please sign in to comment.