From 3507009333a17b73a42b0d851f986b0343f91256 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 11:59:40 +0000 Subject: [PATCH 1/8] Enhance entity mapping with flexible aggregation methods and custom values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for custom values and multiple aggregation methods to the entity mapping system, making it more flexible for complex analysis workflows. Features added: - values parameter: Map custom value arrays instead of existing columns - Extended how parameter with new aggregation methods: * Person → Group: 'sum' (default), 'first' * Group → Person: 'project' (default), 'divide' * Group → Group: 'sum', 'first', 'project', 'divide' Refactoring: - Created base YearData class to eliminate code duplication - UKYearData and USYearData now inherit from base class - Removed duplicate map_to_entity implementations Documentation: - Added comprehensive entity mapping section to core-concepts.md - Added examples to UK and US model documentation - Documented all aggregation methods with use cases All existing tests pass, confirming backward compatibility. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- docs/core-concepts.md | 113 ++++++++++++ docs/country-models-uk.md | 36 ++++ docs/country-models-us.md | 32 ++++ src/policyengine/core/__init__.py | 1 + src/policyengine/core/dataset.py | 166 +++++++++++++++++- .../tax_benefit_models/uk/datasets.py | 34 +--- .../tax_benefit_models/us/datasets.py | 34 +--- 7 files changed, 354 insertions(+), 62 deletions(-) diff --git a/docs/core-concepts.md b/docs/core-concepts.md index bdba1f39..52c3f290 100644 --- a/docs/core-concepts.md +++ b/docs/core-concepts.md @@ -327,6 +327,119 @@ agg = Aggregate( When you request a household-level variable at person level: 1. Replicates household values to all persons in that household (expansion) +### Direct entity mapping + +You can also map data between entities directly using the `map_to_entity` method: + +```python +# Map person income to household level (sum) +household_income = dataset.data.map_to_entity( + source_entity="person", + target_entity="household", + columns=["employment_income"], + how="sum" +) + +# Map household rent to person level (project/broadcast) +person_rent = dataset.data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["rent"], + how="project" +) +``` + +#### Mapping with custom values + +You can map custom value arrays instead of existing columns: + +```python +# Map custom per-person values to household level +import numpy as np + +# Create custom values (e.g., imputed data) +custom_values = np.array([100, 200, 150, 300]) + +household_totals = dataset.data.map_to_entity( + source_entity="person", + target_entity="household", + values=custom_values, + how="sum" +) +``` + +#### Aggregation methods + +The `how` parameter controls how values are mapped: + +**Person → Group (aggregation):** +- `how='sum'` (default): Sum values within each group +- `how='first'`: Take first person's value in each group + +```python +# Sum person incomes to household level +household_income = data.map_to_entity( + source_entity="person", + target_entity="household", + columns=["employment_income"], + how="sum" +) + +# Take first person's age as household reference +household_age = data.map_to_entity( + source_entity="person", + target_entity="household", + columns=["age"], + how="first" +) +``` + +**Group → Person (expansion):** +- `how='project'` (default): Broadcast group value to all members +- `how='divide'`: Split group value equally among members + +```python +# Broadcast household rent to each person +person_rent = data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["rent"], + how="project" +) + +# Split household savings equally per person +person_savings = data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["total_savings"], + how="divide" +) +``` + +**Group → Group (via person entity):** +- `how='sum'` (default): Sum through person entity +- `how='first'`: Take first source group's value +- `how='project'`: Broadcast first source group's value +- `how='divide'`: Split proportionally based on person counts + +```python +# UK: Sum benunit benefits to household level +household_benefits = data.map_to_entity( + source_entity="benunit", + target_entity="household", + columns=["universal_credit"], + how="sum" +) + +# US: Map tax unit income to household, splitting by members +household_from_tax = data.map_to_entity( + source_entity="tax_unit", + target_entity="household", + columns=["taxable_income"], + how="divide" +) +``` + ## Visualisation The package includes utilities for creating PolicyEngine-branded visualisations: diff --git a/docs/country-models-uk.md b/docs/country-models-uk.md index 27d7dae7..bd9d1fbd 100644 --- a/docs/country-models-uk.md +++ b/docs/country-models-uk.md @@ -295,6 +295,42 @@ Valid region values: - `SCOTLAND` - `NORTHERN_IRELAND` +## Entity mapping + +The UK model has a simpler entity structure than the US, with three levels: person → benunit → household. + +### Direct entity mapping + +You can map data between entities using the `map_to_entity` method: + +```python +# Map person income to benunit level +benunit_income = dataset.data.map_to_entity( + source_entity="person", + target_entity="benunit", + columns=["employment_income"], + how="sum" +) + +# Split household rent equally among persons +person_rent_share = dataset.data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["rent"], + how="divide" +) + +# Map benunit UC to household level +household_uc = dataset.data.map_to_entity( + source_entity="benunit", + target_entity="household", + columns=["universal_credit"], + how="sum" +) +``` + +See the [Entity mapping section](core-concepts.md#entity-mapping) in Core Concepts for full documentation on aggregation methods. + ## Data sources The UK model can use several data sources: diff --git a/docs/country-models-us.md b/docs/country-models-us.md index 927966ee..547a4f3b 100644 --- a/docs/country-models-us.md +++ b/docs/country-models-us.md @@ -368,6 +368,38 @@ Household variables are replicated to all household members: # Each person in household gets the same household_net_income value ``` +### Direct entity mapping + +For complex multi-entity scenarios, you can use `map_to_entity` directly: + +```python +# Map SPM unit SNAP benefits to household level +household_snap = dataset.data.map_to_entity( + source_entity="spm_unit", + target_entity="household", + columns=["snap"], + how="sum" +) + +# Split tax unit income equally among persons +person_tax_income = dataset.data.map_to_entity( + source_entity="tax_unit", + target_entity="person", + columns=["taxable_income"], + how="divide" +) + +# Map custom analysis values +custom_analysis = dataset.data.map_to_entity( + source_entity="person", + target_entity="tax_unit", + values=custom_values_array, + how="sum" +) +``` + +See the [Entity mapping section](core-concepts.md#entity-mapping) in Core Concepts for full documentation on aggregation methods. + ## Data sources The US model can use several data sources: diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index b96e8edd..18e9648e 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -1,5 +1,6 @@ from .dataset import Dataset from .dataset import map_to_entity as map_to_entity +from .dataset import YearData as YearData from .dataset_version import DatasetVersion as DatasetVersion from .dynamic import Dynamic as Dynamic from .output import Output as Output diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index a79c0b6d..23c5129d 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -8,6 +8,61 @@ from .tax_benefit_model import TaxBenefitModel +class YearData(BaseModel): + """Base class for entity-level data for a single year.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @property + def entity_data(self) -> dict[str, MicroDataFrame]: + """Return a dictionary of entity names to their data. + + This should be implemented by subclasses to return the appropriate entities. + """ + raise NotImplementedError("Subclasses must implement entity_data property") + + @property + def person_entity(self) -> str: + """Return the name of the person-level entity. + + Defaults to 'person' but can be overridden by subclasses. + """ + return "person" + + def map_to_entity( + self, + source_entity: str, + target_entity: str, + columns: list[str] = None, + values: list = None, + how: str = "sum", + ) -> MicroDataFrame: + """Map data from source entity to target entity using join keys. + + Args: + source_entity (str): The source entity name. + target_entity (str): The target entity name. + columns (list[str], optional): List of column names to map. If None, maps all columns. + values (list, optional): List of values to use instead of column data. + how (str): Aggregation method ('sum' or 'first') when mapping to higher-level entities (default 'sum'). + + Returns: + MicroDataFrame: The mapped data at the target entity level. + + Raises: + ValueError: If source or target entity is invalid. + """ + return map_to_entity( + entity_data=self.entity_data, + source_entity=source_entity, + target_entity=target_entity, + person_entity=self.person_entity, + columns=columns, + values=values, + how=how, + ) + + class Dataset(BaseModel): """Base class for datasets. @@ -43,6 +98,8 @@ def map_to_entity( target_entity: str, person_entity: str = "person", columns: list[str] | None = None, + values: list | None = None, + how: str = "sum", ) -> MicroDataFrame: """Map data from source entity to target entity using join keys. @@ -58,12 +115,17 @@ def map_to_entity( target_entity: The target entity name person_entity: The name of the person entity (default "person") columns: List of column names to map. If None, maps all columns + values: List of values to use instead of column data. If provided, creates a single unnamed column + how: Aggregation method (default 'sum') + - For person → group: 'sum' (aggregate), 'first' (take first value) + - For group → person: 'project' (broadcast), 'divide' (split equally) + - For group → group: 'sum', 'first', 'project', 'divide' Returns: MicroDataFrame: The mapped data at the target entity level Raises: - ValueError: If source or target entity is invalid + ValueError: If source or target entity is invalid or unsupported aggregation method """ valid_entities = set(entity_data.keys()) @@ -79,6 +141,18 @@ def map_to_entity( # Get source data (convert to plain DataFrame to avoid weighted operations during mapping) source_df = pd.DataFrame(entity_data[source_entity]) + # Handle values parameter - create a temporary column with the provided values + if values is not None: + if len(values) != len(source_df): + raise ValueError( + f"Length of values ({len(values)}) must match source entity length ({len(source_df)})" + ) + # Create a temporary DataFrame with just ID columns and the values column + id_cols = {col for col in source_df.columns if col.endswith("_id")} + source_df = source_df[[col for col in id_cols]] + source_df["__mapped_value"] = values + columns = ["__mapped_value"] + if columns: # Select only requested columns (keep all ID columns for joins) id_cols = {col for col in source_df.columns if col.endswith("_id")} @@ -118,10 +192,17 @@ def map_to_entity( if c not in id_cols and c not in weight_cols ] - # Group by join key and sum - aggregated = source_df.groupby(join_key, as_index=False)[ - agg_cols - ].sum() + # Group by join key and aggregate + if how == "sum": + aggregated = source_df.groupby(join_key, as_index=False)[ + agg_cols + ].sum() + elif how == "first": + aggregated = source_df.groupby(join_key, as_index=False)[ + agg_cols + ].first() + else: + raise ValueError(f"Unsupported aggregation method: {how}") # Rename join key to target key if needed if join_key != target_key: @@ -146,6 +227,10 @@ def map_to_entity( # Group entity to person: expand group-level data to person level if source_entity != person_entity and target_entity == person_entity: + # Default to 'project' (broadcast) for group -> person if 'sum' was provided + if how == "sum": + how = "project" + source_key = f"{source_entity}_id" # Check for both naming patterns person_source_key = f"{person_entity}_{source_entity}_id" @@ -163,6 +248,38 @@ def map_to_entity( source_df = source_df.rename(columns={source_key: join_key}) result = target_pd.merge(source_df, on=join_key, how="left") + + # Handle divide operation + if how == "divide": + # Get columns to divide (exclude ID and weight columns) + id_cols = {col for col in result.columns if col.endswith("_id")} + weight_cols = { + col for col in result.columns if col.endswith("_weight") + } + value_cols = [ + c + for c in result.columns + if c not in id_cols and c not in weight_cols + ] + + # Count members in each group + group_counts = ( + target_pd.groupby(join_key, as_index=False) + .size() + .rename(columns={"size": "__group_count"}) + ) + result = result.merge(group_counts, on=join_key, how="left") + + # Divide values by group count + for col in value_cols: + result[col] = result[col] / result["__group_count"] + + result = result.drop(columns=["__group_count"]) + elif how not in ["project"]: + raise ValueError( + f"Unsupported aggregation method for group->person: {how}. Use 'project' or 'divide'." + ) + return MicroDataFrame(result, weights=target_weight) # Group to group: go through person table @@ -228,9 +345,42 @@ def map_to_entity( if c not in id_cols and c not in weight_cols ] - aggregated = source_with_target.groupby( - target_link_key, as_index=False - )[agg_cols].sum() + if how == "sum": + aggregated = source_with_target.groupby( + target_link_key, as_index=False + )[agg_cols].sum() + elif how == "first": + aggregated = source_with_target.groupby( + target_link_key, as_index=False + )[agg_cols].first() + elif how == "project": + # Just take first value (broadcast to target groups) + aggregated = source_with_target.groupby( + target_link_key, as_index=False + )[agg_cols].first() + elif how == "divide": + # Count persons in each source group + source_group_counts = ( + person_df.groupby(source_link_key, as_index=False) + .size() + .rename(columns={"size": "__source_count"}) + ) + source_with_target = source_with_target.merge( + source_group_counts, on=source_link_key, how="left" + ) + + # Divide values by source group count (per-person share) + for col in agg_cols: + source_with_target[col] = ( + source_with_target[col] / source_with_target["__source_count"] + ) + + # Now aggregate (sum of per-person shares) to target level + aggregated = source_with_target.groupby( + target_link_key, as_index=False + )[agg_cols].sum() + else: + raise ValueError(f"Unsupported aggregation method: {how}") # Rename target link key to target key if needed if target_link_key != target_key: diff --git a/src/policyengine/tax_benefit_models/uk/datasets.py b/src/policyengine/tax_benefit_models/uk/datasets.py index 113d4b57..bdf89d9e 100644 --- a/src/policyengine/tax_benefit_models/uk/datasets.py +++ b/src/policyengine/tax_benefit_models/uk/datasets.py @@ -2,12 +2,12 @@ import pandas as pd from microdf import MicroDataFrame -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict -from policyengine.core import Dataset, map_to_entity +from policyengine.core import Dataset, YearData -class UKYearData(BaseModel): +class UKYearData(YearData): """Entity-level data for a single year.""" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -16,34 +16,14 @@ class UKYearData(BaseModel): benunit: MicroDataFrame household: MicroDataFrame - def map_to_entity( - self, source_entity: str, target_entity: str, columns: list[str] = None - ) -> MicroDataFrame: - """Map data from source entity to target entity using join keys. - - Args: - source_entity (str): The source entity name ('person', 'benunit', 'household'). - target_entity (str): The target entity name ('person', 'benunit', 'household'). - columns (list[str], optional): List of column names to map. If None, maps all columns. - - Returns: - MicroDataFrame: The mapped data at the target entity level. - - Raises: - ValueError: If source or target entity is invalid. - """ - entity_data = { + @property + def entity_data(self) -> dict[str, MicroDataFrame]: + """Return a dictionary of entity names to their data.""" + return { "person": self.person, "benunit": self.benunit, "household": self.household, } - return map_to_entity( - entity_data=entity_data, - source_entity=source_entity, - target_entity=target_entity, - person_entity="person", - columns=columns, - ) class PolicyEngineUKDataset(Dataset): diff --git a/src/policyengine/tax_benefit_models/us/datasets.py b/src/policyengine/tax_benefit_models/us/datasets.py index 676e08e3..53643cc5 100644 --- a/src/policyengine/tax_benefit_models/us/datasets.py +++ b/src/policyengine/tax_benefit_models/us/datasets.py @@ -3,12 +3,12 @@ import pandas as pd from microdf import MicroDataFrame -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict -from policyengine.core import Dataset, map_to_entity +from policyengine.core import Dataset, YearData -class USYearData(BaseModel): +class USYearData(YearData): """Entity-level data for a single year.""" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -20,23 +20,10 @@ class USYearData(BaseModel): tax_unit: MicroDataFrame household: MicroDataFrame - def map_to_entity( - self, source_entity: str, target_entity: str, columns: list[str] = None - ) -> MicroDataFrame: - """Map data from source entity to target entity using join keys. - - Args: - source_entity (str): The source entity name. - target_entity (str): The target entity name. - columns (list[str], optional): List of column names to map. If None, maps all columns. - - Returns: - MicroDataFrame: The mapped data at the target entity level. - - Raises: - ValueError: If source or target entity is invalid. - """ - entity_data = { + @property + def entity_data(self) -> dict[str, MicroDataFrame]: + """Return a dictionary of entity names to their data.""" + return { "person": self.person, "marital_unit": self.marital_unit, "family": self.family, @@ -44,13 +31,6 @@ def map_to_entity( "tax_unit": self.tax_unit, "household": self.household, } - return map_to_entity( - entity_data=entity_data, - source_entity=source_entity, - target_entity=target_entity, - person_entity="person", - columns=columns, - ) class PolicyEngineUSDataset(Dataset): From c6a6178a0998d4ebbb1a96dbeee0b072f3d7dffc Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 12:02:54 +0000 Subject: [PATCH 2/8] Apply code formatting fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/policyengine/core/dataset.py | 11 ++++++++--- src/policyengine/core/dynamic.py | 5 ++++- src/policyengine/core/policy.py | 5 ++++- src/policyengine/utils/plotting.py | 1 - 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/policyengine/core/dataset.py b/src/policyengine/core/dataset.py index 23c5129d..35b74570 100644 --- a/src/policyengine/core/dataset.py +++ b/src/policyengine/core/dataset.py @@ -19,7 +19,9 @@ def entity_data(self) -> dict[str, MicroDataFrame]: This should be implemented by subclasses to return the appropriate entities. """ - raise NotImplementedError("Subclasses must implement entity_data property") + raise NotImplementedError( + "Subclasses must implement entity_data property" + ) @property def person_entity(self) -> str: @@ -252,7 +254,9 @@ def map_to_entity( # Handle divide operation if how == "divide": # Get columns to divide (exclude ID and weight columns) - id_cols = {col for col in result.columns if col.endswith("_id")} + id_cols = { + col for col in result.columns if col.endswith("_id") + } weight_cols = { col for col in result.columns if col.endswith("_weight") } @@ -372,7 +376,8 @@ def map_to_entity( # Divide values by source group count (per-person share) for col in agg_cols: source_with_target[col] = ( - source_with_target[col] / source_with_target["__source_count"] + source_with_target[col] + / source_with_target["__source_count"] ) # Now aggregate (sum of per-person shares) to target level diff --git a/src/policyengine/core/dynamic.py b/src/policyengine/core/dynamic.py index 3b6ba553..81ef62b7 100644 --- a/src/policyengine/core/dynamic.py +++ b/src/policyengine/core/dynamic.py @@ -23,7 +23,10 @@ def __add__(self, other: "Dynamic") -> "Dynamic": # Combine simulation modifiers combined_modifier = None - if self.simulation_modifier is not None and other.simulation_modifier is not None: + if ( + self.simulation_modifier is not None + and other.simulation_modifier is not None + ): def combined_modifier(sim): sim = self.simulation_modifier(sim) diff --git a/src/policyengine/core/policy.py b/src/policyengine/core/policy.py index 3aeb19b9..bfb4ca9e 100644 --- a/src/policyengine/core/policy.py +++ b/src/policyengine/core/policy.py @@ -23,7 +23,10 @@ def __add__(self, other: "Policy") -> "Policy": # Combine simulation modifiers combined_modifier = None - if self.simulation_modifier is not None and other.simulation_modifier is not None: + if ( + self.simulation_modifier is not None + and other.simulation_modifier is not None + ): def combined_modifier(sim): sim = self.simulation_modifier(sim) diff --git a/src/policyengine/utils/plotting.py b/src/policyengine/utils/plotting.py index 77aed94f..c3c0ff28 100644 --- a/src/policyengine/utils/plotting.py +++ b/src/policyengine/utils/plotting.py @@ -1,6 +1,5 @@ """Plotting utilities for PolicyEngine visualisations.""" - import plotly.graph_objects as go # PolicyEngine brand colours From a1108dbf44fba538294b95b3a43a246a86ca12cf Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 12:36:06 +0000 Subject: [PATCH 3/8] Fix import sorting order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/policyengine/core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/policyengine/core/__init__.py b/src/policyengine/core/__init__.py index 18e9648e..630620a0 100644 --- a/src/policyengine/core/__init__.py +++ b/src/policyengine/core/__init__.py @@ -1,6 +1,6 @@ from .dataset import Dataset -from .dataset import map_to_entity as map_to_entity from .dataset import YearData as YearData +from .dataset import map_to_entity as map_to_entity from .dataset_version import DatasetVersion as DatasetVersion from .dynamic import Dynamic as Dynamic from .output import Output as Output From c55198fc9cef4e549886e47c63eff2772899a55e Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 12:51:24 +0000 Subject: [PATCH 4/8] Add Claude-friendly documentation and quick reference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add comprehensive guides for AI assistants to use policyengine.py: - .claude/policyengine-guide.md: Detailed patterns and examples - .claude/quick-reference.md: Quick lookup for common operations Includes: - 7 common workflow patterns (synthetic scenarios, parameter sweeps, reforms) - Minimal working examples for UK and US - Entity mapping examples with all aggregation methods - Critical fields reference - Common parameters cheat sheet - Troubleshooting guide These guides help AI assistants quickly understand and use the package for tax-benefit microsimulation analysis. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .claude/policyengine-guide.md | 568 ++++++++++++++++++++++++++++++++++ .claude/quick-reference.md | 367 ++++++++++++++++++++++ 2 files changed, 935 insertions(+) create mode 100644 .claude/policyengine-guide.md create mode 100644 .claude/quick-reference.md diff --git a/.claude/policyengine-guide.md b/.claude/policyengine-guide.md new file mode 100644 index 00000000..922c542c --- /dev/null +++ b/.claude/policyengine-guide.md @@ -0,0 +1,568 @@ +# PolicyEngine.py - Claude Guide + +This guide helps you use the policyengine.py package to perform tax-benefit microsimulation analysis. + +## Core workflow + +1. **Create or load a dataset** with microdata (person, household, etc.) +2. **Run a simulation** applying tax-benefit rules to the dataset +3. **Extract results** using output classes (Aggregate, ChangeAggregate) +4. **Visualise** using built-in plotting utilities + +## Package structure + +``` +policyengine +├── core/ +│ ├── Dataset, YearData # Data containers +│ ├── Simulation # Runs tax-benefit calculations +│ ├── Policy, Parameter # Define reforms +│ └── map_to_entity() # Entity mapping utility +├── outputs/ +│ ├── Aggregate # Calculate statistics +│ └── ChangeAggregate # Analyse reforms +├── tax_benefit_models/ +│ ├── uk/ # UK-specific models +│ └── us/ # US-specific models +└── utils/ + └── plotting # Visualisation tools +``` + +## Quick start patterns + +### Pattern 1: Synthetic scenario analysis + +Use when: User wants to analyse specific household scenarios + +```python +import pandas as pd +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest +) +from policyengine.core import Simulation + +# Create synthetic person data +person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [0, 1, 2], + "person_household_id": [0, 0, 1], + "person_benunit_id": [0, 0, 1], + "age": [35, 8, 40], + "employment_income": [30000, 0, 50000], + "person_weight": [1.0, 1.0, 1.0], + }), + weights="person_weight" +) + +# Create household data +household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [0, 1], + "region": ["LONDON", "SOUTH_EAST"], + "rent": [15000, 12000], + "household_weight": [1.0, 1.0], + }), + weights="household_weight" +) + +# Create benunit data (UK only) +benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [0, 1], + "would_claim_uc": [True, True], + "benunit_weight": [1.0, 1.0], + }), + weights="benunit_weight" +) + +# Package into dataset +dataset = PolicyEngineUKDataset( + name="Custom scenario", + description="Analysis scenario", + filepath="./custom.h5", + year=2026, + data=UKYearData( + person=person_df, + household=household_df, + benunit=benunit_df, + ) +) + +# Run simulation +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +simulation.run() + +# Access results +output = simulation.output_dataset.data +print(output.household[["household_id", "household_net_income"]]) +``` + +### Pattern 2: US synthetic scenario + +```python +from policyengine.tax_benefit_models.us import ( + PolicyEngineUSDataset, + USYearData, + us_latest +) + +# Create person data (note: US has more entity types) +person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [0, 1, 2, 3], + "person_household_id": [0, 0, 0, 0], + "person_tax_unit_id": [0, 0, 0, 0], + "person_spm_unit_id": [0, 0, 0, 0], + "person_family_id": [0, 0, 0, 0], + "person_marital_unit_id": [0, 0, 1, 2], + "age": [35, 33, 8, 5], + "employment_income": [60000, 40000, 0, 0], + "person_weight": [1.0, 1.0, 1.0, 1.0], + }), + weights="person_weight" +) + +# Create entity dataframes (tax_unit, spm_unit, family, marital_unit, household) +# ... (see examples/employment_income_variation_us.py for full pattern) + +dataset = PolicyEngineUSDataset( + name="US scenario", + year=2024, + filepath="./us_scenario.h5", + data=USYearData( + person=person_df, + tax_unit=tax_unit_df, + spm_unit=spm_unit_df, + family=family_df, + marital_unit=marital_unit_df, + household=household_df, + ) +) +``` + +### Pattern 3: Parameter sweep analysis + +Use when: User wants to vary one parameter across many values + +```python +import numpy as np + +# Create N scenarios with varying parameter +n_scenarios = 43 +income_values = np.linspace(0, 100000, n_scenarios) + +# Create person data with all scenarios +person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": range(n_scenarios), + "person_household_id": range(n_scenarios), + "person_benunit_id": range(n_scenarios), + "age": [35] * n_scenarios, + "employment_income": income_values, + "person_weight": [1.0] * n_scenarios, + }), + weights="person_weight" +) + +# Create matching household/benunit data +household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": range(n_scenarios), + "region": ["LONDON"] * n_scenarios, + "rent": [15000] * n_scenarios, + "household_weight": [1.0] * n_scenarios, + }), + weights="household_weight" +) + +# ... create dataset and run simulation once for all scenarios +``` + +### Pattern 4: Policy reform analysis + +Use when: User wants to compare baseline vs reform + +```python +from policyengine.core import Policy, Parameter, ParameterValue +import datetime + +# Define reform +parameter = Parameter( + name="gov.hmrc.income_tax.allowances.personal_allowance.amount", + tax_benefit_model_version=uk_latest, + description="Personal allowance", + data_type=float, +) + +policy = Policy( + name="Increase personal allowance", + description="Raises PA to £15,000", + parameter_values=[ + ParameterValue( + parameter=parameter, + start_date=datetime.date(2026, 1, 1), + end_date=datetime.date(2026, 12, 31), + value=15000, + ) + ], +) + +# Run baseline +baseline_sim = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +baseline_sim.run() + +# Run reform +reform_sim = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, +) +reform_sim.run() + +# Analyse impact +from policyengine.outputs.change_aggregate import ( + ChangeAggregate, + ChangeAggregateType +) + +winners = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="household_net_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=1, # Gain at least £1 +) +winners.run() +print(f"Winners: {winners.result:,.0f}") +``` + +### Pattern 5: Extract aggregates + +Use when: User wants statistics from simulation results + +```python +from policyengine.outputs.aggregate import Aggregate, AggregateType + +# Total spending on a benefit +total_uc = Aggregate( + simulation=simulation, + variable="universal_credit", + entity="benunit", + aggregate_type=AggregateType.SUM, +) +total_uc.run() +print(f"Total UC: £{total_uc.result / 1e9:.1f}bn") + +# Mean income in top decile +top_decile_income = Aggregate( + simulation=simulation, + variable="household_net_income", + entity="household", + aggregate_type=AggregateType.MEAN, + filter_variable="household_net_income", + quantile=10, + quantile_eq=10, # 10th decile only +) +top_decile_income.run() +print(f"Top decile mean income: £{top_decile_income.result:,.0f}") + +# Count households below poverty line +poverty_count = Aggregate( + simulation=simulation, + variable="household_id", + entity="household", + aggregate_type=AggregateType.COUNT, + filter_variable="in_absolute_poverty_bhc", + filter_eq=True, +) +poverty_count.run() +print(f"Households in poverty: {poverty_count.result:,.0f}") +``` + +### Pattern 6: Entity mapping + +Use when: User needs to map data between entity levels + +```python +# Map person income to household level (sum) +household_income = dataset.data.map_to_entity( + source_entity="person", + target_entity="household", + columns=["employment_income"], + how="sum" +) + +# Map household rent to person level (broadcast) +person_rent = dataset.data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["rent"], + how="project" +) + +# Split household savings equally per person +person_savings_share = dataset.data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["total_savings"], + how="divide" +) + +# Map custom values +import numpy as np +custom_values = np.array([100, 200, 150]) +household_totals = dataset.data.map_to_entity( + source_entity="person", + target_entity="household", + values=custom_values, + how="sum" +) +``` + +### Pattern 7: Visualisation + +```python +from policyengine.utils.plotting import format_fig, COLORS +import plotly.graph_objects as go + +fig = go.Figure() +fig.add_trace(go.Scatter( + x=income_values, + y=net_income_values, + mode='lines', + name='Net income', + line=dict(color=COLORS["primary"], width=3) +)) + +format_fig( + fig, + title="Net income by employment income", + xaxis_title="Employment income (£)", + yaxis_title="Net income (£)", + height=600, + width=1000, +) +fig.show() +``` + +## Entity structures + +### UK entities +``` +household + └── benunit (benefit unit - family claiming benefits together) + └── person +``` + +### US entities +``` +household + ├── tax_unit (federal tax filing unit) + ├── spm_unit (Supplemental Poverty Measure unit) + ├── family (Census definition) + └── marital_unit (married couple or single) + └── person +``` + +## Key concepts + +### 1. MicroDataFrame +All entity data uses `MicroDataFrame` which automatically handles survey weights: +```python +df = MicroDataFrame(pd_dataframe, weights="weight_column_name") +df.sum() # Automatically weighted +``` + +### 2. Entity mapping +When variables are at different entity levels, automatic mapping occurs: +- **Person → Group**: Sum values within each group +- **Group → Person**: Replicate group value to all members + +### 3. Required fields + +**UK person:** +- `person_id`, `person_household_id`, `person_benunit_id` +- `age`, `employment_income` +- `person_weight` + +**UK household:** +- `household_id` +- `region` (e.g., "LONDON", "SOUTH_EAST") +- `rent` (annual) +- `household_weight` + +**UK benunit:** +- `benunit_id` +- `would_claim_uc` (boolean - CRITICAL for UC calculations) +- `benunit_weight` + +**US person:** +- `person_id`, `person_household_id`, `person_tax_unit_id`, `person_spm_unit_id`, `person_family_id`, `person_marital_unit_id` +- `age`, `employment_income` +- `person_weight` + +**US household:** +- `household_id` +- `state_code` (e.g., "CA", "NY") +- `household_weight` + +### 4. Common pitfalls + +**Always set would_claim variables:** +```python +"would_claim_uc": [True] * n_benunits # UK +``` + +**Set disability variables to avoid spikes:** +```python +"is_disabled_for_benefits": [False] * n_people +"uc_limited_capability_for_WRA": [False] * n_people +``` + +**Use consistent ID linkages:** +```python +# Person 0 must link to valid household_id and benunit_id +person_df["person_household_id"] = [0, 0, 1] # Persons 0,1 in household 0 +``` + +## Finding parameters + +### UK common parameters +``` +gov.hmrc.income_tax.allowances.personal_allowance.amount +gov.hmrc.national_insurance.class_1.rates.main +gov.dwp.universal_credit.means_test.reduction_rate +gov.dwp.universal_credit.elements.child.first_child +gov.dwp.child_benefit.amount.first_child +``` + +### US common parameters +``` +gov.irs.income.standard_deduction.single +gov.irs.income.standard_deduction.joint +gov.irs.credits.ctc.amount.base +gov.irs.credits.ctc.refundable.amount.max +gov.irs.credits.eitc.max[0] # 0 children +gov.usda.snap.normal_allotment.max[1] # 1 person +``` + +## Aggregation methods for entity mapping + +- `how='sum'`: Aggregate by summing (person → group default) +- `how='first'`: Take first value in group +- `how='project'`: Broadcast group value to members (group → person default) +- `how='divide'`: Split equally among members + +## Response patterns + +When user asks to: + +1. **"Analyse a family with £X income"** → Use Pattern 1 (synthetic scenario) +2. **"How does income vary from £0 to £100k"** → Use Pattern 3 (parameter sweep) +3. **"What if we increased personal allowance?"** → Use Pattern 4 (policy reform) +4. **"How many people benefit?"** → Use Pattern 5 (extract aggregates) +5. **"Compare US vs UK"** → Create both datasets, run separately +6. **"Show me the phase-out"** → Use Pattern 3 + Pattern 7 (sweep + visualise) + +## Debugging tips + +1. **Check dataset shape**: `len(dataset.data.person)` should match expectations +2. **Verify linkages**: All person IDs should map to valid household IDs +3. **Check weights**: `dataset.data.household["household_weight"].sum()` +4. **Inspect output columns**: `list(simulation.output_dataset.data.person.columns)` +5. **Test small first**: Use 3-5 scenarios before scaling to 100+ + +## Example responses + +**User: "What's the net income of a single person earning £30k in London?"** + +```python +# I'll create a synthetic dataset with one person earning £30k in London +# and run a UK simulation to calculate their net income. + +import pandas as pd +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, UKYearData, uk_latest +) +from policyengine.core import Simulation + +# Create person data +person_df = MicroDataFrame( + pd.DataFrame({ + "person_id": [0], + "person_household_id": [0], + "person_benunit_id": [0], + "age": [30], + "employment_income": [30000], + "person_weight": [1.0], + }), + weights="person_weight" +) + +# Create household data +household_df = MicroDataFrame( + pd.DataFrame({ + "household_id": [0], + "region": ["LONDON"], + "rent": [12000], # Typical London rent + "household_weight": [1.0], + }), + weights="household_weight" +) + +# Create benunit data +benunit_df = MicroDataFrame( + pd.DataFrame({ + "benunit_id": [0], + "would_claim_uc": [True], + "benunit_weight": [1.0], + }), + weights="benunit_weight" +) + +# Create and run simulation +dataset = PolicyEngineUKDataset( + name="Single person £30k", + filepath="./temp_scenario.h5", + year=2026, + data=UKYearData( + person=person_df, + household=household_df, + benunit=benunit_df, + ) +) + +simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, +) +simulation.run() + +# Extract results +output = simulation.output_dataset.data +net_income = output.household["household_net_income"].iloc[0] +income_tax = output.person["income_tax"].iloc[0] +ni = output.person["national_insurance"].iloc[0] + +print(f"Employment income: £30,000") +print(f"Income tax: £{income_tax:,.0f}") +print(f"National Insurance: £{ni:,.0f}") +print(f"Net income: £{net_income:,.0f}") +``` + +## Additional resources + +- Full examples in `examples/` directory +- Core concepts: `docs/core-concepts.md` +- UK model: `docs/country-models-uk.md` +- US model: `docs/country-models-us.md` diff --git a/.claude/quick-reference.md b/.claude/quick-reference.md new file mode 100644 index 00000000..d5b5c872 --- /dev/null +++ b/.claude/quick-reference.md @@ -0,0 +1,367 @@ +# PolicyEngine.py Quick Reference + +## Imports cheat sheet + +```python +# Core +from policyengine.core import Simulation, Policy, Parameter, ParameterValue + +# UK +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, + UKYearData, + uk_latest +) + +# US +from policyengine.tax_benefit_models.us import ( + PolicyEngineUSDataset, + USYearData, + us_latest +) + +# Outputs +from policyengine.outputs.aggregate import Aggregate, AggregateType +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType + +# Utilities +from policyengine.utils.plotting import format_fig, COLORS +from microdf import MicroDataFrame +import pandas as pd +import numpy as np +``` + +## Minimal working example (UK) + +```python +import pandas as pd +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.uk import ( + PolicyEngineUKDataset, UKYearData, uk_latest +) +from policyengine.core import Simulation + +# Person data +person_df = MicroDataFrame(pd.DataFrame({ + "person_id": [0], + "person_household_id": [0], + "person_benunit_id": [0], + "age": [30], + "employment_income": [30000], + "person_weight": [1.0], +}), weights="person_weight") + +# Household data +household_df = MicroDataFrame(pd.DataFrame({ + "household_id": [0], + "region": ["LONDON"], + "rent": [12000], + "household_weight": [1.0], +}), weights="household_weight") + +# Benunit data +benunit_df = MicroDataFrame(pd.DataFrame({ + "benunit_id": [0], + "would_claim_uc": [True], + "benunit_weight": [1.0], +}), weights="benunit_weight") + +# Create dataset +dataset = PolicyEngineUKDataset( + name="Example", + filepath="./temp.h5", + year=2026, + data=UKYearData(person=person_df, household=household_df, benunit=benunit_df) +) + +# Run simulation +sim = Simulation(dataset=dataset, tax_benefit_model_version=uk_latest) +sim.run() + +# Get results +output = sim.output_dataset.data +print(output.household[["household_net_income"]]) +``` + +## Minimal working example (US) + +```python +import pandas as pd +from microdf import MicroDataFrame +from policyengine.tax_benefit_models.us import ( + PolicyEngineUSDataset, USYearData, us_latest +) +from policyengine.core import Simulation + +# Person data (US requires more entity links) +person_df = MicroDataFrame(pd.DataFrame({ + "person_id": [0, 1], + "person_household_id": [0, 0], + "person_tax_unit_id": [0, 0], + "person_spm_unit_id": [0, 0], + "person_family_id": [0, 0], + "person_marital_unit_id": [0, 0], + "age": [35, 33], + "employment_income": [60000, 40000], + "person_weight": [1.0, 1.0], +}), weights="person_weight") + +# Create minimal entity dataframes +entities = {} +for entity in ["tax_unit", "spm_unit", "family", "marital_unit"]: + entities[entity] = MicroDataFrame(pd.DataFrame({ + f"{entity}_id": [0], + f"{entity}_weight": [1.0], + }), weights=f"{entity}_weight") + +household_df = MicroDataFrame(pd.DataFrame({ + "household_id": [0], + "state_code": ["CA"], + "household_weight": [1.0], +}), weights="household_weight") + +# Create dataset +dataset = PolicyEngineUSDataset( + name="Example", + filepath="./temp.h5", + year=2024, + data=USYearData( + person=person_df, + tax_unit=entities["tax_unit"], + spm_unit=entities["spm_unit"], + family=entities["family"], + marital_unit=entities["marital_unit"], + household=household_df, + ) +) + +# Run simulation +sim = Simulation(dataset=dataset, tax_benefit_model_version=us_latest) +sim.run() + +# Get results +print(sim.output_dataset.data.household[["household_net_income"]]) +``` + +## Common patterns + +### Parameter sweep (vary one input) +```python +n = 50 +incomes = np.linspace(0, 100000, n) + +person_df = MicroDataFrame(pd.DataFrame({ + "person_id": range(n), + "person_household_id": range(n), + "person_benunit_id": range(n), + "age": [30] * n, + "employment_income": incomes, + "person_weight": [1.0] * n, +}), weights="person_weight") + +# Create matching household/benunit data with n rows +# ... then run simulation once for all scenarios +``` + +### Policy reform +```python +import datetime +from policyengine.core import Policy, Parameter, ParameterValue + +parameter = Parameter( + name="gov.hmrc.income_tax.allowances.personal_allowance.amount", + tax_benefit_model_version=uk_latest, + description="Personal allowance", + data_type=float, +) + +policy = Policy( + name="Reform", + description="Change PA", + parameter_values=[ParameterValue( + parameter=parameter, + start_date=datetime.date(2026, 1, 1), + end_date=datetime.date(2026, 12, 31), + value=15000, + )] +) + +# Run with policy +reform_sim = Simulation(dataset=dataset, tax_benefit_model_version=uk_latest, policy=policy) +``` + +### Extract aggregate statistics +```python +from policyengine.outputs.aggregate import Aggregate, AggregateType + +# Sum +total = Aggregate( + simulation=sim, + variable="universal_credit", + entity="benunit", + aggregate_type=AggregateType.SUM, +) +total.run() + +# Mean +avg = Aggregate( + simulation=sim, + variable="household_net_income", + entity="household", + aggregate_type=AggregateType.MEAN, +) +avg.run() + +# Count with filter +count = Aggregate( + simulation=sim, + variable="person_id", + entity="person", + aggregate_type=AggregateType.COUNT, + filter_variable="age", + filter_geq=65, # Age >= 65 +) +count.run() +``` + +### Compare baseline vs reform +```python +from policyengine.outputs.change_aggregate import ChangeAggregate, ChangeAggregateType + +winners = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="household_net_income", + aggregate_type=ChangeAggregateType.COUNT, + change_geq=1, +) +winners.run() + +revenue = ChangeAggregate( + baseline_simulation=baseline_sim, + reform_simulation=reform_sim, + variable="household_tax", + aggregate_type=ChangeAggregateType.SUM, +) +revenue.run() +``` + +### Entity mapping +```python +# Sum person income to household +household_income = dataset.data.map_to_entity( + source_entity="person", + target_entity="household", + columns=["employment_income"], + how="sum" +) + +# Broadcast household rent to persons +person_rent = dataset.data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["rent"], + how="project" +) + +# Divide household value equally per person +per_person = dataset.data.map_to_entity( + source_entity="household", + target_entity="person", + columns=["total_savings"], + how="divide" +) + +# Map custom values +custom_totals = dataset.data.map_to_entity( + source_entity="person", + target_entity="household", + values=custom_array, + how="sum" +) +``` + +## Critical fields + +### UK +- **Person**: `person_id`, `person_household_id`, `person_benunit_id`, `age`, `employment_income`, `person_weight` +- **Household**: `household_id`, `region`, `rent`, `household_weight` +- **Benunit**: `benunit_id`, `would_claim_uc`, `benunit_weight` + +### US +- **Person**: `person_id`, `person_household_id`, `person_tax_unit_id`, `person_spm_unit_id`, `person_family_id`, `person_marital_unit_id`, `age`, `employment_income`, `person_weight` +- **Household**: `household_id`, `state_code`, `household_weight` +- **Other entities**: Each needs `{entity}_id` and `{entity}_weight` + +## Common UK regions +```python +["LONDON", "SOUTH_EAST", "SOUTH_WEST", "EAST_OF_ENGLAND", + "WEST_MIDLANDS", "EAST_MIDLANDS", "YORKSHIRE", + "NORTH_WEST", "NORTH_EAST", "WALES", "SCOTLAND", "NORTHERN_IRELAND"] +``` + +## Common US state codes +```python +["CA", "NY", "TX", "FL", "PA", "IL", "OH", "GA", "NC", "MI", ...] +``` + +## Aggregate filter options +```python +# Exact match +filter_eq=value + +# Greater than/equal +filter_geq=value + +# Less than/equal +filter_leq=value + +# Quantile filtering (deciles) +quantile=10 # Split into 10 groups +quantile_eq=1 # First decile only +quantile_geq=9 # Top two deciles +quantile_leq=2 # Bottom two deciles +``` + +## Common parameters + +### UK +``` +gov.hmrc.income_tax.allowances.personal_allowance.amount +gov.hmrc.income_tax.rates.uk[0] # Basic rate +gov.hmrc.national_insurance.class_1.rates.main +gov.dwp.universal_credit.means_test.reduction_rate +gov.dwp.universal_credit.elements.child.first_child +gov.dwp.child_benefit.amount.first_child +``` + +### US +``` +gov.irs.income.standard_deduction.single +gov.irs.income.standard_deduction.joint +gov.irs.credits.ctc.amount.base +gov.irs.credits.eitc.max[0] +gov.ssa.payroll.rate.employee +gov.usda.snap.normal_allotment.max[1] +``` + +## Troubleshooting + +| Issue | Solution | +|-------|----------| +| No UC calculated | Set `would_claim_uc=True` | +| Random UC spikes | Set `is_disabled_for_benefits=False`, `uc_limited_capability_for_WRA=False` | +| KeyError on column | Check variable name in docs, may be different entity level | +| Empty results | Check weights sum correctly, verify ID linkages | +| Slow performance | Use parameter sweep pattern (one simulation for N scenarios) | + +## Visualisation template +```python +from policyengine.utils.plotting import format_fig, COLORS +import plotly.graph_objects as go + +fig = go.Figure() +fig.add_trace(go.Scatter(x=x_vals, y=y_vals, line=dict(color=COLORS["primary"]))) +format_fig(fig, title="Title", xaxis_title="X", yaxis_title="Y") +fig.show() +``` From 63f3839fd8e96fbf5a79facaf460df03001bffcf Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 13:00:52 +0000 Subject: [PATCH 5/8] Add get_parameter and get_variable methods to TaxBenefitModelVersion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add convenience methods to look up parameters and variables by name: - get_parameter(name): Returns Parameter object by name - get_variable(name): Returns Variable object by name - Both raise ValueError if not found with helpful error messages Tests added (12 tests, all passing): - UK and US variable lookup tests - UK and US parameter lookup tests - Error handling tests for non-existent parameters/variables - Multiple parameter/variable lookup tests Usage: var = uk_latest.get_variable('income_tax') param = uk_latest.get_parameter('gov.hmrc.income_tax.allowances.personal_allowance.amount') 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../core/tax_benefit_model_version.py | 38 +++++ tests/test_get_parameter_variable.py | 130 ++++++++++++++++++ 2 files changed, 168 insertions(+) create mode 100644 tests/test_get_parameter_variable.py diff --git a/src/policyengine/core/tax_benefit_model_version.py b/src/policyengine/core/tax_benefit_model_version.py index 8555c6f6..b03c73eb 100644 --- a/src/policyengine/core/tax_benefit_model_version.py +++ b/src/policyengine/core/tax_benefit_model_version.py @@ -29,6 +29,44 @@ def run(self, simulation: "Simulation") -> "Simulation": "The TaxBenefitModel class must define a method to execute simulations." ) + def get_parameter(self, name: str) -> "Parameter": + """Get a parameter by name. + + Args: + name: The parameter name (e.g., "gov.hmrc.income_tax.allowances.personal_allowance.amount") + + Returns: + Parameter: The matching parameter + + Raises: + ValueError: If parameter not found + """ + for param in self.parameters: + if param.name == name: + return param + raise ValueError( + f"Parameter '{name}' not found in {self.model.id} version {self.version}" + ) + + def get_variable(self, name: str) -> "Variable": + """Get a variable by name. + + Args: + name: The variable name (e.g., "income_tax", "household_net_income") + + Returns: + Variable: The matching variable + + Raises: + ValueError: If variable not found + """ + for var in self.variables: + if var.name == name: + return var + raise ValueError( + f"Variable '{name}' not found in {self.model.id} version {self.version}" + ) + def __repr__(self) -> str: # Give the id and version, and the number of variables, parameters, parameter values return f"" diff --git a/tests/test_get_parameter_variable.py b/tests/test_get_parameter_variable.py new file mode 100644 index 00000000..1de2dd0a --- /dev/null +++ b/tests/test_get_parameter_variable.py @@ -0,0 +1,130 @@ +"""Tests for get_parameter and get_variable methods on TaxBenefitModelVersion.""" + +import pytest + +from policyengine.tax_benefit_models.uk import uk_latest +from policyengine.tax_benefit_models.us import us_latest + + +def test_uk_get_variable(): + """Test getting a variable by name from UK model.""" + # Get a known variable + var = uk_latest.get_variable("income_tax") + + assert var is not None + assert var.name == "income_tax" + assert var.entity == "person" + assert var.tax_benefit_model_version == uk_latest + + +def test_uk_get_variable_not_found(): + """Test error handling when variable doesn't exist.""" + with pytest.raises(ValueError, match="Variable 'nonexistent_variable' not found"): + uk_latest.get_variable("nonexistent_variable") + + +def test_uk_get_parameter(): + """Test getting a parameter by name from UK model.""" + # Get a known parameter + param = uk_latest.get_parameter( + "gov.hmrc.income_tax.allowances.personal_allowance.amount" + ) + + assert param is not None + assert param.name == "gov.hmrc.income_tax.allowances.personal_allowance.amount" + assert param.tax_benefit_model_version == uk_latest + + +def test_uk_get_parameter_not_found(): + """Test error handling when parameter doesn't exist.""" + with pytest.raises(ValueError, match="Parameter 'nonexistent.parameter' not found"): + uk_latest.get_parameter("nonexistent.parameter") + + +def test_us_get_variable(): + """Test getting a variable by name from US model.""" + # Get a known variable + var = us_latest.get_variable("income_tax") + + assert var is not None + assert var.name == "income_tax" + assert var.entity == "tax_unit" + assert var.tax_benefit_model_version == us_latest + + +def test_us_get_variable_not_found(): + """Test error handling when variable doesn't exist.""" + with pytest.raises(ValueError, match="Variable 'nonexistent_variable' not found"): + us_latest.get_variable("nonexistent_variable") + + +def test_us_get_parameter(): + """Test getting a parameter by name from US model.""" + # Get a known parameter + param = us_latest.get_parameter( + "gov.irs.investment.net_investment_income_tax.rate" + ) + + assert param is not None + assert param.name == "gov.irs.investment.net_investment_income_tax.rate" + assert param.tax_benefit_model_version == us_latest + + +def test_us_get_parameter_not_found(): + """Test error handling when parameter doesn't exist.""" + with pytest.raises(ValueError, match="Parameter 'nonexistent.parameter' not found"): + us_latest.get_parameter("nonexistent.parameter") + + +def test_uk_multiple_variables(): + """Test getting multiple different variables.""" + vars_to_test = [ + "income_tax", + "national_insurance", + "universal_credit", + "household_net_income", + ] + + for var_name in vars_to_test: + var = uk_latest.get_variable(var_name) + assert var.name == var_name + + +def test_us_multiple_variables(): + """Test getting multiple different variables.""" + vars_to_test = [ + "income_tax", + "employee_payroll_tax", + "eitc", + "household_net_income", + ] + + for var_name in vars_to_test: + var = us_latest.get_variable(var_name) + assert var.name == var_name + + +def test_uk_multiple_parameters(): + """Test getting multiple different parameters.""" + params_to_test = [ + "gov.hmrc.income_tax.allowances.personal_allowance.amount", + "gov.hmrc.income_tax.rates.uk[0].rate", + "gov.dwp.universal_credit.means_test.reduction_rate", + ] + + for param_name in params_to_test: + param = uk_latest.get_parameter(param_name) + assert param.name == param_name + + +def test_us_multiple_parameters(): + """Test getting multiple different parameters.""" + params_to_test = [ + "gov.irs.investment.net_investment_income_tax.rate", + "gov.irs.self_employment.rate.social_security", + "gov.irs.vita.eligibility.income_limit", + ] + + for param_name in params_to_test: + param = us_latest.get_parameter(param_name) + assert param.name == param_name From f73fa129dd5fd5d4a7fbc9d2d107f1854853f3c6 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 13:01:10 +0000 Subject: [PATCH 6/8] Apply formatting to test_get_parameter_variable.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/test_get_parameter_variable.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/test_get_parameter_variable.py b/tests/test_get_parameter_variable.py index 1de2dd0a..9c2b48d5 100644 --- a/tests/test_get_parameter_variable.py +++ b/tests/test_get_parameter_variable.py @@ -19,7 +19,9 @@ def test_uk_get_variable(): def test_uk_get_variable_not_found(): """Test error handling when variable doesn't exist.""" - with pytest.raises(ValueError, match="Variable 'nonexistent_variable' not found"): + with pytest.raises( + ValueError, match="Variable 'nonexistent_variable' not found" + ): uk_latest.get_variable("nonexistent_variable") @@ -31,13 +33,18 @@ def test_uk_get_parameter(): ) assert param is not None - assert param.name == "gov.hmrc.income_tax.allowances.personal_allowance.amount" + assert ( + param.name + == "gov.hmrc.income_tax.allowances.personal_allowance.amount" + ) assert param.tax_benefit_model_version == uk_latest def test_uk_get_parameter_not_found(): """Test error handling when parameter doesn't exist.""" - with pytest.raises(ValueError, match="Parameter 'nonexistent.parameter' not found"): + with pytest.raises( + ValueError, match="Parameter 'nonexistent.parameter' not found" + ): uk_latest.get_parameter("nonexistent.parameter") @@ -54,7 +61,9 @@ def test_us_get_variable(): def test_us_get_variable_not_found(): """Test error handling when variable doesn't exist.""" - with pytest.raises(ValueError, match="Variable 'nonexistent_variable' not found"): + with pytest.raises( + ValueError, match="Variable 'nonexistent_variable' not found" + ): us_latest.get_variable("nonexistent_variable") @@ -72,7 +81,9 @@ def test_us_get_parameter(): def test_us_get_parameter_not_found(): """Test error handling when parameter doesn't exist.""" - with pytest.raises(ValueError, match="Parameter 'nonexistent.parameter' not found"): + with pytest.raises( + ValueError, match="Parameter 'nonexistent.parameter' not found" + ): us_latest.get_parameter("nonexistent.parameter") From 91baf988dcc1e11e1404caaf42ad0f566a20e909 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 17 Nov 2025 17:26:59 +0000 Subject: [PATCH 7/8] Add parameter labels --- src/policyengine/core/parameter.py | 1 + src/policyengine/tax_benefit_models/uk/model.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/policyengine/core/parameter.py b/src/policyengine/core/parameter.py index 54e3e116..79c4f175 100644 --- a/src/policyengine/core/parameter.py +++ b/src/policyengine/core/parameter.py @@ -8,6 +8,7 @@ class Parameter(BaseModel): id: str = Field(default_factory=lambda: str(uuid4())) name: str + label: str | None = None description: str | None = None data_type: type | None = None tax_benefit_model_version: TaxBenefitModelVersion diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index 18f1ef25..ee0aa08a 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -83,6 +83,7 @@ def __init__(self, **kwargs: dict): parameter = Parameter( id=self.id + "-" + param_node.name, name=param_node.name, + label=param_node.metadata.get("label", param_node.name), tax_benefit_model_version=self, description=param_node.description, data_type=type( From ceaee0ab998f265ee1e3a3e53398b5cac2cf4772 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Tue, 18 Nov 2025 11:06:29 +0000 Subject: [PATCH 8/8] Final updates --- changelog_entry.yaml | 4 + examples/employment_income_variation_uk.py | 32 ---- examples/employment_income_variation_us.py | 46 ------ src/policyengine/core/simulation.py | 13 +- .../core/tax_benefit_model_version.py | 10 ++ .../tax_benefit_models/uk/model.py | 154 ++++++++++-------- .../tax_benefit_models/us/model.py | 126 +++++++------- tests/test_us_simulation.py | 13 -- 8 files changed, 174 insertions(+), 224 deletions(-) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..1311929d 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Standardised saving and loading of simulations. diff --git a/examples/employment_income_variation_uk.py b/examples/employment_income_variation_uk.py index 173c78ff..22bcb93c 100644 --- a/examples/employment_income_variation_uk.py +++ b/examples/employment_income_variation_uk.py @@ -162,42 +162,10 @@ def create_dataset_with_varied_employment_income( def run_simulation(dataset: PolicyEngineUKDataset) -> Simulation: """Run a single simulation for all employment income variations.""" - # Specify additional variables to calculate beyond defaults - variables = { - "household": [ - # Default variables - "household_id", - "household_weight", - "household_net_income", - "hbai_household_net_income", - "household_benefits", - "household_tax", - ], - "person": [ - "person_id", - "benunit_id", - "household_id", - "person_weight", - "employment_income", - "age", - ], - "benunit": [ - "benunit_id", - "benunit_weight", - # Individual benefits (at benunit level) - "universal_credit", - "child_benefit", - "working_tax_credit", - "child_tax_credit", - "pension_credit", - "income_support", - ], - } simulation = Simulation( dataset=dataset, tax_benefit_model_version=uk_latest, - variables=variables, ) simulation.run() return simulation diff --git a/examples/employment_income_variation_us.py b/examples/employment_income_variation_us.py index 863d8018..f4ceb80e 100644 --- a/examples/employment_income_variation_us.py +++ b/examples/employment_income_variation_us.py @@ -171,56 +171,10 @@ def create_dataset_with_varied_employment_income( def run_simulation(dataset: PolicyEngineUSDataset) -> Simulation: """Run a single simulation for all employment income variations.""" - # Specify variables to calculate - variables = { - "household": [ - "household_id", - "household_weight", - "household_net_income", - "household_benefits", - "household_tax", - "household_market_income", - ], - "person": [ - "person_id", - "household_id", - "marital_unit_id", - "family_id", - "spm_unit_id", - "tax_unit_id", - "person_weight", - "employment_income", - "age", - ], - "spm_unit": [ - "spm_unit_id", - "spm_unit_weight", - "snap", - "tanf", - "spm_unit_net_income", - ], - "tax_unit": [ - "tax_unit_id", - "tax_unit_weight", - "income_tax", - "employee_payroll_tax", - "eitc", - "ctc", - ], - "marital_unit": [ - "marital_unit_id", - "marital_unit_weight", - ], - "family": [ - "family_id", - "family_weight", - ], - } simulation = Simulation( dataset=dataset, tax_benefit_model_version=us_latest, - variables=variables, ) simulation.run() return simulation diff --git a/src/policyengine/core/simulation.py b/src/policyengine/core/simulation.py index 1e493b9a..f7c214e4 100644 --- a/src/policyengine/core/simulation.py +++ b/src/policyengine/core/simulation.py @@ -21,10 +21,13 @@ class Simulation(BaseModel): tax_benefit_model_version: TaxBenefitModelVersion = None output_dataset: Dataset | None = None - variables: dict[str, list[str]] | None = Field( - default=None, - description="Optional dictionary mapping entity names to lists of variable names to calculate. If None, uses model defaults.", - ) - def run(self): self.tax_benefit_model_version.run(self) + + def save(self): + """Save the simulation's output dataset.""" + self.tax_benefit_model_version.save(self) + + def load(self): + """Load the simulation's output dataset.""" + self.tax_benefit_model_version.load(self) diff --git a/src/policyengine/core/tax_benefit_model_version.py b/src/policyengine/core/tax_benefit_model_version.py index b03c73eb..53b9936e 100644 --- a/src/policyengine/core/tax_benefit_model_version.py +++ b/src/policyengine/core/tax_benefit_model_version.py @@ -29,6 +29,16 @@ def run(self, simulation: "Simulation") -> "Simulation": "The TaxBenefitModel class must define a method to execute simulations." ) + def save(self, simulation: "Simulation"): + raise NotImplementedError( + "The TaxBenefitModel class must define a method to save simulations." + ) + + def load(self, simulation: "Simulation"): + raise NotImplementedError( + "The TaxBenefitModel class must define a method to load simulations." + ) + def get_parameter(self, name: str) -> "Parameter": """Get a parameter by name. diff --git a/src/policyengine/tax_benefit_models/uk/model.py b/src/policyengine/tax_benefit_models/uk/model.py index ee0aa08a..6b8c5c76 100644 --- a/src/policyengine/tax_benefit_models/uk/model.py +++ b/src/policyengine/tax_benefit_models/uk/model.py @@ -153,77 +153,72 @@ def run(self, simulation: "Simulation") -> "Simulation": ) modifier(microsim) - # Allow custom variable selection, or use defaults - if simulation.variables is not None: - entity_variables = simulation.variables - else: - # Default comprehensive variable set - entity_variables = { - "person": [ - # IDs and weights - "person_id", - "benunit_id", - "household_id", - "person_weight", - # Demographics - "age", - "gender", - "is_adult", - "is_SP_age", - "is_child", - # Income - "employment_income", - "self_employment_income", - "pension_income", - "private_pension_income", - "savings_interest_income", - "dividend_income", - "property_income", - "total_income", - "earned_income", - # Benefits - "universal_credit", - "child_benefit", - "pension_credit", - "income_support", - "working_tax_credit", - "child_tax_credit", - # Tax - "income_tax", - "national_insurance", - ], - "benunit": [ - # IDs and weights - "benunit_id", - "benunit_weight", - # Structure - "family_type", - # Income and benefits - "universal_credit", - "child_benefit", - "working_tax_credit", - "child_tax_credit", - ], - "household": [ - # IDs and weights - "household_id", - "household_weight", - # Income measures - "household_net_income", - "hbai_household_net_income", - "equiv_hbai_household_net_income", - "household_market_income", - "household_gross_income", - # Benefits and tax - "household_benefits", - "household_tax", - "vat", - # Housing - "rent", - "council_tax", - "tenure_type", - ], - } + entity_variables = { + "person": [ + # IDs and weights + "person_id", + "benunit_id", + "household_id", + "person_weight", + # Demographics + "age", + "gender", + "is_adult", + "is_SP_age", + "is_child", + # Income + "employment_income", + "self_employment_income", + "pension_income", + "private_pension_income", + "savings_interest_income", + "dividend_income", + "property_income", + "total_income", + "earned_income", + # Benefits + "universal_credit", + "child_benefit", + "pension_credit", + "income_support", + "working_tax_credit", + "child_tax_credit", + # Tax + "income_tax", + "national_insurance", + ], + "benunit": [ + # IDs and weights + "benunit_id", + "benunit_weight", + # Structure + "family_type", + # Income and benefits + "universal_credit", + "child_benefit", + "working_tax_credit", + "child_tax_credit", + ], + "household": [ + # IDs and weights + "household_id", + "household_weight", + # Income measures + "household_net_income", + "hbai_household_net_income", + "equiv_hbai_household_net_income", + "household_market_income", + "household_gross_income", + # Benefits and tax + "household_benefits", + "household_tax", + "vat", + # Housing + "rent", + "council_tax", + "tenure_type", + ], + } data = { "person": pd.DataFrame(), @@ -248,6 +243,7 @@ def run(self, simulation: "Simulation") -> "Simulation": ) simulation.output_dataset = PolicyEngineUKDataset( + id=simulation.id, name=dataset.name, description=dataset.description, filepath=str( @@ -263,7 +259,23 @@ def run(self, simulation: "Simulation") -> "Simulation": ), ) + def save(self, simulation: "Simulation"): + """Save the simulation's output dataset.""" simulation.output_dataset.save() + def load(self, simulation: "Simulation"): + """Load the simulation's output dataset.""" + simulation.output_dataset = PolicyEngineUKDataset( + id=simulation.id, + name=simulation.dataset.name, + description=simulation.dataset.description, + filepath=str( + Path(simulation.dataset.filepath).parent + / (simulation.id + ".h5") + ), + year=simulation.dataset.year, + is_output_dataset=True, + ) + uk_latest = PolicyEngineUKLatest() diff --git a/src/policyengine/tax_benefit_models/us/model.py b/src/policyengine/tax_benefit_models/us/model.py index 5e2068c5..a5a267a2 100644 --- a/src/policyengine/tax_benefit_models/us/model.py +++ b/src/policyengine/tax_benefit_models/us/model.py @@ -156,63 +156,58 @@ def run(self, simulation: "Simulation") -> "Simulation": ) modifier(microsim) - # Allow custom variable selection, or use defaults - if simulation.variables is not None: - entity_variables = simulation.variables - else: - # Default comprehensive variable set - entity_variables = { - "person": [ - # IDs and weights - "person_id", - "marital_unit_id", - "family_id", - "spm_unit_id", - "tax_unit_id", - "household_id", - "person_weight", - # Demographics - "age", - # Income - "employment_income", - # Benefits - "ssi", - "social_security", - "medicaid", - "unemployment_compensation", - ], - "marital_unit": [ - "marital_unit_id", - "marital_unit_weight", - ], - "family": [ - "family_id", - "family_weight", - ], - "spm_unit": [ - "spm_unit_id", - "spm_unit_weight", - "snap", - "tanf", - "spm_unit_net_income", - ], - "tax_unit": [ - "tax_unit_id", - "tax_unit_weight", - "income_tax", - "employee_payroll_tax", - "eitc", - "ctc", - ], - "household": [ - "household_id", - "household_weight", - "household_net_income", - "household_benefits", - "household_tax", - "household_market_income", - ], - } + entity_variables = { + "person": [ + # IDs and weights + "person_id", + "marital_unit_id", + "family_id", + "spm_unit_id", + "tax_unit_id", + "household_id", + "person_weight", + # Demographics + "age", + # Income + "employment_income", + # Benefits + "ssi", + "social_security", + "medicaid", + "unemployment_compensation", + ], + "marital_unit": [ + "marital_unit_id", + "marital_unit_weight", + ], + "family": [ + "family_id", + "family_weight", + ], + "spm_unit": [ + "spm_unit_id", + "spm_unit_weight", + "snap", + "tanf", + "spm_unit_net_income", + ], + "tax_unit": [ + "tax_unit_id", + "tax_unit_weight", + "income_tax", + "employee_payroll_tax", + "eitc", + "ctc", + ], + "household": [ + "household_id", + "household_weight", + "household_net_income", + "household_benefits", + "household_tax", + "household_market_income", + ], + } data = { "person": pd.DataFrame(), @@ -291,6 +286,7 @@ def run(self, simulation: "Simulation") -> "Simulation": ) simulation.output_dataset = PolicyEngineUSDataset( + id=simulation.id, name=dataset.name, description=dataset.description, filepath=str( @@ -309,8 +305,24 @@ def run(self, simulation: "Simulation") -> "Simulation": ), ) + def save(self, simulation: "Simulation"): + """Save the simulation's output dataset.""" simulation.output_dataset.save() + def load(self, simulation: "Simulation"): + """Load the simulation's output dataset.""" + simulation.output_dataset = PolicyEngineUSDataset( + id=simulation.id, + name=simulation.dataset.name, + description=simulation.dataset.description, + filepath=str( + Path(simulation.dataset.filepath).parent + / (simulation.id + ".h5") + ), + year=simulation.dataset.year, + is_output_dataset=True, + ) + def _build_simulation_from_dataset(self, microsim, dataset, system): """Build a PolicyEngine Core simulation from dataset entity IDs. diff --git a/tests/test_us_simulation.py b/tests/test_us_simulation.py index 4de79691..aad5f9bb 100644 --- a/tests/test_us_simulation.py +++ b/tests/test_us_simulation.py @@ -227,19 +227,6 @@ def test_us_simulation_from_dataset(): simulation = Simulation( dataset=dataset, tax_benefit_model_version=us_latest, - variables={ - "person": [ - "person_id", - "person_weight", - "age", - "employment_income", - ], - "household": ["household_id", "household_weight"], - "marital_unit": ["marital_unit_id", "marital_unit_weight"], - "family": ["family_id", "family_weight"], - "spm_unit": ["spm_unit_id", "spm_unit_weight"], - "tax_unit": ["tax_unit_id", "tax_unit_weight"], - }, ) simulation.run()