In [1]:
from policyengine_us import Simulation
from policyengine_core.reforms import Reform
from policyengine_core.charts import format_fig
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Color constants
GRAY = "#808080"
BLUE_PRIMARY = "#2C6496"
TEAL_ACCENT = "#39C6C0"
DARK_GRAY = "#616161"


In [3]:
# Define reform
reform = Reform.from_dict(
    {
        "gov.aca.ptc_phase_out_rate[0].amount": {"2026-01-01.2100-12-31": 0},
        "gov.aca.ptc_phase_out_rate[1].amount": {"2025-01-01.2100-12-31": 0},
        "gov.aca.ptc_phase_out_rate[2].amount": {"2026-01-01.2100-12-31": 0},
        "gov.aca.ptc_phase_out_rate[3].amount": {"2026-01-01.2100-12-31": 0.02},
        "gov.aca.ptc_phase_out_rate[4].amount": {"2026-01-01.2100-12-31": 0.04},
        "gov.aca.ptc_phase_out_rate[5].amount": {"2026-01-01.2100-12-31": 0.06},
        "gov.aca.ptc_phase_out_rate[6].amount": {"2026-01-01.2100-12-31": 0.085},
        "gov.aca.ptc_income_eligibility[2].amount": {"2026-01-01.2100-12-31": True},
    },
    country_id="us",
)

In [4]:
def create_household_situation(state, county_fips=None, has_child=True, year=2026):
    """Create a household situation for simulation."""
    
    # Define people based on household composition
    people = {
        "you": {"age": {str(year): 30 if has_child else 25}},
        "your partner": {"age": {str(year): 30 if has_child else 28}}
    }
    
    members = ["you", "your partner"]
    
    if has_child:
        people["your first dependent"] = {"age": {str(year): 3}}
        members.append("your first dependent")
    
    # Build situation
    situation = {
        "people": people,
        "families": {"your family": {"members": members}},
        "spm_units": {"your household": {"members": members}},
        "tax_units": {"your tax unit": {"members": members}},
        "households": {
            "your household": {
                "members": members,
                "state_name": {str(year): state}
            }
        },
        "marital_units": {"your marital unit": {"members": ["you", "your partner"]}},
        "axes": [[{"name": "employment_income", "count": 800, "min": 0, "max": 400000}]]
    }
    
    # Add county if provided
    if county_fips:
        situation["households"]["your household"]["county_fips"] = {str(year): county_fips}
    
    # Add child's marital unit if applicable
    if has_child:
        situation["marital_units"]["your first dependent's marital unit"] = {
            "members": ["your first dependent"],
            "marital_unit_id": {str(year): 1}
        }
    
    return situation

In [5]:
def calculate_benefits(simulation, variables, period=2026):
    """Calculate multiple variables for a simulation and return as dict."""
    return {
        var: simulation.calculate(var, map_to="household", period=period)
        for var in variables
    }

In [6]:
def create_benefit_traces(x_values, benefits_dict, label_suffix, line_style="solid"):
    """Create plot traces for benefits."""
    traces = []
    
    # Define line properties
    dash_style = None if line_style == "solid" else "dot"
    
    # Individual benefit traces
    benefit_colors = {
        "per_capita_chip": GRAY,
        "aca_ptc": BLUE_PRIMARY,
        "medicaid_per_capita_cost": TEAL_ACCENT
    }
    
    benefit_labels = {
        "per_capita_chip": "CHIP",
        "aca_ptc": "ACA PTC",
        "medicaid_per_capita_cost": "Medicaid"
    }
    
    for benefit, color in benefit_colors.items():
        if benefit in benefits_dict:
            traces.append(go.Scatter(
                x=x_values,
                y=benefits_dict[benefit],
                mode="lines",
                name=f"{benefit_labels[benefit]} ({label_suffix})",
                line=dict(color=color, width=2, dash=dash_style)
            ))
    
    # Total benefits trace
    if all(b in benefits_dict for b in benefit_colors.keys()):
        total_benefits = [
            sum(values) for values in zip(
                benefits_dict["per_capita_chip"],
                benefits_dict["aca_ptc"],
                benefits_dict["medicaid_per_capita_cost"]
            )
        ]
        
        traces.append(go.Scatter(
            x=x_values,
            y=total_benefits,
            mode="lines",
            name=f"Total Benefits ({label_suffix})",
            line=dict(color=DARK_GRAY, width=2, dash=dash_style)
        ))
    
    return traces

In [7]:
def create_household_benefit_plot(state_info, reform=None):
    """Create a complete benefit plot for a household."""
    state_name = state_info["name"]
    state_abbrev = state_info["abbrev"]
    county_fips = state_info.get("county_fips")
    has_child = state_info.get("has_child", True)
    
    # Create situations
    situation = create_household_situation(state_abbrev, county_fips, has_child)
    
    # Create simulations
    baseline_sim = Simulation(situation=situation)
    
    # Variables to calculate
    variables = ["employment_income", "per_capita_chip", "aca_ptc", "medicaid_per_capita_cost"]
    
    # Calculate benefits
    baseline_benefits = calculate_benefits(baseline_sim, variables)
    
    # Create figure
    fig = go.Figure()
    
    # Add baseline traces
    baseline_traces = create_benefit_traces(
        baseline_benefits["employment_income"],
        baseline_benefits,
        "Baseline"
    )
    for trace in baseline_traces:
        fig.add_trace(trace)
    
    # Add reform traces if reform provided
    if reform:
        reform_sim = Simulation(situation=situation, reform=reform)
        reform_benefits = calculate_benefits(reform_sim, variables)
        
        reform_traces = create_benefit_traces(
            reform_benefits["employment_income"],
            reform_benefits,
            "Reform",
            line_style="dotted"
        )
        for trace in reform_traces:
            fig.add_trace(trace)
    
    # Update layout
    household_type = "Family of 3" if has_child else "Couple"
    fig.update_layout(
        title=f"{state_name} Household ({household_type}) - Program Benefits by Income Level",
        xaxis_title="Household Income",
        yaxis_title="Benefit Amount",
        legend_title="Programs",
        xaxis=dict(tickformat="$,.0f", range=[0, 400000]),
        yaxis=dict(tickformat="$,.0f"),
        height=600,
        width=1000
    )
    
    return format_fig(fig)

In [11]:
ny_info = {"name": "New York", "abbrev": "NY", "has_child": True}
tx_info = {"name": "Texas", "abbrev": "TX", "county_fips": "48015", "has_child": False}

In [8]:


fig_ny = create_household_benefit_plot(ny_info, reform)
fig_tx = create_household_benefit_plot(tx_info, reform)

In [9]:
fig_ny.show()


In [10]:
fig_tx.show()
