# PolicyEngine Microsimulation from Flat Files

Run PolicyEngine tax-benefit microsimulations on your own CSV datasets.

## How It Works

1. You provide a CSV with household/tax unit data
2. This notebook expands it to PolicyEngine's entity structure
3. PolicyEngine calculates taxes, credits, and benefits
4. Results are returned as a DataFrame (one row per tax unit)

## Supported Input Formats

| Format | Use When |
|--------|----------|
| `tax_unit` | One row per tax filing unit (most common) |
| `person` | Need different income for each person |
| `household` | Multiple tax units share a household |

## Quick Start

1. Run all cells in the **Setup** section
2. Prepare your CSV (see **Input Format Reference** below)
3. Call `run_microsim()` with your file path

## Finding Variable Names

- **Input variables:** [policyengine.org/us/api/variables](https://policyengine.org/us/api/variables)
- **Common inputs:** `employment_income`, `self_employment_income`, `social_security`
- **Common outputs:** `income_tax`, `state_income_tax`, `eitc`, `ctc`

---
## Setup

Run all cells in this section first.

In [None]:
# Install dependencies if needed (uncomment if running in Colab or fresh environment)
# !pip install policyengine-us pandas numpy tqdm

In [None]:
# Core dependencies
import pandas as pd
import numpy as np
import tempfile
from pathlib import Path
from typing import Dict, Set
from tqdm.auto import tqdm

# PolicyEngine
from policyengine_us import Microsimulation
from policyengine_core.data import Dataset

print("Imports successful!")

In [None]:
# =============================================================================
# Constants and Utilities
# =============================================================================

# Filing status: map string names to PolicyEngine's integer codes
# PolicyEngine uses: 1=SINGLE, 2=JOINT, 3=SEPARATE, 4=HEAD_OF_HOUSEHOLD, 5=WIDOW
FILING_STATUS_MAP = {
    "SINGLE": 1, "JOINT": 2, "SEPARATE": 3, "HEAD_OF_HOUSEHOLD": 4, "WIDOW": 5,
    1: 1, 2: 2, 3: 3, 4: 4, 5: 5,  # Also accept integers directly
}

# State codes to FIPS (PolicyEngine uses FIPS codes internally)
STATE_CODE_TO_FIPS = {
    'AL': 1, 'AK': 2, 'AZ': 4, 'AR': 5, 'CA': 6, 'CO': 8, 'CT': 9, 'DE': 10,
    'DC': 11, 'FL': 12, 'GA': 13, 'HI': 15, 'ID': 16, 'IL': 17, 'IN': 18,
    'IA': 19, 'KS': 20, 'KY': 21, 'LA': 22, 'ME': 23, 'MD': 24, 'MA': 25,
    'MI': 26, 'MN': 27, 'MS': 28, 'MO': 29, 'MT': 30, 'NE': 31, 'NV': 32,
    'NH': 33, 'NJ': 34, 'NM': 35, 'NY': 36, 'NC': 37, 'ND': 38, 'OH': 39,
    'OK': 40, 'OR': 41, 'PA': 42, 'RI': 44, 'SC': 45, 'SD': 46, 'TN': 47,
    'TX': 48, 'UT': 49, 'VT': 50, 'VA': 51, 'WA': 53, 'WV': 54, 'WI': 55, 'WY': 56
}

# Get PolicyEngine's tax-benefit system for auto-detecting variable entities
from policyengine_us import Simulation as _Sim
_TAX_BENEFIT_SYSTEM = _Sim.default_tax_benefit_system()


def get_variable_entity(var_name: str) -> str:
    """
    Get the entity type for a PolicyEngine variable.
    
    PolicyEngine variables belong to different entities:
    - person: individual-level (employment_income, age, etc.)
    - tax_unit: filing unit level (tax credits, filing status, etc.)
    - household: household level (in_nyc, state_fips, etc.)
    - family, spm_unit, marital_unit: other group entities
    
    Returns entity key or "person" if variable not found.
    """
    var = _TAX_BENEFIT_SYSTEM.variables.get(var_name)
    if var:
        return var.entity.key
    return "person"  # Default to person if unknown

In [None]:
# =============================================================================
# Input Parsers
# =============================================================================


def parse_person_format(df: pd.DataFrame) -> pd.DataFrame:
    """Parse person-level input (one row per person)."""
    required = ["person_id", "household_id", "tax_unit_id", "age", "year",
                "is_tax_unit_head", "is_tax_unit_spouse", "is_tax_unit_dependent"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")
    
    result = df.copy()
    if "state_code" in result.columns:
        result["state_code"] = result["state_code"].str.upper()
    else:
        result["state_code"] = "CA"
    for col in ["is_tax_unit_head", "is_tax_unit_spouse", "is_tax_unit_dependent"]:
        result[col] = result[col].astype(bool)
    return result


def parse_tax_unit_format(df: pd.DataFrame) -> pd.DataFrame:
    """
    Parse tax-unit-level input (one row per tax unit).
    Only tax_unit_id and year are required. Expands to person rows.
    """
    required = ["tax_unit_id", "year"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")
    
    structural = {"tax_unit_id", "household_id", "year", "state_code",
                  "filing_status", "age_head", "age_spouse", "num_dependents", "dependent_ages"}
    pe_vars = [c for c in df.columns if c not in structural]
    
    household_vars = {v for v in pe_vars if get_variable_entity(v) == "household"}
    tax_unit_vars = {v for v in pe_vars if get_variable_entity(v) == "tax_unit"}
    person_vars = {v for v in pe_vars if v not in household_vars and v not in tax_unit_vars}
    
    persons = []
    for idx, row in df.iterrows():
        tax_unit_id = row["tax_unit_id"]
        household_id = row.get("household_id", tax_unit_id)
        year = int(row["year"])
        state_code = str(row["state_code"]).upper() if "state_code" in row and pd.notna(row.get("state_code")) else "CA"
        
        # Filing status determines if spouse is created (JOINT=2)
        filing_status_val = row.get("filing_status")
        if pd.isna(filing_status_val):
            filing_status = 1  # SINGLE
        elif isinstance(filing_status_val, str):
            filing_status = FILING_STATUS_MAP.get(filing_status_val.upper(), 1)
        else:
            filing_status = int(filing_status_val)
        
        # Age - use provided or let PE use its default (40)
        age_head = int(row["age_head"]) if "age_head" in row and pd.notna(row.get("age_head")) else 40
        
        # Dependents
        num_deps = int(row["num_dependents"]) if "num_dependents" in row and pd.notna(row.get("num_dependents")) else 0
        dep_ages_str = row.get("dependent_ages")
        if pd.notna(dep_ages_str) and str(dep_ages_str).strip():
            dep_ages = [int(a.strip()) for a in str(dep_ages_str).split(",") if a.strip().isdigit()]
        else:
            dep_ages = [10] * num_deps
        
        pid = idx * 100
        
        # Head
        head = {
            "person_id": pid, "household_id": household_id, "tax_unit_id": tax_unit_id,
            "year": year, "state_code": state_code, "age": age_head,
            "is_tax_unit_head": True, "is_tax_unit_spouse": False, "is_tax_unit_dependent": False
        }
        for var in pe_vars:
            if var in row and pd.notna(row[var]):
                head[var] = row[var]
        persons.append(head)
        
        # Spouse (only for JOINT)
        if filing_status == 2:
            spouse_age = int(row["age_spouse"]) if "age_spouse" in row and pd.notna(row.get("age_spouse")) else age_head
            spouse = {
                "person_id": pid + 1, "household_id": household_id, "tax_unit_id": tax_unit_id,
                "year": year, "state_code": state_code, "age": spouse_age,
                "is_tax_unit_head": False, "is_tax_unit_spouse": True, "is_tax_unit_dependent": False
            }
            for var in household_vars | tax_unit_vars:
                if var in row and pd.notna(row[var]):
                    spouse[var] = row[var]
            persons.append(spouse)
        
        # Dependents
        for i, dep_age in enumerate(dep_ages[:num_deps]):
            dep = {
                "person_id": pid + 2 + i, "household_id": household_id, "tax_unit_id": tax_unit_id,
                "year": year, "state_code": state_code, "age": dep_age,
                "is_tax_unit_head": False, "is_tax_unit_spouse": False, "is_tax_unit_dependent": True
            }
            for var in household_vars | tax_unit_vars:
                if var in row and pd.notna(row[var]):
                    dep[var] = row[var]
            persons.append(dep)
    
    result = pd.DataFrame(persons)
    for var in person_vars:
        if var in result.columns:
            result[var] = result[var].fillna(0)
    return result


def parse_household_format(df: pd.DataFrame) -> pd.DataFrame:
    """Parse household-level input (one row per household with multiple tax units)."""
    required = ["household_id", "year", "num_tax_units"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")
    
    persons = []
    for idx, row in df.iterrows():
        household_id = row["household_id"]
        year = int(row["year"])
        state_code = str(row["state_code"]).upper() if "state_code" in row and pd.notna(row.get("state_code")) else "CA"
        pid = idx * 1000
        
        for tu_num in range(1, int(row["num_tax_units"]) + 1):
            s = f"_{tu_num}"
            tax_unit_id = household_id * 100 + tu_num
            
            filing_status_val = row.get(f"filing_status{s}")
            if pd.isna(filing_status_val):
                filing_status = 1
            elif isinstance(filing_status_val, str):
                filing_status = FILING_STATUS_MAP.get(filing_status_val.upper(), 1)
            else:
                filing_status = int(filing_status_val)
            
            age_head = int(row[f"age_head{s}"]) if f"age_head{s}" in row and pd.notna(row.get(f"age_head{s}")) else 40
            age_spouse = int(row[f"age_spouse{s}"]) if f"age_spouse{s}" in row and pd.notna(row.get(f"age_spouse{s}")) else 0
            num_deps = int(row[f"num_dependents{s}"]) if f"num_dependents{s}" in row and pd.notna(row.get(f"num_dependents{s}")) else 0
            
            persons.append({
                "person_id": pid, "household_id": household_id, "tax_unit_id": tax_unit_id,
                "year": year, "state_code": state_code, "age": age_head,
                "is_tax_unit_head": True, "is_tax_unit_spouse": False, "is_tax_unit_dependent": False
            })
            pid += 1
            
            if filing_status == 2 and age_spouse > 0:
                persons.append({
                    "person_id": pid, "household_id": household_id, "tax_unit_id": tax_unit_id,
                    "year": year, "state_code": state_code, "age": age_spouse,
                    "is_tax_unit_head": False, "is_tax_unit_spouse": True, "is_tax_unit_dependent": False
                })
                pid += 1
            
            for _ in range(num_deps):
                persons.append({
                    "person_id": pid, "household_id": household_id, "tax_unit_id": tax_unit_id,
                    "year": year, "state_code": state_code, "age": 10,
                    "is_tax_unit_head": False, "is_tax_unit_spouse": False, "is_tax_unit_dependent": True
                })
                pid += 1
    
    return pd.DataFrame(persons)


def parse_input(df: pd.DataFrame, input_type: str) -> pd.DataFrame:
    """Route to appropriate parser based on input format type."""
    parsers = {"person": parse_person_format, "tax_unit": parse_tax_unit_format, "household": parse_household_format}
    if input_type not in parsers:
        raise ValueError(f"Unknown input type: {input_type}. Must be one of {list(parsers.keys())}")
    return parsers[input_type](df)

In [None]:
# =============================================================================
# PolicyEngine Dataset Class
# =============================================================================


class ResearcherDataset(Dataset):
    """Converts person-level DataFrame into PolicyEngine's TIME_PERIOD_ARRAYS format."""

    name = "researcher_dataset"
    label = "Researcher Flat File Dataset"
    data_format = Dataset.TIME_PERIOD_ARRAYS

    def __init__(self, person_df: pd.DataFrame):
        self.person_df = person_df.copy()
        self.tmp_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False)
        self.file_path = Path(self.tmp_file.name)
        super().__init__()

    def generate(self) -> None:
        data = {}
        years = sorted(self.person_df["year"].unique())
        
        # Identify PE variable columns
        structural = {"person_id", "household_id", "tax_unit_id", "year", "state_code", "age",
                      "is_tax_unit_head", "is_tax_unit_spouse", "is_tax_unit_dependent"}
        pe_vars = set(self.person_df.columns) - structural
        household_vars = {v for v in pe_vars if get_variable_entity(v) == "household"}
        tax_unit_vars = {v for v in pe_vars if get_variable_entity(v) == "tax_unit"}
        person_vars = pe_vars - household_vars - tax_unit_vars

        print(f"Generating dataset for {len(self.person_df)} persons across {len(years)} year(s)...")

        for year in tqdm(years, desc="Processing years"):
            year_int = int(year)
            year_df = self.person_df[self.person_df["year"] == year].copy()
            if len(year_df) == 0:
                continue

            n_persons = len(year_df)
            hh_map = {hid: i for i, hid in enumerate(year_df["household_id"].unique())}
            tu_map = {tuid: i for i, tuid in enumerate(year_df["tax_unit_id"].unique())}
            n_hh, n_tu = len(hh_map), len(tu_map)

            # Person-to-entity mappings
            person_hh = np.array([hh_map[h] for h in year_df["household_id"]])
            person_tu = np.array([tu_map[t] for t in year_df["tax_unit_id"]])
            
            data.setdefault("person_id", {})[year_int] = np.arange(n_persons)
            data.setdefault("person_household_id", {})[year_int] = person_hh
            data.setdefault("person_tax_unit_id", {})[year_int] = person_tu
            for entity in ["family", "spm_unit", "marital_unit"]:
                data.setdefault(f"person_{entity}_id", {})[year_int] = person_hh

            # Entity ID arrays
            data.setdefault("household_id", {})[year_int] = np.arange(n_hh)
            data.setdefault("tax_unit_id", {})[year_int] = np.arange(n_tu)
            for entity in ["family", "spm_unit", "marital_unit"]:
                data.setdefault(f"{entity}_id", {})[year_int] = np.arange(n_hh)

            # Person attributes
            data.setdefault("age", {})[year_int] = year_df["age"].values.astype(int)
            for role in ["is_tax_unit_head", "is_tax_unit_spouse", "is_tax_unit_dependent"]:
                data.setdefault(role, {})[year_int] = year_df[role].values.astype(bool)

            # State FIPS (household-level)
            hh_states = year_df.groupby("household_id")["state_code"].first()
            state_codes = [hh_states[h] for h in sorted(hh_map.keys(), key=lambda x: hh_map[x])]
            data.setdefault("state_fips", {})[year_int] = np.array([STATE_CODE_TO_FIPS.get(sc, 6) for sc in state_codes])

            # Person-level PE variables
            for var in person_vars:
                if var in year_df.columns:
                    data.setdefault(var, {})[year_int] = year_df[var].fillna(0).values.astype(float)
            
            # Household-level PE variables
            for var in household_vars:
                if var in year_df.columns:
                    hh_vals = year_df.groupby("household_id")[var].first()
                    vals = [hh_vals.get(h, False) for h in sorted(hh_map.keys(), key=lambda x: hh_map[x])]
                    pe_var = _TAX_BENEFIT_SYSTEM.variables.get(var)
                    dtype = bool if pe_var and pe_var.value_type == bool else float
                    data.setdefault(var, {})[year_int] = np.array(vals).astype(dtype)
            
            # Tax unit-level PE variables
            for var in tax_unit_vars:
                if var in year_df.columns:
                    tu_vals = year_df.groupby("tax_unit_id")[var].first()
                    vals = [tu_vals.get(t, 0) for t in sorted(tu_map.keys(), key=lambda x: tu_map[x])]
                    pe_var = _TAX_BENEFIT_SYSTEM.variables.get(var)
                    if pe_var and pe_var.value_type == bool:
                        dtype = bool
                    elif pe_var and pe_var.value_type == int:
                        dtype = int
                    else:
                        dtype = float
                    data.setdefault(var, {})[year_int] = np.array(vals).astype(dtype)

        self.save_dataset(data)
        print("Dataset generated successfully.")

    def cleanup(self) -> None:
        if hasattr(self, "file_path") and self.file_path.exists():
            try:
                self.file_path.unlink()
            except:
                pass

In [None]:
# =============================================================================
# Main Simulation Function
# =============================================================================


def run_microsim(
    input_file: str,
    input_type: str = "tax_unit",
    output_vars: list = None,
    output_file: str = None,
) -> pd.DataFrame:
    """
    Run PolicyEngine microsimulation on a flat-file dataset.
    
    This is the main entry point. It:
    1. Reads and parses your CSV into person-level format
    2. Creates a PolicyEngine Dataset
    3. Runs the simulation
    4. Extracts results per tax unit
    
    Args:
        input_file: Path to CSV file
        input_type: Format of input data:
            - "tax_unit": One row per tax unit (most common)
            - "person": One row per person (for split income)
            - "household": One row per household with multiple tax units
        output_vars: List of PolicyEngine variables to calculate
                    (default: ["income_tax", "state_income_tax"])
        output_file: Optional path to save results CSV
    
    Returns:
        DataFrame with one row per tax unit and requested output variables
    """
    if output_vars is None:
        output_vars = ["income_tax", "state_income_tax"]
    
    print(f"Input: {input_file}")
    print(f"Format: {input_type}")
    print(f"Outputs: {output_vars}")
    
    # Step 1: Read and parse input
    input_df = pd.read_csv(input_file)
    print(f"Read {len(input_df)} rows")
    
    person_df = parse_input(input_df, input_type)
    print(f"Expanded to {len(person_df)} persons")
    
    # Step 2: Create dataset and run simulation
    dataset = ResearcherDataset(person_df)
    
    try:
        dataset.generate()
        sim = Microsimulation(dataset=dataset)
        
        # Step 3: Extract results per tax unit
        results = []
        years = sorted(person_df["year"].unique())
        
        for year in tqdm(years, desc="Extracting results"):
            year_int = int(year)
            year_str = str(year_int)
            year_df = person_df[person_df["year"] == year]
            tax_units = year_df.groupby("tax_unit_id").first().reset_index()
            
            for idx, (_, tu_row) in enumerate(tax_units.iterrows()):
                result = {
                    "tax_unit_id": tu_row["tax_unit_id"],
                    "year": year_int,
                    "state_code": tu_row["state_code"],
                }
                for var in output_vars:
                    try:
                        values = sim.calculate(var, period=year_str)
                        result[var] = round(float(values[idx]), 2)
                    except Exception as e:
                        print(f"Warning: Could not calculate {var}: {e}")
                        result[var] = 0.0
                results.append(result)
        
        results_df = pd.DataFrame(results)
        
        # Step 4: Save if output file specified
        if output_file:
            results_df.to_csv(output_file, index=False)
            print(f"Results saved to {output_file}")
        
        print(f"Done! {len(results_df)} tax units processed.")
        return results_df
    
    finally:
        dataset.cleanup()

---
## Input Format Reference

### Tax Unit Format (`input_type="tax_unit"`) - Most Common

One row per tax filing unit. Minimal example:

```csv
tax_unit_id,year,employment_income
1,2024,50000
2,2024,75000
```

**Required:** `tax_unit_id`, `year`

**Optional:** `state_code`, `filing_status`, `age_head`, `age_spouse`, `num_dependents`, `dependent_ages`, plus any PolicyEngine variable

---

### Person Format (`input_type="person"`) - Split Income

One row per person. Use when each person has different income.

**Required:** `person_id`, `household_id`, `tax_unit_id`, `year`, `age`, `is_tax_unit_head`, `is_tax_unit_spouse`, `is_tax_unit_dependent`

---

### Household Format (`input_type="household"`) - Multiple Tax Units

**Required:** `household_id`, `year`, `num_tax_units`

---

### Common Variables

**Inputs** (add as columns - entity is auto-detected):
- `employment_income`, `self_employment_income`, `social_security`
- `long_term_capital_gains`, `taxable_interest_income`, `rental_income`
- `in_nyc` (household-level, True/False)

**Outputs** (request via `output_vars`):
- `income_tax`, `state_income_tax`, `eitc`, `ctc`
- `adjusted_gross_income`, `taxable_income`, `payroll_tax`

Full list: [policyengine.org/us/api/variables](https://policyengine.org/us/api/variables)

---
## Run

Replace `your_data.csv` with your file path:

In [None]:
# Run microsimulation on your CSV file
results = run_microsim(
    input_file="your_data.csv",
    output_vars=["income_tax", "state_income_tax", "eitc", "ctc"],
)
results