Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/port-deprecated-input-stripping.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Strip deprecated `medical_out_of_pocket_expenses` inputs from calculate requests before simulation and return a warning, and return structured `400` errors for unrecognized calculate inputs instead of surfacing engine errors.
66 changes: 64 additions & 2 deletions policyengine_api/endpoints/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
import logging
from datetime import date
from policyengine_api.utils.deprecated_inputs import drop_deprecated_inputs
from policyengine_api.utils.input_validation import (
find_unrecognized_inputs,
format_unrecognized_inputs_message,
)
from policyengine_api.utils.payload_validators import validate_country


Expand Down Expand Up @@ -48,6 +53,28 @@ def add_yearly_variables(household, country_id, countries=None):
return household


def get_invalid_inputs_response(household_json, policy_json, country):
invalid_inputs = find_unrecognized_inputs(
household_json,
policy_json,
country.metadata,
)
if not invalid_inputs:
return None

response_body = dict(
status="error",
message=format_unrecognized_inputs_message(invalid_inputs),
result=None,
errors=[invalid_input.to_dict() for invalid_input in invalid_inputs],
)
return Response(
json.dumps(response_body),
status=400,
mimetype="application/json",
)


def get_household_year(household):
"""Given a household dict, get the household's year

Expand Down Expand Up @@ -130,6 +157,8 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st
household["household_json"] = add_yearly_variables(
household["household_json"], country_id
)
deprecated_inputs = drop_deprecated_inputs(household["household_json"])
household["household_json"] = deprecated_inputs.household

# Retrieve from the policy table

Expand All @@ -153,6 +182,13 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st
)

country = get_countries().get(country_id)
invalid_inputs_response = get_invalid_inputs_response(
household["household_json"],
policy["policy_json"],
country,
)
if invalid_inputs_response is not None:
return invalid_inputs_response

try:
result = country.calculate(
Expand Down Expand Up @@ -193,11 +229,15 @@ def get_household_under_policy(country_id: str, household_id: str, policy_id: st
(json.dumps(result), country_id, household_id, policy_id),
)

return dict(
response_body = dict(
status="ok",
message=None,
result=result,
)
warning_messages = [w.message for w in deprecated_inputs.warnings]
if warning_messages:
response_body["warnings"] = warning_messages
return response_body


@validate_country
Expand All @@ -216,7 +256,21 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict:
# Add in any missing yearly variables to household_json
household_json = add_yearly_variables(household_json, country_id)

# Strip deprecated inputs from a copy before the engine runs so
# partners who still pass removed/renamed variables get a warning +
# working response instead of a `VariableNotFoundError` HTTP 500.
deprecated_inputs = drop_deprecated_inputs(household_json)
household_json = deprecated_inputs.household
deprecation_warnings = deprecated_inputs.warnings

country = get_countries().get(country_id)
invalid_inputs_response = get_invalid_inputs_response(
household_json,
policy_json,
country,
)
if invalid_inputs_response is not None:
return invalid_inputs_response

try:
result = country.calculate(household_json, policy_json)
Expand All @@ -232,8 +286,16 @@ def get_calculate(country_id: str, add_missing: bool = False) -> dict:
mimetype="application/json",
)

return dict(
response_body = dict(
status="ok",
message=None,
result=result,
)

warning_messages = [w.message for w in deprecation_warnings]
if warning_messages:
# Serialize to strings on the wire; the structured dataclasses
# stay available for any future caller that wants the fields.
response_body["warnings"] = warning_messages

return response_body
157 changes: 157 additions & 0 deletions policyengine_api/utils/deprecated_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Detect and drop deprecated input variables before they reach the engine.

Without this, a partner who passes a removed model variable (e.g.
``medical_out_of_pocket_expenses``, deleted in policyengine-us 1.673.0)
crashes the simulation with ``VariableNotFoundError`` (HTTP 500). Dropping
the input and surfacing a structured warning gives partners a soft
landing - every other output computes normally; only outputs that
depended on the deprecated input fall back to defaults.
"""

import copy
from dataclasses import dataclass


# Registry of removed/renamed model variables that legacy partner traffic
# may still pass. The value is a one-line migration hint surfaced verbatim
# in the warning message - keep it actionable.
DEPRECATED_VARIABLES: dict[str, str] = {
"medical_out_of_pocket_expenses": (
"Removed in policyengine-us 1.673.0. Migrate non-premium spending "
"to `other_medical_expenses` and premium spending to "
"`health_insurance_premiums`."
),
}


@dataclass(frozen=True)
class DeprecatedVariableWarning:
"""A removed/renamed variable was supplied; dropped before the engine ran."""

variable: str
entity_plural: str
entity_id: str
hint: str

@property
def message(self) -> str:
location = f"`{self.entity_plural}/{self.entity_id}`"
if self.entity_plural == "axes":
location = f"`axes[{self.entity_id}].name`"
return (
f"Input `{self.variable}` on {location} is deprecated and was "
f"ignored for this calculation. {self.hint}"
)


@dataclass(frozen=True)
class DeprecatedInputsResult:
"""A household copy with deprecated inputs removed plus warnings."""

household: dict
warnings: list[DeprecatedVariableWarning]


def drop_deprecated_inputs(
household: dict,
) -> DeprecatedInputsResult:
"""Return a household copy with deprecated input keys stripped.

Returns one warning per (entity, deprecated-key) occurrence. The
caller's ``household`` is never mutated; downstream simulation
receives the returned copy.

Non-dict inputs are returned unchanged with no warnings; downstream
code retains its existing bad-shape behavior.
"""
warnings: list[DeprecatedVariableWarning] = []

if not isinstance(household, dict):
return DeprecatedInputsResult(household=household, warnings=warnings)

cleaned_household = copy.deepcopy(household)

for entity_plural, entity_group in cleaned_household.items():
if entity_plural == "axes":
continue
if not isinstance(entity_group, dict):
continue
for entity_id, variables in entity_group.items():
if not isinstance(variables, dict):
continue
for variable_name in list(variables.keys()):
hint = DEPRECATED_VARIABLES.get(variable_name)
if hint is None:
continue
warnings.append(
DeprecatedVariableWarning(
variable=variable_name,
entity_plural=entity_plural,
entity_id=entity_id,
hint=hint,
)
)
del variables[variable_name]

_drop_deprecated_axes(cleaned_household, warnings)

return DeprecatedInputsResult(household=cleaned_household, warnings=warnings)


def _drop_deprecated_axes(
household: dict, warnings: list[DeprecatedVariableWarning]
) -> None:
axes = household.get("axes")
if not isinstance(axes, list):
return

changed = False
retained_entries = []

for entry_index, entry in enumerate(axes):
if isinstance(entry, list):
retained_axes = []
for axis_index, axis in enumerate(entry):
location = f"{entry_index}][{axis_index}"
if _is_deprecated_axis(axis, location, warnings):
changed = True
continue
retained_axes.append(axis)
if retained_axes:
retained_entries.append(retained_axes)
continue

location = str(entry_index)
if _is_deprecated_axis(entry, location, warnings):
changed = True
continue
retained_entries.append(entry)

if not changed:
return
if retained_entries:
household["axes"] = retained_entries
else:
del household["axes"]


def _is_deprecated_axis(
axis, location: str, warnings: list[DeprecatedVariableWarning]
) -> bool:
if not isinstance(axis, dict):
return False

variable_name = axis.get("name")
hint = DEPRECATED_VARIABLES.get(variable_name)
if hint is None:
return False

warnings.append(
DeprecatedVariableWarning(
variable=variable_name,
entity_plural="axes",
entity_id=location,
hint=hint,
)
)
return True
Loading
Loading