diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index fd5264e90e..c8ab67e565 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -11,12 +11,13 @@ Future Release * Documentation Changes * Minor fix to release notes (:pr:`2444`) * Testing Changes + * Add test that checks for Natural Language primitives timing out against edge-case input (:pr:`2429`) * Fix test compatibility with composeml 0.10 (:pr:`2439`) * Minimum dependency unit test jobs do not abort if one job fails (:pr:`2437`) * Run Looking Glass performance tests on merge to main (:pr:`2440`, :pr:`2441`) Thanks to the following people for contributing to this release: - :user:`gsheni`, :user:`rwedge`, :user: `thehomebrewnerd` + :user:`gsheni`, :user:`rwedge`, :user:`sbadithe`, :user:`thehomebrewnerd` v1.20.0 Jan 5, 2023 =================== diff --git a/featuretools/primitives/utils.py b/featuretools/primitives/utils.py index 8c7bdc0078..9cccf6b714 100644 --- a/featuretools/primitives/utils.py +++ b/featuretools/primitives/utils.py @@ -6,6 +6,7 @@ import pandas as pd from woodwork import list_logical_types, list_semantic_tags from woodwork.column_schema import ColumnSchema +from woodwork.logical_types import NaturalLanguage import featuretools from featuretools.primitives import NumberOfCommonWords @@ -17,28 +18,57 @@ from featuretools.utils.gen_utils import Library, find_descendents -# returns all aggregation primitives, regardless of compatibility -def get_aggregation_primitives(): - aggregation_primitives = set([]) +def _get_primitives(primitive_kind): + """Helper function that selects all primitives + that are instances of `primitive_kind` + """ + primitives = set() for attribute_string in dir(featuretools.primitives): attribute = getattr(featuretools.primitives, attribute_string) if isclass(attribute): - if issubclass(attribute, featuretools.primitives.AggregationPrimitive): - if attribute.name: - aggregation_primitives.add(attribute) - return {prim.name.lower(): prim for prim in aggregation_primitives} + if issubclass(attribute, primitive_kind) and attribute.name: + primitives.add(attribute) + return {prim.name.lower(): prim for prim in primitives} + + +def get_aggregation_primitives(): + """Returns all aggregation primitives, regardless + of compatibility + """ + return _get_primitives(featuretools.primitives.AggregationPrimitive) -# returns all transform primitives, regardless of compatibility def get_transform_primitives(): - transform_primitives = set([]) - for attribute_string in dir(featuretools.primitives): - attribute = getattr(featuretools.primitives, attribute_string) - if isclass(attribute): - if issubclass(attribute, featuretools.primitives.TransformPrimitive): - if attribute.name: - transform_primitives.add(attribute) - return {prim.name.lower(): prim for prim in transform_primitives} + """Returns all transform primitives, regardless + of compatibility + """ + return _get_primitives(featuretools.primitives.TransformPrimitive) + + +def _get_natural_language_primitives(): + """Returns all Natural Language transform primitives, + regardless of compatibility + """ + transform_primitives = get_transform_primitives() + + def _natural_language_in_input_type(primitive): + for input_type in primitive.input_types: + if isinstance(input_type, list): + if any( + isinstance(column_schema.logical_type, NaturalLanguage) + for column_schema in input_type + ): + return True + else: + if isinstance(input_type.logical_type, NaturalLanguage): + return True + return False + + return { + name: primitive + for name, primitive in transform_primitives.items() + if _natural_language_in_input_type(primitive) + } def list_primitives(): diff --git a/featuretools/tests/conftest.py b/featuretools/tests/conftest.py index 1e7f9a5c67..f29fd752e5 100644 --- a/featuretools/tests/conftest.py +++ b/featuretools/tests/conftest.py @@ -864,3 +864,11 @@ class TestTransform(TransformPrimitive): stack_on = [] return TestTransform + + +@pytest.fixture +def strings_that_have_triggered_errors_before(): + return [ + " ", + '"This Borderlands game here"" is the perfect conclusion to the ""Borderlands 3"" line, which focuses on the fans ""favorite character and gives the players the opportunity to close for a long time some very important questions about\'s character and the memorable scenery with which the players interact.', + ] diff --git a/featuretools/tests/primitive_tests/natural_language_primitives_tests/test_natural_language_primitives_terminate.py b/featuretools/tests/primitive_tests/natural_language_primitives_tests/test_natural_language_primitives_terminate.py new file mode 100644 index 0000000000..b667ff556e --- /dev/null +++ b/featuretools/tests/primitive_tests/natural_language_primitives_tests/test_natural_language_primitives_terminate.py @@ -0,0 +1,22 @@ +import pandas as pd +import pytest + +from featuretools.primitives.utils import _get_natural_language_primitives + +TIMEOUT_THRESHOLD = 20 + + +class TestNaturalLanguagePrimitivesTerminate: + + # need to sort primitives to avoid pytest collection error + primitives = sorted(_get_natural_language_primitives().items()) + + @pytest.mark.timeout(TIMEOUT_THRESHOLD) + @pytest.mark.parametrize("primitive", [prim for _, prim in primitives]) + def test_natlang_primitive_does_not_timeout( + self, + strings_that_have_triggered_errors_before, + primitive, + ): + for text in strings_that_have_triggered_errors_before: + primitive().get_function()(pd.Series(text)) diff --git a/pyproject.toml b/pyproject.toml index 9876734c77..5d27f7689a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ test = [ "pytest-xdist >= 2.5.0", "smart-open >= 5.0.0", "urllib3 >= 1.26.5", + "pytest-timeout >= 2.1.0" ] spark = [ "woodwork[spark] >= 0.18.0",