From aa2f0b5924bcccb7fc8a7c6e3db46ccb458b3b3c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 14 May 2026 23:40:40 +0200 Subject: [PATCH 1/4] Strip deprecated calculate inputs --- .../port-deprecated-input-stripping.fixed.md | 1 + policyengine_api/endpoints/household.py | 18 +- policyengine_api/utils/deprecated_inputs.py | 157 ++++++++++++++ .../test_calculate_deprecated_inputs.py | 179 ++++++++++++++++ tests/unit/utils/test_deprecated_inputs.py | 197 ++++++++++++++++++ 5 files changed, 551 insertions(+), 1 deletion(-) create mode 100644 changelog.d/port-deprecated-input-stripping.fixed.md create mode 100644 policyengine_api/utils/deprecated_inputs.py create mode 100644 tests/unit/endpoints/test_calculate_deprecated_inputs.py create mode 100644 tests/unit/utils/test_deprecated_inputs.py diff --git a/changelog.d/port-deprecated-input-stripping.fixed.md b/changelog.d/port-deprecated-input-stripping.fixed.md new file mode 100644 index 000000000..5b8ae7b8b --- /dev/null +++ b/changelog.d/port-deprecated-input-stripping.fixed.md @@ -0,0 +1 @@ +Strip deprecated `medical_out_of_pocket_expenses` inputs from `/calculate` requests before simulation and return a warning instead of surfacing an engine error. diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index 92663ac42..1215c627b 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -4,6 +4,7 @@ from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS import logging from datetime import date +from policyengine_api.utils.deprecated_inputs import drop_deprecated_inputs from policyengine_api.utils.payload_validators import validate_country @@ -216,6 +217,13 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict: # Add in any missing yearly variables to household_json household_json = add_yearly_variables(household_json, country_id) + # Strip deprecated inputs from a copy before the engine runs so + # partners who still pass removed/renamed variables get a warning + + # working response instead of a `VariableNotFoundError` HTTP 500. + deprecated_inputs = drop_deprecated_inputs(household_json) + household_json = deprecated_inputs.household + deprecation_warnings = deprecated_inputs.warnings + country = get_countries().get(country_id) try: @@ -232,8 +240,16 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict: mimetype="application/json", ) - return dict( + response_body = dict( status="ok", message=None, result=result, ) + + warning_messages = [w.message for w in deprecation_warnings] + if warning_messages: + # Serialize to strings on the wire; the structured dataclasses + # stay available for any future caller that wants the fields. + response_body["warnings"] = warning_messages + + return response_body diff --git a/policyengine_api/utils/deprecated_inputs.py b/policyengine_api/utils/deprecated_inputs.py new file mode 100644 index 000000000..79dbecdf7 --- /dev/null +++ b/policyengine_api/utils/deprecated_inputs.py @@ -0,0 +1,157 @@ +"""Detect and drop deprecated input variables before they reach the engine. + +Without this, a partner who passes a removed model variable (e.g. +``medical_out_of_pocket_expenses``, deleted in policyengine-us 1.673.0) +crashes the simulation with ``VariableNotFoundError`` (HTTP 500). Dropping +the input and surfacing a structured warning gives partners a soft +landing - every other output computes normally; only outputs that +depended on the deprecated input fall back to defaults. +""" + +import copy +from dataclasses import dataclass + + +# Registry of removed/renamed model variables that legacy partner traffic +# may still pass. The value is a one-line migration hint surfaced verbatim +# in the warning message - keep it actionable. +DEPRECATED_VARIABLES: dict[str, str] = { + "medical_out_of_pocket_expenses": ( + "Removed in policyengine-us 1.673.0. Migrate non-premium spending " + "to `other_medical_expenses` and premium spending to " + "`health_insurance_premiums`." + ), +} + + +@dataclass(frozen=True) +class DeprecatedVariableWarning: + """A removed/renamed variable was supplied; dropped before the engine ran.""" + + variable: str + entity_plural: str + entity_id: str + hint: str + + @property + def message(self) -> str: + location = f"`{self.entity_plural}/{self.entity_id}`" + if self.entity_plural == "axes": + location = f"`axes[{self.entity_id}].name`" + return ( + f"Input `{self.variable}` on {location} is deprecated and was " + f"ignored for this calculation. {self.hint}" + ) + + +@dataclass(frozen=True) +class DeprecatedInputsResult: + """A household copy with deprecated inputs removed plus warnings.""" + + household: dict + warnings: list[DeprecatedVariableWarning] + + +def drop_deprecated_inputs( + household: dict, +) -> DeprecatedInputsResult: + """Return a household copy with deprecated input keys stripped. + + Returns one warning per (entity, deprecated-key) occurrence. The + caller's ``household`` is never mutated; downstream simulation + receives the returned copy. + + Non-dict inputs are returned unchanged with no warnings; downstream + code retains its existing bad-shape behavior. + """ + warnings: list[DeprecatedVariableWarning] = [] + + if not isinstance(household, dict): + return DeprecatedInputsResult(household=household, warnings=warnings) + + cleaned_household = copy.deepcopy(household) + + for entity_plural, entity_group in cleaned_household.items(): + if entity_plural == "axes": + continue + if not isinstance(entity_group, dict): + continue + for entity_id, variables in entity_group.items(): + if not isinstance(variables, dict): + continue + for variable_name in list(variables.keys()): + hint = DEPRECATED_VARIABLES.get(variable_name) + if hint is None: + continue + warnings.append( + DeprecatedVariableWarning( + variable=variable_name, + entity_plural=entity_plural, + entity_id=entity_id, + hint=hint, + ) + ) + del variables[variable_name] + + _drop_deprecated_axes(cleaned_household, warnings) + + return DeprecatedInputsResult(household=cleaned_household, warnings=warnings) + + +def _drop_deprecated_axes( + household: dict, warnings: list[DeprecatedVariableWarning] +) -> None: + axes = household.get("axes") + if not isinstance(axes, list): + return + + changed = False + retained_entries = [] + + for entry_index, entry in enumerate(axes): + if isinstance(entry, list): + retained_axes = [] + for axis_index, axis in enumerate(entry): + location = f"{entry_index}][{axis_index}" + if _is_deprecated_axis(axis, location, warnings): + changed = True + continue + retained_axes.append(axis) + if retained_axes: + retained_entries.append(retained_axes) + continue + + location = str(entry_index) + if _is_deprecated_axis(entry, location, warnings): + changed = True + continue + retained_entries.append(entry) + + if not changed: + return + if retained_entries: + household["axes"] = retained_entries + else: + del household["axes"] + + +def _is_deprecated_axis( + axis, location: str, warnings: list[DeprecatedVariableWarning] +) -> bool: + if not isinstance(axis, dict): + return False + + variable_name = axis.get("name") + hint = DEPRECATED_VARIABLES.get(variable_name) + if hint is None: + return False + + warnings.append( + DeprecatedVariableWarning( + variable=variable_name, + entity_plural="axes", + entity_id=location, + hint=hint, + ) + ) + return True diff --git a/tests/unit/endpoints/test_calculate_deprecated_inputs.py b/tests/unit/endpoints/test_calculate_deprecated_inputs.py new file mode 100644 index 000000000..f3545f2f3 --- /dev/null +++ b/tests/unit/endpoints/test_calculate_deprecated_inputs.py @@ -0,0 +1,179 @@ +from flask import Flask +import pytest + +from policyengine_api.endpoints import household as household_endpoint + + +class DummyCountry: + def __init__(self): + self.household = None + self.policy = None + + def calculate(self, household, policy): + self.household = household + self.policy = policy + return {"household": household, "policy": policy} + + +@pytest.fixture +def calculate_client(monkeypatch): + country = DummyCountry() + monkeypatch.setattr( + household_endpoint, + "get_countries", + lambda: {"us": country}, + ) + + app = Flask(__name__) + app.add_url_rule( + "//calculate", + "calculate", + household_endpoint.get_calculate, + methods=["POST"], + ) + return app.test_client(), country + + +def test__calculate__drops_deprecated_input_with_warning(calculate_client): + client, country = calculate_client + household = { + "people": { + "you": { + "age": {"2025": 49}, + "medical_out_of_pocket_expenses": {"2025": 0}, + } + } + } + + response = client.post("/us/calculate", json={"household": household}) + + assert response.status_code == 200 + payload = response.get_json() + assert "medical_out_of_pocket_expenses" not in country.household["people"]["you"] + assert country.household["people"]["you"]["age"] == {"2025": 49} + assert payload["result"]["household"] == country.household + assert any( + "medical_out_of_pocket_expenses" in warning + and "people/you" in warning + and "deprecated" in warning.lower() + for warning in payload["warnings"] + ) + assert "medical_out_of_pocket_expenses" in household["people"]["you"] + + +def test__calculate__drops_deprecated_axis_with_warning(calculate_client): + client, country = calculate_client + household = { + "people": {"you": {"age": {"2025": 49}}}, + "axes": [ + [ + { + "name": "medical_out_of_pocket_expenses", + "period": "2025", + "min": 0, + "max": 100, + "count": 2, + }, + { + "name": "employment_income", + "period": "2025", + "min": 0, + "max": 100, + "count": 2, + }, + ] + ], + } + + response = client.post("/us/calculate", json={"household": household}) + + assert response.status_code == 200 + payload = response.get_json() + assert country.household["axes"] == [ + [ + { + "name": "employment_income", + "period": "2025", + "min": 0, + "max": 100, + "count": 2, + } + ] + ] + assert any( + "medical_out_of_pocket_expenses" in warning and "axes[0][0].name" in warning + for warning in payload["warnings"] + ) + + +def test__calculate__omits_warnings_without_deprecated_input(calculate_client): + client, country = calculate_client + household = { + "people": { + "you": { + "age": {"2025": 49}, + } + } + } + + response = client.post("/us/calculate", json={"household": household}) + + assert response.status_code == 200 + payload = response.get_json() + assert "warnings" not in payload + assert country.household == household + + +def test__calculate_full__drops_deprecated_input_after_add_missing(monkeypatch): + country = DummyCountry() + monkeypatch.setattr( + household_endpoint, + "get_countries", + lambda: {"us": country}, + ) + monkeypatch.setattr( + household_endpoint, + "add_yearly_variables", + lambda household, country_id: { + **household, + "people": { + **household["people"], + "you": { + **household["people"]["you"], + "employment_income": {"2025": None}, + }, + }, + }, + ) + + app = Flask(__name__) + + def calculate_full(country_id): + return household_endpoint.get_calculate(country_id, add_missing=True) + + app.add_url_rule( + "//calculate-full", + "calculate_full", + calculate_full, + methods=["POST"], + ) + client = app.test_client() + household = { + "people": { + "you": { + "age": {"2025": 49}, + "medical_out_of_pocket_expenses": {"2025": 0}, + } + } + } + + response = client.post("/us/calculate-full", json={"household": household}) + + assert response.status_code == 200 + payload = response.get_json() + assert "medical_out_of_pocket_expenses" not in country.household["people"]["you"] + assert country.household["people"]["you"]["employment_income"] == {"2025": None} + assert any( + "medical_out_of_pocket_expenses" in warning and "deprecated" in warning.lower() + for warning in payload["warnings"] + ) diff --git a/tests/unit/utils/test_deprecated_inputs.py b/tests/unit/utils/test_deprecated_inputs.py new file mode 100644 index 000000000..b665ed76d --- /dev/null +++ b/tests/unit/utils/test_deprecated_inputs.py @@ -0,0 +1,197 @@ +"""Unit tests for the deprecated-input warn-and-drop helper. + +`drop_deprecated_inputs` returns a household copy with removed/renamed +model variables stripped before validation so partners who still pass +them get a structured warning + working response instead of a +`VariableNotFoundError` HTTP 500 from the engine. +""" + +import copy + +from policyengine_api.utils.deprecated_inputs import ( + DEPRECATED_VARIABLES, + DeprecatedInputsResult, + DeprecatedVariableWarning, + drop_deprecated_inputs, +) + + +class TestDropDeprecatedInputs: + def test__deprecated_variable__is_dropped_with_warning(self): + household = { + "people": { + "you": { + "age": {"2025": 49}, + "medical_out_of_pocket_expenses": {"2025": 0}, + } + } + } + original = copy.deepcopy(household) + + result = drop_deprecated_inputs(household) + cleaned = result.household + warnings = result.warnings + + assert "medical_out_of_pocket_expenses" not in cleaned["people"]["you"] + assert household == original + assert cleaned["people"]["you"]["age"] == {"2025": 49} + assert len(warnings) == 1 + assert isinstance(result, DeprecatedInputsResult) + assert isinstance(warnings[0], DeprecatedVariableWarning) + assert warnings[0].variable == "medical_out_of_pocket_expenses" + assert warnings[0].entity_plural == "people" + assert warnings[0].entity_id == "you" + assert "other_medical_expenses" in warnings[0].hint + + def test__no_deprecated_variables__returns_empty(self): + household = { + "people": {"you": {"age": {"2025": 49}}}, + "households": {"household": {"members": ["you"]}}, + } + original = copy.deepcopy(household) + + result = drop_deprecated_inputs(household) + + assert result.warnings == [] + assert result.household == household + assert result.household is not household + assert household == original + + def test__deprecated_variable_on_multiple_people__warns_each(self): + household = { + "people": { + "you": { + "medical_out_of_pocket_expenses": {"2025": 100}, + }, + "spouse": { + "medical_out_of_pocket_expenses": {"2025": 200}, + }, + } + } + original = copy.deepcopy(household) + + result = drop_deprecated_inputs(household) + warnings = result.warnings + + assert len(warnings) == 2 + ids = {w.entity_id for w in warnings} + assert ids == {"you", "spouse"} + assert result.household["people"]["you"] == {} + assert result.household["people"]["spouse"] == {} + assert household == original + + def test__list_valued_entity_group__is_skipped_safely(self): + # `axes` is a list at the household level, not an entity dict. + # The helper must skip it without raising. + household = { + "people": {"you": {"age": {"2025": 49}}}, + "axes": [[{"name": "employment_income", "count": 5}]], + } + original = copy.deepcopy(household) + + result = drop_deprecated_inputs(household) + + assert result.warnings == [] + assert result.household["axes"] == [[{"name": "employment_income", "count": 5}]] + assert household == original + + def test__deprecated_axis_name__is_dropped_with_warning(self): + household = { + "people": {"you": {"age": {"2025": 49}}}, + "axes": [ + [ + { + "name": "medical_out_of_pocket_expenses", + "period": "2025", + "min": 0, + "max": 100, + "count": 2, + }, + { + "name": "employment_income", + "period": "2025", + "min": 0, + "max": 100, + "count": 2, + }, + ] + ], + } + original = copy.deepcopy(household) + + result = drop_deprecated_inputs(household) + warnings = result.warnings + + assert result.household["axes"] == [ + [ + { + "name": "employment_income", + "period": "2025", + "min": 0, + "max": 100, + "count": 2, + } + ] + ] + assert len(warnings) == 1 + assert warnings[0].entity_plural == "axes" + assert warnings[0].entity_id == "0][0" + assert "axes[0][0].name" in warnings[0].message + assert household == original + + def test__only_deprecated_axis_names__removes_axes_key(self): + household = { + "people": {"you": {"age": {"2025": 49}}}, + "axes": [ + [ + { + "name": "medical_out_of_pocket_expenses", + "period": "2025", + "min": 0, + "max": 100, + "count": 2, + } + ] + ], + } + original = copy.deepcopy(household) + + result = drop_deprecated_inputs(household) + + assert "axes" not in result.household + assert len(result.warnings) == 1 + assert household == original + + def test__list_valued_variable__is_not_misinterpreted(self): + # `members` is a list-valued slot on entity groups; the helper must + # not try to mutate it. + household = { + "spm_units": { + "spm_unit": { + "members": ["you"], + "snap": {"2025": None}, + } + } + } + original = copy.deepcopy(household) + + result = drop_deprecated_inputs(household) + + assert result.warnings == [] + assert result.household["spm_units"]["spm_unit"]["members"] == ["you"] + assert household == original + + def test__warning_message_includes_variable_entity_and_hint(self): + warning = DeprecatedVariableWarning( + variable="medical_out_of_pocket_expenses", + entity_plural="people", + entity_id="you", + hint=DEPRECATED_VARIABLES["medical_out_of_pocket_expenses"], + ) + + msg = warning.message + + assert "medical_out_of_pocket_expenses" in msg + assert "people/you" in msg + assert "other_medical_expenses" in msg + assert "deprecated" in msg.lower() From 3a59b4d14083ec0414b88464466077bdf804567c Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Thu, 14 May 2026 18:00:05 -0400 Subject: [PATCH 2/4] Return 400 for unrecognized calculate inputs --- .../port-deprecated-input-stripping.fixed.md | 2 +- policyengine_api/endpoints/household.py | 48 ++++- policyengine_api/utils/input_validation.py | 202 ++++++++++++++++++ .../test_calculate_deprecated_inputs.py | 134 ++++++++++++ tests/unit/utils/test_input_validation.py | 69 ++++++ 5 files changed, 453 insertions(+), 2 deletions(-) create mode 100644 policyengine_api/utils/input_validation.py create mode 100644 tests/unit/utils/test_input_validation.py diff --git a/changelog.d/port-deprecated-input-stripping.fixed.md b/changelog.d/port-deprecated-input-stripping.fixed.md index 5b8ae7b8b..560e1f39b 100644 --- a/changelog.d/port-deprecated-input-stripping.fixed.md +++ b/changelog.d/port-deprecated-input-stripping.fixed.md @@ -1 +1 @@ -Strip deprecated `medical_out_of_pocket_expenses` inputs from `/calculate` requests before simulation and return a warning instead of surfacing an engine error. +Strip deprecated `medical_out_of_pocket_expenses` inputs from calculate requests before simulation and return a warning, and return structured `400` errors for unrecognized calculate inputs instead of surfacing engine errors. diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index 1215c627b..24e14661d 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -5,6 +5,10 @@ import logging from datetime import date from policyengine_api.utils.deprecated_inputs import drop_deprecated_inputs +from policyengine_api.utils.input_validation import ( + find_unrecognized_inputs, + format_unrecognized_inputs_message, +) from policyengine_api.utils.payload_validators import validate_country @@ -49,6 +53,28 @@ def add_yearly_variables(household, country_id, countries=None): return household +def get_invalid_inputs_response(household_json, policy_json, country): + invalid_inputs = find_unrecognized_inputs( + household_json, + policy_json, + country.metadata, + ) + if not invalid_inputs: + return None + + response_body = dict( + status="error", + message=format_unrecognized_inputs_message(invalid_inputs), + result=None, + errors=[invalid_input.to_dict() for invalid_input in invalid_inputs], + ) + return Response( + json.dumps(response_body), + status=400, + mimetype="application/json", + ) + + def get_household_year(household): """Given a household dict, get the household's year @@ -131,6 +157,8 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st household["household_json"] = add_yearly_variables( household["household_json"], country_id ) + deprecated_inputs = drop_deprecated_inputs(household["household_json"]) + household["household_json"] = deprecated_inputs.household # Retrieve from the policy table @@ -154,6 +182,13 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st ) country = get_countries().get(country_id) + invalid_inputs_response = get_invalid_inputs_response( + household["household_json"], + policy["policy_json"], + country, + ) + if invalid_inputs_response is not None: + return invalid_inputs_response try: result = country.calculate( @@ -194,11 +229,15 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st (json.dumps(result), country_id, household_id, policy_id), ) - return dict( + response_body = dict( status="ok", message=None, result=result, ) + warning_messages = [w.message for w in deprecated_inputs.warnings] + if warning_messages: + response_body["warnings"] = warning_messages + return response_body @validate_country @@ -225,6 +264,13 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict: deprecation_warnings = deprecated_inputs.warnings country = get_countries().get(country_id) + invalid_inputs_response = get_invalid_inputs_response( + household_json, + policy_json, + country, + ) + if invalid_inputs_response is not None: + return invalid_inputs_response try: result = country.calculate(household_json, policy_json) diff --git a/policyengine_api/utils/input_validation.py b/policyengine_api/utils/input_validation.py new file mode 100644 index 000000000..68daa3c5f --- /dev/null +++ b/policyengine_api/utils/input_validation.py @@ -0,0 +1,202 @@ +"""Validate calculate payload input names before simulation creation.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class UnrecognizedInput: + input_type: str + name: str + path: str + expected_entity_plural: str | None = None + actual_entity_plural: str | None = None + + @property + def message(self) -> str: + if self.input_type == "household_entity": + return ( + f"Unrecognized household entity group `{self.name}` at `{self.path}`." + ) + if self.input_type == "household_variable_wrong_entity": + return ( + f"Household variable `{self.name}` belongs on " + f"`{self.expected_entity_plural}`, not `{self.actual_entity_plural}`, " + f"at `{self.path}`." + ) + if self.input_type == "household_axis_variable": + return ( + f"Unrecognized household axis variable `{self.name}` at `{self.path}`." + ) + if self.input_type == "policy_parameter": + return f"Unrecognized policy parameter `{self.name}` at `{self.path}`." + return f"Unrecognized household variable `{self.name}` at `{self.path}`." + + def to_dict(self) -> dict: + data = { + "type": self.input_type, + "name": self.name, + "path": self.path, + "message": self.message, + } + if self.expected_entity_plural is not None: + data["expected_entity_plural"] = self.expected_entity_plural + if self.actual_entity_plural is not None: + data["actual_entity_plural"] = self.actual_entity_plural + return data + + +def find_unrecognized_inputs( + household: dict, + policy: dict, + metadata: dict, +) -> list[UnrecognizedInput]: + """Return calculate payload names that the country package cannot resolve.""" + + return [ + *find_unrecognized_household_inputs(household, metadata), + *find_unrecognized_policy_inputs(policy, metadata), + ] + + +def find_unrecognized_household_inputs( + household: dict, + metadata: dict, +) -> list[UnrecognizedInput]: + if not isinstance(household, dict): + return [] + + errors: list[UnrecognizedInput] = [] + variables = metadata.get("variables", {}) + entities = metadata.get("entities", {}) + entity_by_plural = { + entity["plural"]: entity for entity in entities.values() if "plural" in entity + } + + for entity_plural, entity_group in household.items(): + if entity_plural == "axes": + errors.extend(_find_unrecognized_axes(entity_group, variables)) + continue + + entity = entity_by_plural.get(entity_plural) + if entity is None: + errors.append( + UnrecognizedInput( + input_type="household_entity", + name=entity_plural, + path=f"household.{entity_plural}", + ) + ) + continue + if not isinstance(entity_group, dict): + continue + + relationship_keys = { + role["plural"] + for role in entity.get("roles", {}).values() + if "plural" in role + } + for entity_id, entity_inputs in entity_group.items(): + if not isinstance(entity_inputs, dict): + continue + for variable_name in entity_inputs: + if variable_name in relationship_keys: + continue + variable = variables.get(variable_name) + path = f"household.{entity_plural}.{entity_id}.{variable_name}" + if variable is None: + errors.append( + UnrecognizedInput( + input_type="household_variable", + name=variable_name, + path=path, + ) + ) + continue + expected_entity = entities.get(variable["entity"], {}) + expected_entity_plural = expected_entity.get("plural") + if ( + expected_entity_plural is not None + and expected_entity_plural != entity_plural + ): + errors.append( + UnrecognizedInput( + input_type="household_variable_wrong_entity", + name=variable_name, + path=path, + expected_entity_plural=expected_entity_plural, + actual_entity_plural=entity_plural, + ) + ) + + return errors + + +def find_unrecognized_policy_inputs( + policy: dict, + metadata: dict, +) -> list[UnrecognizedInput]: + if not isinstance(policy, dict): + return [] + + parameters = metadata.get("parameters", {}) + return [ + UnrecognizedInput( + input_type="policy_parameter", + name=parameter_name, + path=f"policy.{parameter_name}", + ) + for parameter_name in policy + if parameter_name not in parameters + ] + + +def format_unrecognized_inputs_message(errors: list[UnrecognizedInput]) -> str: + return "Unrecognized calculate input(s): " + "; ".join( + error.message for error in errors + ) + + +def _find_unrecognized_axes(axes, variables: dict) -> list[UnrecognizedInput]: + if not isinstance(axes, list): + return [] + + errors: list[UnrecognizedInput] = [] + for entry_index, entry in enumerate(axes): + if isinstance(entry, list): + for axis_index, axis in enumerate(entry): + errors.extend( + _find_unrecognized_axis( + axis, + f"household.axes[{entry_index}][{axis_index}].name", + variables, + ) + ) + continue + errors.extend( + _find_unrecognized_axis( + entry, + f"household.axes[{entry_index}].name", + variables, + ) + ) + return errors + + +def _find_unrecognized_axis( + axis, + path: str, + variables: dict, +) -> list[UnrecognizedInput]: + if not isinstance(axis, dict) or "name" not in axis: + return [] + + variable_name = axis["name"] + if variable_name in variables: + return [] + return [ + UnrecognizedInput( + input_type="household_axis_variable", + name=variable_name, + path=path, + ) + ] diff --git a/tests/unit/endpoints/test_calculate_deprecated_inputs.py b/tests/unit/endpoints/test_calculate_deprecated_inputs.py index f3545f2f3..0d7799abd 100644 --- a/tests/unit/endpoints/test_calculate_deprecated_inputs.py +++ b/tests/unit/endpoints/test_calculate_deprecated_inputs.py @@ -8,6 +8,30 @@ class DummyCountry: def __init__(self): self.household = None self.policy = None + self.metadata = { + "variables": { + "age": {"entity": "person"}, + "employment_income": {"entity": "person"}, + "snap": {"entity": "spm_unit"}, + }, + "entities": { + "person": { + "plural": "people", + "roles": {}, + }, + "spm_unit": { + "plural": "spm_units", + "roles": { + "member": { + "plural": "members", + }, + }, + }, + }, + "parameters": { + "gov.irs.income.exemption.amount": {}, + }, + } def calculate(self, household, policy): self.household = household @@ -124,6 +148,116 @@ def test__calculate__omits_warnings_without_deprecated_input(calculate_client): assert country.household == household +def test__calculate__returns_400_for_unrecognized_household_variable( + calculate_client, +): + client, country = calculate_client + household = { + "people": { + "you": { + "age": {"2025": 49}, + "employmnt_income": {"2025": 1_000}, + } + } + } + + response = client.post("/us/calculate", json={"household": household}) + + assert response.status_code == 400 + payload = response.get_json() + assert payload["status"] == "error" + assert payload["result"] is None + assert "employmnt_income" in payload["message"] + assert payload["errors"] == [ + { + "type": "household_variable", + "name": "employmnt_income", + "path": "household.people.you.employmnt_income", + "message": ( + "Unrecognized household variable `employmnt_income` at " + "`household.people.you.employmnt_income`." + ), + } + ] + assert country.household is None + + +def test__calculate__returns_400_for_variable_on_wrong_entity(calculate_client): + client, country = calculate_client + household = { + "people": { + "you": { + "age": {"2025": 49}, + "snap": {"2025": None}, + } + } + } + + response = client.post("/us/calculate", json={"household": household}) + + assert response.status_code == 400 + payload = response.get_json() + assert payload["errors"][0]["type"] == "household_variable_wrong_entity" + assert payload["errors"][0]["name"] == "snap" + assert payload["errors"][0]["expected_entity_plural"] == "spm_units" + assert payload["errors"][0]["actual_entity_plural"] == "people" + assert country.household is None + + +def test__calculate__returns_400_for_unrecognized_axis_variable(calculate_client): + client, country = calculate_client + household = { + "people": {"you": {"age": {"2025": 49}}}, + "axes": [[{"name": "employmnt_income", "count": 2}]], + } + + response = client.post("/us/calculate", json={"household": household}) + + assert response.status_code == 400 + payload = response.get_json() + assert payload["errors"][0]["type"] == "household_axis_variable" + assert payload["errors"][0]["name"] == "employmnt_income" + assert payload["errors"][0]["path"] == "household.axes[0][0].name" + assert country.household is None + + +def test__calculate__returns_400_for_unrecognized_policy_parameter( + calculate_client, +): + client, country = calculate_client + household = {"people": {"you": {"age": {"2025": 49}}}} + policy = {"gov.irs.income.exemption.amunt": {"2025-01-01.2025-12-31": 100}} + + response = client.post( + "/us/calculate", + json={"household": household, "policy": policy}, + ) + + assert response.status_code == 400 + payload = response.get_json() + assert payload["errors"][0]["type"] == "policy_parameter" + assert payload["errors"][0]["name"] == "gov.irs.income.exemption.amunt" + assert country.household is None + + +def test__calculate__preserves_relationship_fields(calculate_client): + client, country = calculate_client + household = { + "people": {"you": {"age": {"2025": 49}}}, + "spm_units": { + "spm_unit": { + "members": ["you"], + "snap": {"2025": None}, + } + }, + } + + response = client.post("/us/calculate", json={"household": household}) + + assert response.status_code == 200 + assert country.household == household + + def test__calculate_full__drops_deprecated_input_after_add_missing(monkeypatch): country = DummyCountry() monkeypatch.setattr( diff --git a/tests/unit/utils/test_input_validation.py b/tests/unit/utils/test_input_validation.py new file mode 100644 index 000000000..9cd70efa2 --- /dev/null +++ b/tests/unit/utils/test_input_validation.py @@ -0,0 +1,69 @@ +from policyengine_api.utils.input_validation import find_unrecognized_inputs + + +METADATA = { + "variables": { + "age": {"entity": "person"}, + "employment_income": {"entity": "person"}, + "snap": {"entity": "spm_unit"}, + }, + "entities": { + "person": { + "plural": "people", + "roles": {}, + }, + "spm_unit": { + "plural": "spm_units", + "roles": { + "member": { + "plural": "members", + }, + }, + }, + }, + "parameters": { + "gov.irs.income.exemption.amount": {}, + }, +} + + +def test__find_unrecognized_inputs__accepts_known_household_and_policy_inputs(): + household = { + "people": {"you": {"age": {"2025": 49}}}, + "spm_units": { + "spm_unit": { + "members": ["you"], + "snap": {"2025": None}, + } + }, + "axes": [[{"name": "employment_income", "count": 2}]], + } + policy = {"gov.irs.income.exemption.amount": {"2025-01-01.2025-12-31": 100}} + + errors = find_unrecognized_inputs(household, policy, METADATA) + + assert errors == [] + + +def test__find_unrecognized_inputs__reports_unknown_entity_variable_axis_and_policy(): + household = { + "person": {"you": {"age": {"2025": 49}}}, + "people": {"you": {"age": {"2025": 49}, "snap": {"2025": None}}}, + "axes": [[{"name": "employmnt_income", "count": 2}]], + } + policy = {"gov.irs.income.exemption.amunt": {"2025-01-01.2025-12-31": 100}} + + errors = find_unrecognized_inputs(household, policy, METADATA) + + assert [error.input_type for error in errors] == [ + "household_entity", + "household_variable_wrong_entity", + "household_axis_variable", + "policy_parameter", + ] + assert [error.path for error in errors] == [ + "household.person", + "household.people.you.snap", + "household.axes[0][0].name", + "policy.gov.irs.income.exemption.amunt", + ] From 04a75a197a66df977b4d2e4622b97620dc76e169 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 15 May 2026 00:00:32 +0200 Subject: [PATCH 3/4] Add live deprecated calculate input test --- tests/integration/test_live_calculate.py | 29 ++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/integration/test_live_calculate.py b/tests/integration/test_live_calculate.py index 70259d73b..e3b34c2fc 100644 --- a/tests/integration/test_live_calculate.py +++ b/tests/integration/test_live_calculate.py @@ -29,3 +29,32 @@ def test_live_calculate_us_2(api_client): assert response.status_code == 200, response.text payload = response.json() assert payload is not None + + +def test_live_calculate_drops_deprecated_medical_input( + api_client, + integration_probe_id, +): + response = api_client.post( + "/us/calculate", + json={ + "staging_probe": f"{integration_probe_id}-deprecated-medical-input", + "household": { + "people": { + "you": { + "age": {"2026": 40}, + "medical_out_of_pocket_expenses": {"2026": 0}, + } + } + }, + }, + ) + + assert response.status_code == 200, response.text + payload = response.json() + assert payload["status"] == "ok", payload + assert "medical_out_of_pocket_expenses" not in payload["result"]["people"]["you"] + assert any( + "medical_out_of_pocket_expenses" in warning and "deprecated" in warning.lower() + for warning in payload["warnings"] + ) From 661538da5b48cddee4a0e4375c0150c25080c0d6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 15 May 2026 00:03:20 +0200 Subject: [PATCH 4/4] Keep live calculate probe payload valid --- tests/integration/test_live_calculate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_live_calculate.py b/tests/integration/test_live_calculate.py index e3b34c2fc..092c464ae 100644 --- a/tests/integration/test_live_calculate.py +++ b/tests/integration/test_live_calculate.py @@ -35,15 +35,18 @@ def test_live_calculate_drops_deprecated_medical_input( api_client, integration_probe_id, ): + deprecated_input_value = int(integration_probe_id.rsplit("-", 1)[-1], 16) + response = api_client.post( "/us/calculate", json={ - "staging_probe": f"{integration_probe_id}-deprecated-medical-input", "household": { "people": { "you": { "age": {"2026": 40}, - "medical_out_of_pocket_expenses": {"2026": 0}, + "medical_out_of_pocket_expenses": { + "2026": deprecated_input_value + }, } } },