diff --git a/.gitmodules b/.gitmodules index 1f6cd95..1d804ea 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "tests/engine_tests/engine-test-data"] path = tests/engine_tests/engine-test-data url = https://github.com/flagsmith/engine-test-data.git - tag = v3.5.0 + tag = v3.6.0 diff --git a/flag_engine/segments/evaluator.py b/flag_engine/segments/evaluator.py index 78cd44d..0b8f8fc 100644 --- a/flag_engine/segments/evaluator.py +++ b/flag_engine/segments/evaluator.py @@ -279,23 +279,9 @@ def context_matches_condition( context_value = get_context_value(context, condition_property) if condition_operator == constants.IN: - if isinstance(segment_value := condition["value"], list): - in_values = segment_value - else: - try: - in_values = json.loads(segment_value) - # Only accept JSON lists. - # Ideally, we should use something like pydantic.TypeAdapter[list[str]], - # but we aim to ditch the pydantic dependency in the future. - if not isinstance(in_values, list): - raise ValueError - except ValueError: - in_values = segment_value.split(",") - in_values = [str(value) for value in in_values] + in_values = _get_in_values(condition["value"]) # Guard against comparing boolean values to numeric strings. - if isinstance(context_value, int) and not ( - context_value is True or context_value is False - ): + if type(context_value) is int: context_value = str(context_value) return context_value in in_values @@ -348,6 +334,30 @@ def _matches_context_value( return False +@lru_cache(maxsize=1024) +def _parse_in_values_str(segment_value: str) -> frozenset[str]: + """ + Parse a string-form IN condition value into a frozenset of strings. + A bracketed value is tried as JSON first (with CSV fallback on parse + error); anything else is split on commas directly. + """ + if segment_value.startswith("["): + try: + parsed: list[typing.Any] = json.loads(segment_value) + except ValueError: + return frozenset(segment_value.split(",")) + return frozenset(v if type(v) is str else str(v) for v in parsed) + return frozenset(segment_value.split(",")) + + +def _get_in_values( + segment_value: typing.Union[str, list[typing.Any]], +) -> frozenset[str]: + if isinstance(segment_value, list): + return frozenset(v if type(v) is str else str(v) for v in segment_value) + return _parse_in_values_str(segment_value) + + def _evaluate_not_contains( segment_value: typing.Optional[str], context_value: ContextValue, diff --git a/tests/engine_tests/engine-test-data b/tests/engine_tests/engine-test-data index 7840a13..9307930 160000 --- a/tests/engine_tests/engine-test-data +++ b/tests/engine_tests/engine-test-data @@ -1 +1 @@ -Subproject commit 7840a1349b601df3b6b4a089f40864f659801afb +Subproject commit 9307930e9e64482a35e7d6b254225addb6e44687 diff --git a/tests/engine_tests/test_engine.py b/tests/engine_tests/test_engine.py index 3b5a67b..fb98f4b 100644 --- a/tests/engine_tests/test_engine.py +++ b/tests/engine_tests/test_engine.py @@ -36,6 +36,9 @@ def _extract_benchmark_contexts( for file_path in [ "test_0cfd0d72-4de4-4ed7-9cfb-d80dc3dacead__default.json", "test_1bde8445-ca19-4bda-a9d5-3543a800fc0f__context_values.json", + "test_in_condition_json_array_format__should_match.jsonc", + "test_in_condition_numeric_comma_separated__should_match.jsonc", + "test_in_condition_array_matching_value__should_match.jsonc", ]: yield pyjson5.loads((test_cases_dir_path / file_path).read_text())["context"]