Skip to content

Commit

Permalink
Add a mixin to make the task skippable (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Dec 20, 2021
1 parent 0901ed5 commit 080b954
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 0 deletions.
63 changes: 63 additions & 0 deletions data_validation_framework/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import traceback
import warnings
from functools import partial
from pathlib import Path

import luigi
Expand All @@ -21,13 +22,15 @@
from numpy import VisibleDeprecationWarning

from data_validation_framework.report import make_report
from data_validation_framework.result import ValidationResult
from data_validation_framework.result import ValidationResultSet
from data_validation_framework.target import ReportTarget
from data_validation_framework.target import TaggedOutputLocalTarget
from data_validation_framework.util import apply_to_df

L = logging.getLogger(__name__)
INDEX_LABEL = "__index_label__"
SKIP_COMMENT = "Skipped by user."


class ValidationError(Exception):
Expand Down Expand Up @@ -660,3 +663,63 @@ def validation_function(*args, **kwargs):
This method should usually do nothing for :class:`ValidationWorkflow` as this class is only
supposed to gather validation steps.
"""


def _skippable_element_validation_function(validation_function, skip, *args, **kwargs):
"""Skipping wrapper for an element validation function."""
if skip:
return ValidationResult(is_valid=True, comment=SKIP_COMMENT)
return validation_function(*args, **kwargs)


def _skippable_set_validation_function(validation_function, skip, *args, **kwargs):
"""Skipping wrapper for a set validation function."""
df = kwargs.get("df", args[0])
if skip:
df.loc[df["is_valid"], "comment"] = SKIP_COMMENT
else:
validation_function(*args, **kwargs)


def SkippableMixin(default_value=False):
"""Create a mixin class to add a ``skip`` parameter.
This mixin must be applied to a :class:`data_validation_framework.ElementValidationTask`.
It will create a ``skip`` parameter and wrap the validation function to just skip it if the
``skip`` argument is set to ``True``. If skipped, it will keep the ``is_valid`` values as is and
add a specific comment to inform the user.
Args:
default_value (bool): The default value for the ``skip`` argument.
"""

class Mixin:
"""A mixin to add a ``skip`` parameter to a :class:`luigi.task`."""

skip = BoolParameter(default=default_value, description=":bool: Skip the task")

def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)

if isinstance(self, ElementValidationTask):
new_validation_function = partial(
_skippable_element_validation_function,
self.validation_function,
self.skip,
)
elif isinstance(self, SetValidationTask) and not isinstance(self, ValidationWorkflow):
new_validation_function = partial(
_skippable_set_validation_function,
self.validation_function,
self.skip,
)
else:
raise TypeError(
"The SkippableMixin can only be associated with childs of ElementValidationTask"
" or SetValidationTask"
)
self._skippable_validation_function = self.validation_function
self.validation_function = new_validation_function

return Mixin
154 changes: 154 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,3 +1568,157 @@ def test_nested_workflows(
data_dir / "test_report_before_run" / "report_rst2pdf_nested.pdf",
threshold=25,
)


class TestSkippableMixin:
"""Test the data_validation_framework.task.SkippableMixin class."""

def test_fail_parent_type(self):
err_msg = (
"The SkippableMixin can only be associated with childs of ElementValidationTask"
" or SetValidationTask"
)

class TestTask1(task.SkippableMixin(), luigi.Task):
pass

with pytest.raises(
TypeError,
match=err_msg,
):
TestTask1()

class TestTask2(task.SkippableMixin(), task.ValidationWorkflow):
pass

with pytest.raises(
TypeError,
match=err_msg,
):
TestTask2()

def test_skip_element_task(self, dataset_df_path, tmpdir):
class TestSkippableTask(task.SkippableMixin(), task.ElementValidationTask):
@staticmethod
# pylint: disable=arguments-differ
def validation_function(row, output_path, *args, **kwargs):
if row["a"] <= 1:
return result.ValidationResult(is_valid=True)
if row["a"] <= 2:
return result.ValidationResult(is_valid=False, comment="bad value")
raise ValueError(f"Incorrect value {row['a']}")

# Test with no given skip value (should be False by default)
assert luigi.build(
[
TestSkippableTask(
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_default")
)
],
local_scheduler=True,
)

report_data = pd.read_csv(tmpdir / "out_default" / "TestSkippableTask" / "report.csv")
assert (report_data["is_valid"] == [True, False]).all()
assert (report_data["comment"].isnull() == [True, False]).all()
assert report_data.loc[1, "comment"] == "bad value"
assert report_data["exception"].isnull().all()

# Test with no given skip value (should be False by default)
assert luigi.build(
[
TestSkippableTask(
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_no_skip"), skip=False
)
],
local_scheduler=True,
)

report_data = pd.read_csv(tmpdir / "out_no_skip" / "TestSkippableTask" / "report.csv")
assert (report_data["is_valid"] == [True, False]).all()
assert (report_data["comment"].isnull() == [True, False]).all()
assert report_data.loc[1, "comment"] == "bad value"
assert report_data["exception"].isnull().all()

# Test with no given skip value (should be False by default)
assert luigi.build(
[
TestSkippableTask(
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_skip"), skip=True
)
],
local_scheduler=True,
)

report_data = pd.read_csv(tmpdir / "out_skip" / "TestSkippableTask" / "report.csv")
assert (
report_data["is_valid"] == True # noqa ; pylint: disable=singleton-comparison
).all()
assert (report_data["comment"] == "Skipped by user.").all()
assert report_data["exception"].isnull().all()

def test_skip_set_task(self, dataset_df_path, tmpdir):
class TestSkippableTask(task.SkippableMixin(), task.SetValidationTask):
@staticmethod
def validation_function(df, output_path, *args, **kwargs):
# pylint: disable=no-member
df["a"] *= 10
df.loc[1, "is_valid"] = False
df.loc[1, "ret_code"] = 1
df[["a", "b"]].to_csv(output_path / "test.csv")

# Test with no given skip value (should be False by default)
assert luigi.build(
[
TestSkippableTask(
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_default")
)
],
local_scheduler=True,
)

res = pd.read_csv(tmpdir / "out_default" / "TestSkippableTask" / "data" / "test.csv")
expected = pd.read_csv(tmpdir / "dataset.csv")
expected["a"] *= 10
assert res.equals(expected)
report_data = pd.read_csv(tmpdir / "out_default" / "TestSkippableTask" / "report.csv")
assert (report_data["is_valid"] == [True, False]).all()
assert report_data["comment"].isnull().all()
assert report_data["exception"].isnull().all()

# Test with skip = False
assert luigi.build(
[
TestSkippableTask(
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_no_skip"), skip=False
)
],
local_scheduler=True,
)

res = pd.read_csv(tmpdir / "out_no_skip" / "TestSkippableTask" / "data" / "test.csv")
expected = pd.read_csv(tmpdir / "dataset.csv")
expected["a"] *= 10
assert res.equals(expected)
report_data = pd.read_csv(tmpdir / "out_no_skip" / "TestSkippableTask" / "report.csv")
assert (report_data["is_valid"] == [True, False]).all()
assert report_data["comment"].isnull().all()
assert report_data["exception"].isnull().all()

# Test with skip = True
assert luigi.build(
[
TestSkippableTask(
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_skip"), skip=True
)
],
local_scheduler=True,
)

assert not (tmpdir / "out_skip" / "TestSkippableTask" / "data" / "test.csv").exists()
report_data = pd.read_csv(tmpdir / "out_skip" / "TestSkippableTask" / "report.csv")
assert (
report_data["is_valid"] == True # noqa ; pylint: disable=singleton-comparison
).all()
assert (report_data["comment"] == "Skipped by user.").all()
assert report_data["exception"].isnull().all()

0 comments on commit 080b954

Please sign in to comment.