diff --git a/changelog.d/household-axes.added.md b/changelog.d/household-axes.added.md new file mode 100644 index 00000000..0c99f42e --- /dev/null +++ b/changelog.d/household-axes.added.md @@ -0,0 +1 @@ +Expose axes in the US and UK household calculator helpers. diff --git a/docs/households.md b/docs/households.md index cae9501a..29488727 100644 --- a/docs/households.md +++ b/docs/households.md @@ -103,6 +103,44 @@ result = pe.us.calculate_household( ) ``` +## Axes + +Use `axes` to evaluate one household across a grid of input values. Pass either +the lower-level nested shape or a flat list of axis dictionaries; missing +`period` values default to `year`. + +```python +result = pe.us.calculate_household( + people=[ + { + "age": 35, + "employment_income": 60_000, + "is_tax_unit_head": True, + "charitable_cash_donations": 0, + } + ], + tax_unit={"filing_status": "SINGLE"}, + household={"state_code": "CA"}, + year=2026, + axes=[ + { + "name": "charitable_cash_donations", + "min": 0, + "max": 10_000, + "count": 3, + } + ], + extra_variables=["charitable_cash_donations"], +) + +result.person[0].charitable_cash_donations # [0, 5000, 10000] +result.tax_unit.income_tax # one value per axis point +``` + +When axes are present, result values are lists ordered by the axis grid instead +of scalars. For person results, each person still has their own result object; +each variable on that person is its own axis series. + ## Accessing the result ```python diff --git a/src/policyengine/tax_benefit_models/common/__init__.py b/src/policyengine/tax_benefit_models/common/__init__.py index 744bf21d..76c0c674 100644 --- a/src/policyengine/tax_benefit_models/common/__init__.py +++ b/src/policyengine/tax_benefit_models/common/__init__.py @@ -5,6 +5,8 @@ their public ``calculate_household`` / ``analyze_reform`` entry points. """ +from .axes import normalize_axes as normalize_axes +from .axes import values_for_entity as values_for_entity from .extra_variables import dispatch_extra_variables as dispatch_extra_variables from .household import ( validate_annual_household_inputs as validate_annual_household_inputs, diff --git a/src/policyengine/tax_benefit_models/common/axes.py b/src/policyengine/tax_benefit_models/common/axes.py new file mode 100644 index 00000000..94a5facf --- /dev/null +++ b/src/policyengine/tax_benefit_models/common/axes.py @@ -0,0 +1,81 @@ +"""Axes helpers for household calculators.""" + +from __future__ import annotations + +from collections.abc import Mapping +from difflib import get_close_matches +from typing import Any, Optional + +from policyengine.core.tax_benefit_model_version import TaxBenefitModelVersion + + +def normalize_axes( + *, + axes: Optional[list[Any]], + year: int, + model_version: TaxBenefitModelVersion, +) -> Optional[list[list[dict[str, Any]]]]: + """Validate and periodise household-calculator axes. + + The country packages expect the lower-level OpenFisca/PolicyEngine Core + shape: ``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]``. + For convenience, callers may also pass a flat list of axis dictionaries. + Missing ``period`` values default to the household calculator's ``year``. + """ + if axes is None: + return None + if not isinstance(axes, list) or not axes: + raise ValueError("axes must be a non-empty list of axis dictionaries.") + + axis_groups = axes if isinstance(axes[0], list) else [axes] + normalized: list[list[dict[str, Any]]] = [] + variables_by_name = model_version.variables_by_name + + for group in axis_groups: + if not isinstance(group, list) or not group: + raise ValueError("each axes group must be a non-empty list.") + + normalized_group: list[dict[str, Any]] = [] + for axis in group: + if not isinstance(axis, Mapping): + raise ValueError("each axis must be a dictionary.") + + axis_dict = dict(axis) + name = axis_dict.get("name") + if not isinstance(name, str): + raise ValueError("each axis must include a string 'name'.") + if name not in variables_by_name: + suggestions = get_close_matches( + name, list(variables_by_name), n=1, cutoff=0.7 + ) + suggestion = ( + f" (did you mean '{suggestions[0]}'?)" if suggestions else "" + ) + raise ValueError( + f"axis variable '{name}' is not defined on " + f"{model_version.model.id} {model_version.version}{suggestion}" + ) + + for required_key in ("min", "max", "count"): + if required_key not in axis_dict: + raise ValueError(f"axis '{name}' must include '{required_key}'.") + + axis_dict.setdefault("period", year) + normalized_group.append(axis_dict) + + normalized.append(normalized_group) + + return normalized + + +def values_for_entity( + values: list[Any], + *, + entity_index: int, + entity_count: int, + axes_active: bool, +): + """Return scalar or axis-series values for one entity member.""" + if not axes_active: + return values[entity_index] + return values[entity_index::entity_count] diff --git a/src/policyengine/tax_benefit_models/uk/household.py b/src/policyengine/tax_benefit_models/uk/household.py index 53c386d1..41832336 100644 --- a/src/policyengine/tax_benefit_models/uk/household.py +++ b/src/policyengine/tax_benefit_models/uk/household.py @@ -28,7 +28,9 @@ HouseholdResult, compile_reform, dispatch_extra_variables, + normalize_axes, validate_annual_household_inputs, + values_for_entity, ) from policyengine.utils.household_validation import validate_household_input @@ -63,6 +65,7 @@ def _build_situation( benunit: Mapping[str, Any], household: Mapping[str, Any], year: int, + axes: Optional[list[Any]] = None, ) -> dict[str, Any]: year_str = str(year) @@ -75,15 +78,18 @@ def _periodise(spec: Mapping[str, Any]) -> dict[str, dict[str, Any]]: def _group(spec: Mapping[str, Any]) -> dict[str, Any]: return {"members": list(person_ids), **_periodise(spec)} - return { + situation = { "people": persons, "benunits": {"benunit_0": _group(benunit)}, "households": {"household_0": _group(household)}, } + if axes is not None: + situation["axes"] = axes + return situation _ALLOWED_KWARGS = frozenset( - {"people", "benunit", "household", "year", "reform", "extra_variables"} + {"people", "benunit", "household", "year", "reform", "extra_variables", "axes"} ) @@ -100,7 +106,7 @@ def _raise_unexpected_kwargs(unexpected: Mapping[str, Any]) -> None: ) lines.append(f" - '{name}'{hint}") lines.append( - "Valid kwargs: people, benunit, household, year, reform, extra_variables." + "Valid kwargs: people, benunit, household, year, reform, extra_variables, axes." ) raise TypeError("\n".join(lines)) @@ -113,6 +119,7 @@ def calculate_household( year: int = 2026, reform: Optional[Mapping[str, Any]] = None, extra_variables: Optional[list[str]] = None, + axes: Optional[list[Any]] = None, **unexpected: Any, ) -> HouseholdResult: """Compute tax and benefit variables for a single UK household. @@ -127,6 +134,11 @@ def calculate_household( close-match suggestion. extra_variables: Flat list of extra UK variables to compute; the library dispatches each to its entity. + axes: Optional household-calculator axes. Pass either the lower-level + ``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]`` + shape or a flat list of axis dictionaries. Missing ``period`` + values default to ``year``. When axes are present, result values + are lists ordered by the axis grid instead of scalars. Returns: :class:`HouseholdResult` with dot-accessible entity results. @@ -170,6 +182,8 @@ def calculate_household( ) output_columns = _default_output_columns(extra_by_entity) reform_dict = compile_reform(reform, year=year, model_version=uk_latest) + normalized_axes = normalize_axes(axes=axes, year=year, model_version=uk_latest) + axes_active = normalized_axes is not None simulation = Simulation( situation=_build_situation( @@ -177,6 +191,7 @@ def calculate_household( benunit=benunit_dict, household=household_dict, year=year, + axes=normalized_axes, ), reform=reform_dict, ) @@ -190,12 +205,27 @@ def calculate_household( if entity == "person": result["person"] = [ EntityResult( - {variable: _safe_convert(raw[variable][i]) for variable in columns} + { + variable: values_for_entity( + [_safe_convert(value) for value in raw[variable]], + entity_index=i, + entity_count=len(people), + axes_active=axes_active, + ) + for variable in columns + } ) for i in range(len(people)) ] else: result[entity] = EntityResult( - {variable: _safe_convert(raw[variable][0]) for variable in columns} + { + variable: ( + [_safe_convert(value) for value in raw[variable]] + if axes_active + else _safe_convert(raw[variable][0]) + ) + for variable in columns + } ) return result diff --git a/src/policyengine/tax_benefit_models/us/household.py b/src/policyengine/tax_benefit_models/us/household.py index 76902f00..5122bb15 100644 --- a/src/policyengine/tax_benefit_models/us/household.py +++ b/src/policyengine/tax_benefit_models/us/household.py @@ -45,7 +45,9 @@ HouseholdResult, compile_reform, dispatch_extra_variables, + normalize_axes, validate_annual_household_inputs, + values_for_entity, ) from policyengine.utils.household_validation import validate_household_input @@ -66,7 +68,7 @@ def _raise_unexpected_kwargs(unexpected: Mapping[str, Any]) -> None: lines.append(f" - '{name}'{hint}") lines.append( "Valid kwargs: people, marital_unit, family, spm_unit, tax_unit, " - "household, year, reform, extra_variables." + "household, year, reform, extra_variables, axes." ) raise TypeError("\n".join(lines)) @@ -102,6 +104,7 @@ def _build_situation( tax_unit: Mapping[str, Any], household: Mapping[str, Any], year: int, + axes: Optional[list[Any]] = None, ) -> dict[str, Any]: year_str = str(year) @@ -114,7 +117,7 @@ def _periodise(spec: Mapping[str, Any]) -> dict[str, dict[str, Any]]: def _group(spec: Mapping[str, Any]) -> dict[str, Any]: return {"members": list(person_ids), **_periodise(spec)} - return { + situation = { "people": persons, "marital_units": {"marital_unit_0": _group(marital_unit)}, "families": {"family_0": _group(family)}, @@ -122,6 +125,9 @@ def _group(spec: Mapping[str, Any]) -> dict[str, Any]: "tax_units": {"tax_unit_0": _group(tax_unit)}, "households": {"household_0": _group(household)}, } + if axes is not None: + situation["axes"] = axes + return situation _ALLOWED_KWARGS = frozenset( @@ -135,6 +141,7 @@ def _group(spec: Mapping[str, Any]) -> dict[str, Any]: "year", "reform", "extra_variables", + "axes", } ) @@ -150,6 +157,7 @@ def calculate_household( year: int = 2026, reform: Optional[Mapping[str, Any]] = None, extra_variables: Optional[list[str]] = None, + axes: Optional[list[Any]] = None, **unexpected: Any, ) -> HouseholdResult: """Compute tax and benefit variables for a single US household. @@ -171,6 +179,11 @@ def calculate_household( the default output columns; the library dispatches each name to its entity. Unknown names raise ``ValueError`` with a close-match suggestion. + axes: Optional household-calculator axes. Pass either the lower-level + ``[[{"name": ..., "min": ..., "max": ..., "count": ...}]]`` + shape or a flat list of axis dictionaries. Missing ``period`` + values default to ``year``. When axes are present, result values + are lists ordered by the axis grid instead of scalars. Returns: :class:`HouseholdResult` with dot-accessible per-entity @@ -220,6 +233,8 @@ def calculate_household( ) output_columns = _default_output_columns(extra_by_entity) reform_dict = compile_reform(reform, year=year, model_version=us_latest) + normalized_axes = normalize_axes(axes=axes, year=year, model_version=us_latest) + axes_active = normalized_axes is not None simulation = Simulation( situation=_build_situation( @@ -230,6 +245,7 @@ def calculate_household( tax_unit=entities["tax_unit"], household=entities["household"], year=year, + axes=normalized_axes, ), reform=reform_dict, ) @@ -243,12 +259,27 @@ def calculate_household( if entity == "person": result["person"] = [ EntityResult( - {variable: _safe_convert(raw[variable][i]) for variable in columns} + { + variable: values_for_entity( + [_safe_convert(value) for value in raw[variable]], + entity_index=i, + entity_count=len(people), + axes_active=axes_active, + ) + for variable in columns + } ) for i in range(len(people)) ] else: result[entity] = EntityResult( - {variable: _safe_convert(raw[variable][0]) for variable in columns} + { + variable: ( + [_safe_convert(value) for value in raw[variable]] + if axes_active + else _safe_convert(raw[variable][0]) + ) + for variable in columns + } ) return result diff --git a/tests/test_household_impact.py b/tests/test_household_impact.py index 88444ebc..c03ac8dd 100644 --- a/tests/test_household_impact.py +++ b/tests/test_household_impact.py @@ -97,6 +97,30 @@ def test__periodized_group_input__then_raises_before_calculation(self): year=2026, ) + def test__axes__then_result_values_are_axis_series(self): + result = pe.uk.calculate_household( + people=[ + { + "age": 35, + "employment_income": 50000, + "gift_aid": 0, + } + ], + year=2026, + axes=[ + { + "name": "gift_aid", + "min": 0, + "max": 10000, + "count": 3, + } + ], + extra_variables=["gift_aid"], + ) + assert result.person[0].gift_aid == [0, 5000, 10000] + assert len(result.person[0].income_tax) == 3 + assert len(result.household.household_tax) == 3 + class TestUSCalculateHousehold: def test__single_adult__then_returns_result_with_net_income(self): @@ -143,6 +167,57 @@ def test__extra_variables_flat_list__then_values_appear_on_entity(self): assert "adjusted_gross_income" in result.tax_unit assert result.tax_unit.adjusted_gross_income > 0 + def test__axes__then_result_values_are_axis_series(self): + result = pe.us.calculate_household( + people=[ + { + "age": 35, + "employment_income": 60000, + "is_tax_unit_head": True, + "charitable_cash_donations": 0, + } + ], + tax_unit={"filing_status": "SINGLE"}, + household={"state_code": "CA"}, + year=2026, + axes=[ + { + "name": "charitable_cash_donations", + "min": 0, + "max": 10000, + "count": 3, + } + ], + extra_variables=["charitable_cash_donations"], + ) + assert result.person[0].charitable_cash_donations == [0, 5000, 10000] + assert result.tax_unit.income_tax == [5020, 4900, 4900] + assert len(result.household.household_net_income) == 3 + + def test__nested_axes_shape__then_supported(self): + result = pe.us.calculate_household( + people=[ + { + "age": 35, + "employment_income": 0, + "is_tax_unit_head": True, + } + ], + tax_unit={"filing_status": "SINGLE"}, + year=2026, + axes=[ + [ + { + "name": "employment_income", + "min": 0, + "max": 10000, + "count": 2, + } + ] + ], + ) + assert result.person[0].employment_income == [0, 10000] + def test__reform_compiles_effective_date_form(self): result = pe.us.calculate_household( people=[{"age": 30, "is_tax_unit_head": True}], @@ -244,6 +319,21 @@ def test__unknown_reform_path__then_raises_with_close_match(self): reform={"gov.irs.not_a_real_parameter": 0}, ) + def test__unknown_axis_variable__then_raises_with_suggestion(self): + with pytest.raises(ValueError, match="axis variable"): + pe.us.calculate_household( + people=[{"age": 35, "is_tax_unit_head": True}], + year=2026, + axes=[ + { + "name": "employment_incme", + "min": 0, + "max": 10000, + "count": 3, + } + ], + ) + def test__us_kwarg_on_uk__then_raises_with_uk_hint(self): with pytest.raises(TypeError, match="US-only"): pe.uk.calculate_household(