# Proving the Bug in policyengine_uk's build_from_dataframe Method

This notebook proves that the UK country filtering bug is caused by `policyengine_uk`'s 
`build_from_dataframe` method not handling entity-level aggregation.

## The Bug Location
**File:** `policyengine_uk/simulation.py`  
**Method:** `build_from_dataframe()`  
**Lines:** 281-286

```python
# Set input values for each variable and time period
for column in df:
    variable, time_period = column.split("__")
    if variable not in self.tax_benefit_system.variables:
        continue
    self.set_input(variable, time_period, df[column])  # <-- BUG: No entity-level check!
```

## The Problem
1. `to_input_dataframe()` exports ALL variables at **person level** (one row per person)
2. `build_from_dataframe()` correctly builds entity structure with proper counts
3. BUT it then tries to `set_input()` with person-level arrays for household-level variables
4. This causes a length mismatch error

## Step 1: Setup

In [1]:
import numpy as np
import pandas as pd
import traceback
import inspect

from policyengine import Simulation
from policyengine_uk import Simulation as UKSimulation

# Show where policyengine_uk is loaded from
import policyengine_uk
version = getattr(policyengine_uk, '__version__', 'unknown')
print(f"policyengine_uk version: {version}")
print(f"policyengine_uk location: {policyengine_uk.__file__}")

policyengine_uk version: unknown
policyengine_uk location: /opt/miniconda3/envs/py-3.13/lib/python3.13/site-packages/policyengine_uk/__init__.py


## Step 2: Examine the Buggy Code

Let's look at the actual `build_from_dataframe` method to confirm the bug.

In [2]:
# Show the source code of build_from_dataframe
print("=== build_from_dataframe source code ===")
print(inspect.getsource(UKSimulation.build_from_dataframe))

=== build_from_dataframe source code ===
    def build_from_dataframe(self, df: pd.DataFrame) -> None:
        """Build simulation from a pandas DataFrame.

        Args:
            df: DataFrame with columns in format "variable_name__time_period"
        """

        def get_first_array(variable_name: str) -> pd.Series:
            """Extract the first array for a given variable name pattern."""
            columns = df.columns[df.columns.str.contains(variable_name + "__")]
            return df[columns[0]]

        # Extract ID columns
        (
            person_id,
            person_benunit_id,
            person_household_id,
            benunit_id,
            household_id,
        ) = map(
            get_first_array,
            [
                "person_id",
                "person_benunit_id",
                "person_household_id",
                "benunit_id",
                "household_id",
            ],
        )

        # Build entity structure
        self.build_fro

## Step 3: Create Test Data

Create a UK simulation and export to DataFrame, then filter to Wales.

In [3]:
# Create UK-wide simulation
print("Creating UK-wide simulation...")
sim_uk = Simulation(country="uk", scope="macro")
underlying_sim = sim_uk.baseline_simulation

print(f"\nUK-wide entity counts:")
print(f"  Persons: {underlying_sim.persons.count:,}")
print(f"  Households: {underlying_sim.household.count:,}")

Creating UK-wide simulation...


No data provided, using default dataset: gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5
Using dataset: gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5
Downloading enhanced_frs_2023_24.h5 from bucket policyengine-uk-data-private



UK-wide entity counts:
  Persons: 115,612
  Households: 53,508


In [4]:
# Export to DataFrame and filter to Wales
print("Exporting to DataFrame...")
df = underlying_sim.to_input_dataframe()

# Filter to Wales
country_person = underlying_sim.calculate("country", map_to="person").values
wales_mask = country_person == "WALES"
df_wales = df[wales_mask]

print(f"\nFiltered DataFrame:")
print(f"  Rows (Welsh persons): {len(df_wales):,}")
print(f"  Columns: {len(df_wales.columns):,}")

Exporting to DataFrame...

Filtered DataFrame:
  Rows (Welsh persons): 8,470
  Columns: 1,127


## Step 4: Prove the DataFrame Has Person-Level Data for Household Variables

This is the key insight: `to_input_dataframe()` exports EVERYTHING at person level.

In [5]:
# Find household-level variables in the DataFrame
print("=== Household-Level Variables in DataFrame ===")

tax_benefit_system = underlying_sim.tax_benefit_system
household_vars_in_df = []

for col in df_wales.columns:
    var_name = col.split("__")[0]
    if var_name in tax_benefit_system.variables:
        var_meta = tax_benefit_system.get_variable(var_name)
        if var_meta.entity.key == "household":
            household_vars_in_df.append((col, var_name))

print(f"Found {len(household_vars_in_df)} household-level variable columns in DataFrame")
print(f"\nFirst 10 household variables:")
for col, var_name in household_vars_in_df[:10]:
    print(f"  - {col}")

=== Household-Level Variables in DataFrame ===
Found 392 household-level variable columns in DataFrame

First 10 household variables:
  - corporate_wealth__2023
  - corporate_wealth__2024
  - corporate_wealth__2025
  - corporate_wealth__2026
  - corporate_wealth__2027
  - corporate_wealth__2028
  - corporate_wealth__2029
  - corporate_wealth__2030
  - non_residential_property_value__2023
  - non_residential_property_value__2024


In [6]:
# Show the mismatch: DataFrame rows vs expected household count
print("=== THE CRITICAL MISMATCH ===")
print()

# Get expected Welsh household count from person_household_id
phh_col = [c for c in df_wales.columns if c.startswith('person_household_id__')][0]
welsh_household_count = df_wales[phh_col].nunique()

print(f"DataFrame rows (person-level): {len(df_wales):,}")
print(f"Expected Welsh households: {welsh_household_count:,}")
print()

# Show a specific household variable
example_var = "corporate_wealth__2025" if "corporate_wealth__2025" in df_wales.columns else household_vars_in_df[0][0]
print(f"Example: '{example_var}'")
print(f"  Data length in DataFrame: {len(df_wales[example_var]):,}")
print(f"  Should be (household count): {welsh_household_count:,}")
print()
print(f"  MISMATCH: {len(df_wales[example_var]):,} != {welsh_household_count:,}")
print()
print("This is why set_input() fails!")

=== THE CRITICAL MISMATCH ===

DataFrame rows (person-level): 8,470
Expected Welsh households: 4,108

Example: 'corporate_wealth__2025'
  Data length in DataFrame: 8,470
  Should be (household count): 4,108

  MISMATCH: 8,470 != 4,108

This is why set_input() fails!


## Step 5: Trace Through build_from_dataframe Step-by-Step

Let's manually execute what `build_from_dataframe` does to see exactly where it fails.

In [7]:
# Step 5a: Extract ID columns (lines 249-270 of build_from_dataframe)
print("=== Step 5a: Extract ID columns ===")

def get_first_array(df, variable_name):
    columns = df.columns[df.columns.str.contains(variable_name + "__")]
    return df[columns[0]]

person_id = get_first_array(df_wales, "person_id")
person_benunit_id = get_first_array(df_wales, "person_benunit_id")
person_household_id = get_first_array(df_wales, "person_household_id")
benunit_id = get_first_array(df_wales, "benunit_id")
household_id = get_first_array(df_wales, "household_id")

print(f"person_id length: {len(person_id)}")
print(f"person_household_id length: {len(person_household_id)}")
print(f"person_household_id unique values: {person_household_id.nunique()}")
print(f"household_id length: {len(household_id)}")
print(f"household_id unique values: {household_id.nunique()}")

=== Step 5a: Extract ID columns ===
person_id length: 8470
person_household_id length: 8470
person_household_id unique values: 4108
household_id length: 8470
household_id unique values: 4108


In [8]:
# Step 5b: Build entity structure (lines 273-279 - build_from_ids)
print("\n=== Step 5b: Build entity structure (build_from_ids) ===")

from policyengine_core.simulations.simulation_builder import SimulationBuilder
from policyengine_uk.tax_benefit_system import CountryTaxBenefitSystem

# Create a fresh simulation to test
test_tbs = CountryTaxBenefitSystem()
builder = SimulationBuilder()
builder.populations = test_tbs.instantiate_entities()

# Declare entities - this is what build_from_ids does
builder.declare_person_entity("person", person_id.values)
builder.declare_entity("benunit", np.unique(benunit_id.values))
builder.declare_entity("household", np.unique(household_id.values))

print(f"Person entity count: {len(builder.populations['person'].ids)}")
print(f"Benunit entity count: {len(builder.populations['benunit'].ids)}")
print(f"Household entity count: {len(builder.populations['household'].ids)}")
print()
print("Entity structure is CORRECT! 4108 households were created.")


=== Step 5b: Build entity structure (build_from_ids) ===
Person entity count: 8470
Benunit entity count: 4664
Household entity count: 4108

Entity structure is CORRECT! 4108 households were created.


In [9]:
# Step 5c: Complete entity setup with joins
print("\n=== Step 5c: Complete entity setup ===")

builder.join_with_persons(
    builder.populations["benunit"],
    person_benunit_id.values,
    np.array(["member"] * len(person_benunit_id)),
)
builder.join_with_persons(
    builder.populations["household"],
    person_household_id.values,
    np.array(["member"] * len(person_household_id)),
)

# Create simulation with these populations
from policyengine_core.simulations import Simulation as CoreSimulation
from policyengine_core.tracers import SimpleTracer

class TestSimulation(CoreSimulation):
    default_input_period = 2025
    default_calculation_period = 2025

test_sim = TestSimulation.__new__(TestSimulation)
test_sim.tax_benefit_system = test_tbs
test_sim.branch_name = "default"
test_sim.invalidated_caches = set()
test_sim.branches = {}

# Initialize required attributes that build_from_populations expects
test_sim.debug = False
test_sim.trace = False
test_sim.tracer = SimpleTracer()
test_sim.opt_out_cache = False
test_sim.max_spiral_loops = 10
test_sim.memory_config = None
test_sim._data_storage_dir = None

test_sim.build_from_populations(builder.populations)

print(f"Test simulation created:")
print(f"  Persons: {test_sim.persons.count}")
print(f"  Households: {test_sim.household.count}")
print()
print("Entity counts are CORRECT at this point!")


=== Step 5c: Complete entity setup ===
Test simulation created:
  Persons: 8470
  Households: 4108

Entity counts are CORRECT at this point!


In [10]:
# Step 5d: THE BUG - Try to set_input for a household variable with person-level data
print("\n=== Step 5d: THE BUG - set_input without aggregation ===")
print()

# This is what build_from_dataframe does at lines 281-286:
# for column in df:
#     variable, time_period = column.split("__")
#     if variable not in self.tax_benefit_system.variables:
#         continue
#     self.set_input(variable, time_period, df[column])  # <-- BUG!

# Let's simulate this for a household variable
test_column = example_var
variable_name, time_period = test_column.split("__")

print(f"Attempting to set '{variable_name}' for period {time_period}")
print(f"  Variable entity: {test_tbs.get_variable(variable_name).entity.key}")
print(f"  Data length: {len(df_wales[test_column])}")
print(f"  Household count: {test_sim.household.count}")
print()

try:
    test_sim.set_input(variable_name, time_period, df_wales[test_column].values)
    print("SUCCESS - No error (unexpected!)")
except ValueError as e:
    print(f"ERROR (expected): {e}")
    print()
    print("="*60)
    print("BUG PROVEN!")
    print("="*60)
    print()
    print("The build_from_dataframe method calls set_input() with")
    print("person-level data (8470 values) for a household-level")
    print(f"variable, but there are only {test_sim.household.count} households.")


=== Step 5d: THE BUG - set_input without aggregation ===

Attempting to set 'corporate_wealth' for period 2025
  Variable entity: household
  Data length: 8470
  Household count: 4108

ERROR (expected): Unable to set value "[ 42531.723   42531.723   42531.723  ... 145237.94   145237.94
   6483.3296]" for variable "corporate_wealth", as its length is 8470 while there are 4108 households in the simulation.

BUG PROVEN!

The build_from_dataframe method calls set_input() with
person-level data (8470 values) for a household-level
variable, but there are only 4108 households.


## Step 6: Show What the Fix Should Look Like

The fix needs to check if aggregation is required before calling `set_input()`.

In [11]:
# Demonstrate the correct approach: aggregate before set_input
print("=== The Fix: Aggregate Before set_input ===")
print()

variable_name, time_period = example_var.split("__")
var_meta = test_tbs.get_variable(variable_name)
entity = var_meta.entity
population = test_sim.get_population(entity.plural)

data = df_wales[example_var].values

print(f"Variable: {variable_name}")
print(f"Entity: {entity.key}")
print(f"Data length: {len(data)}")
print(f"Population count: {population.count}")
print()

# Check if aggregation is needed
if len(data) != population.count:
    print(f"Aggregation needed: {len(data)} != {population.count}")
    print()
    
    # Use value_from_first_person to aggregate
    aggregated_data = population.value_from_first_person(data)
    print(f"After aggregation: {len(aggregated_data)} values")
    print()
    
    # Now set_input should work
    try:
        test_sim.set_input(variable_name, time_period, aggregated_data)
        print(f"SUCCESS! set_input worked with aggregated data.")
    except Exception as e:
        print(f"Still failed: {e}")
else:
    print("No aggregation needed")

=== The Fix: Aggregate Before set_input ===

Variable: corporate_wealth
Entity: household
Data length: 8470
Population count: 4108

Aggregation needed: 8470 != 4108

After aggregation: 4108 values

SUCCESS! set_input worked with aggregated data.


## Step 7: Show the Required Code Fix

Here's what the fixed `build_from_dataframe` method should look like.

In [12]:
print("=== Required Fix for build_from_dataframe ===")
print()
print("""CURRENT CODE (buggy):
```python
# Set input values for each variable and time period
for column in df:
    variable, time_period = column.split("__")
    if variable not in self.tax_benefit_system.variables:
        continue
    self.set_input(variable, time_period, df[column])
```

FIXED CODE:
```python
# Set input values for each variable and time period
for column in df:
    variable, time_period = column.split("__")
    if variable not in self.tax_benefit_system.variables:
        continue
    
    # Get variable metadata and target population
    var_meta = self.tax_benefit_system.get_variable(variable)
    entity = var_meta.entity
    population = self.get_population(entity.plural)
    
    data = df[column].values
    
    # Check if aggregation is needed (data is person-level but variable is group-level)
    if len(data) != population.count:
        # Aggregate from person-level to entity-level
        data = population.value_from_first_person(data)
    
    self.set_input(variable, time_period, data)
```
""")

=== Required Fix for build_from_dataframe ===

CURRENT CODE (buggy):
```python
# Set input values for each variable and time period
for column in df:
    variable, time_period = column.split("__")
    if variable not in self.tax_benefit_system.variables:
        continue
    self.set_input(variable, time_period, df[column])
```

FIXED CODE:
```python
# Set input values for each variable and time period
for column in df:
    variable, time_period = column.split("__")
    if variable not in self.tax_benefit_system.variables:
        continue
    
    # Get variable metadata and target population
    var_meta = self.tax_benefit_system.get_variable(variable)
    entity = var_meta.entity
    population = self.get_population(entity.plural)
    
    data = df[column].values
    
    # Check if aggregation is needed (data is person-level but variable is group-level)
    if len(data) != population.count:
        # Aggregate from person-level to entity-level
        data = population.value_from_

## Summary

In [13]:
print("="*70)
print("SUMMARY: BUG PROVEN")
print("="*70)
print("""
LOCATION:
  File: policyengine_uk/simulation.py
  Method: build_from_dataframe()
  Lines: 281-286

ROOT CAUSE:
  The method iterates through DataFrame columns and calls set_input()
  without checking if the data length matches the target entity count.
  
  - to_input_dataframe() exports ALL variables at PERSON level
  - build_from_ids() correctly creates entity structure (e.g., 4108 households)
  - BUT the loop then tries to set 8470 person-level values for 
    household-level variables that only have 4108 entities

THE FIX:
  Before calling set_input(), check if len(data) != population.count.
  If so, aggregate using population.value_from_first_person(data).

NOTE:
  This is the same aggregation logic that policyengine_core's
  build_from_dataset() method uses (simulation.py lines 406-414).
  The policyengine_uk version simply forgot to include it.
""")

SUMMARY: BUG PROVEN

LOCATION:
  File: policyengine_uk/simulation.py
  Method: build_from_dataframe()
  Lines: 281-286

ROOT CAUSE:
  The method iterates through DataFrame columns and calls set_input()
  without checking if the data length matches the target entity count.
  
  - to_input_dataframe() exports ALL variables at PERSON level
  - build_from_ids() correctly creates entity structure (e.g., 4108 households)
  - BUT the loop then tries to set 8470 person-level values for 
    household-level variables that only have 4108 entities

THE FIX:
  Before calling set_input(), check if len(data) != population.count.
  If so, aggregate using population.value_from_first_person(data).

NOTE:
  This is the same aggregation logic that policyengine_core's
  build_from_dataset() method uses (simulation.py lines 406-414).
  The policyengine_uk version simply forgot to include it.

