In [1]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from policyengine_core.charts import format_fig
from policyengine_us import Simulation
from policyengine_core.reforms import Reform

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Constants
YEAR = 2026
MAX_INCOME = 800000

In [3]:
reforms = Reform.from_dict(
    {
        "gov.contrib.congress.romney.family_security_act.remove_head_of_household": {
            "2024-01-01.2100-12-31": True
        },
        "gov.contrib.congress.romney.family_security_act_2024.pregnant_mothers_credit.amount[0].amount": {
            "2026-01-01.2039-12-31": 2800
        },
        "gov.contrib.congress.romney.family_security_act_2024.pregnant_mothers_credit.income_phase_in_end": {
            "2026-01-01.2039-12-31": 10000
        },
        "gov.contrib.congress.romney.family_security_act_2_0.ctc.apply_ctc_structure": {
            "2024-01-01.2100-12-31": True
        },
        "gov.contrib.congress.romney.family_security_act_2_0.ctc.base[0].amount": {
            "2026-01-01.2039-12-31": 4200
        },
        "gov.contrib.congress.romney.family_security_act_2_0.ctc.base[1].amount": {
            "2026-01-01.2039-12-31": 3000
        },
        "gov.contrib.congress.romney.family_security_act_2_0.ctc.child_cap": {
            "2026-01-01.2039-12-31": 6
        },
        "gov.contrib.congress.romney.family_security_act_2_0.ctc.phase_in.income_phase_in_end": {
            "2026-01-01.2039-12-31": 20000
        },
        "gov.contrib.congress.romney.family_security_act_2_0.eitc.amount.joint[0].amount": {
            "2026-01-01.2039-12-31": 1400
        },
        "gov.contrib.congress.romney.family_security_act_2_0.eitc.amount.joint[1].amount": {
            "2026-01-01.2039-12-31": 5000
        },
        "gov.contrib.congress.romney.family_security_act_2_0.eitc.amount.single[0].amount": {
            "2026-01-01.2039-12-31": 700
        },
        "gov.contrib.congress.romney.family_security_act_2_0.eitc.amount.single[1].amount": {
            "2026-01-01.2039-12-31": 4300
        },
        "gov.contrib.congress.romney.family_security_act_2_0.eitc.apply_eitc_structure": {
            "2026-01-01.2039-12-31": True
        },
        "gov.contrib.treasury.repeal_dependent_exemptions": {
            "2026-01-01.2039-12-31": True
        },
        "gov.irs.credits.cdcc.eligibility.child_age": {"2026-01-01.2039-12-31": 0},
        "gov.irs.credits.ctc.phase_out.threshold.HEAD_OF_HOUSEHOLD": {
            "2026-01-01.2039-12-31": 200000
        },
        "gov.irs.credits.ctc.phase_out.threshold.JOINT": {
            "2026-01-01.2039-12-31": 400000
        },
        "gov.irs.credits.ctc.phase_out.threshold.SEPARATE": {
            "2026-01-01.2039-12-31": 200000
        },
        "gov.irs.credits.ctc.phase_out.threshold.SINGLE": {
            "2026-01-01.2039-12-31": 200000
        },
        "gov.irs.credits.ctc.phase_out.threshold.SURVIVING_SPOUSE": {
            "2026-01-01.2039-12-31": 200000
        },
        "gov.irs.credits.ctc.refundable.fully_refundable": {
            "2024-01-01.2100-12-31": True
        },
        "gov.irs.credits.eitc.phase_in_rate[2].amount": {"2026-01-01.2039-12-31": 0.34},
        "gov.irs.credits.eitc.phase_in_rate[3].amount": {"2026-01-01.2039-12-31": 0.34},
        "gov.irs.credits.eitc.phase_out.joint_bonus[0].amount": {
            "2026-01-01.2039-12-31": 10000
        },
        "gov.irs.credits.eitc.phase_out.joint_bonus[1].amount": {
            "2026-01-01.2039-12-31": 10000
        },
        "gov.irs.credits.eitc.phase_out.rate[0].amount": {"2026-01-01.2039-12-31": 0.1},
        "gov.irs.credits.eitc.phase_out.rate[1].amount": {
            "2026-01-01.2039-12-31": 0.25
        },
        "gov.irs.credits.eitc.phase_out.rate[2].amount": {
            "2026-01-01.2039-12-31": 0.25
        },
        "gov.irs.credits.eitc.phase_out.rate[3].amount": {
            "2026-01-01.2039-12-31": 0.25
        },
        "gov.irs.credits.eitc.phase_out.start[0].amount": {
            "2026-01-01.2039-12-31": 10000
        },
        "gov.irs.credits.eitc.phase_out.start[1].amount": {
            "2026-01-01.2039-12-31": 33000
        },
        "gov.irs.credits.eitc.phase_out.start[2].amount": {
            "2026-01-01.2039-12-31": 33000
        },
        "gov.irs.credits.eitc.phase_out.start[3].amount": {
            "2026-01-01.2039-12-31": 33000
        },
        "gov.irs.deductions.itemized.salt_and_real_estate.cap.HEAD_OF_HOUSEHOLD": {
            "2026-01-01.2100-12-31": 10000
        },
        "gov.irs.deductions.itemized.salt_and_real_estate.cap.JOINT": {
            "2026-01-01.2100-12-31": 10000
        },
        "gov.irs.deductions.itemized.salt_and_real_estate.cap.SEPARATE": {
            "2026-01-01.2100-12-31": 5000
        },
        "gov.irs.deductions.itemized.salt_and_real_estate.cap.SINGLE": {
            "2026-01-01.2100-12-31": 10000
        },
        "gov.irs.deductions.itemized.salt_and_real_estate.cap.SURVIVING_SPOUSE": {
            "2026-01-01.2100-12-31": 10000
        },
    },
    country_id="us",
)

In [4]:
def create_situation(filing_status, child_ages, with_childcare=False):
    situation = {
        "people": {"you": {"age": {YEAR: 40}, "employment_income": {YEAR: 0}}},
        "axes": [],
    }
    members = ["you"]

    if filing_status == "married":
        situation["people"]["your partner"] = {
            "age": {YEAR: 40},
            "employment_income": {YEAR: 0},
        }
        members.append("your partner")
        situation["axes"] = [
            [
                {
                    "name": "employment_income",
                    "count": 201,
                    "min": 0,
                    "max": MAX_INCOME / 2,
                    "period": YEAR,
                    "index": 0,
                },
                {
                    "name": "employment_income",
                    "count": 201,
                    "min": 0,
                    "max": MAX_INCOME / 2,
                    "period": YEAR,
                    "index": 1,
                },
            ]
        ]
    else:
        situation["axes"] = [
            [
                {
                    "name": "employment_income",
                    "count": 201,
                    "min": 0,
                    "max": MAX_INCOME,
                    "period": YEAR,
                    "index": 0,
                }
            ]
        ]

    for i, age in enumerate(child_ages, 1):
        child_id = f"child {i}"
        situation["people"][child_id] = {
            "age": {YEAR: age},
            "employment_income": {YEAR: 0},
        }
        members.append(child_id)

    situation["families"] = {"your family": {"members": members}}
    situation["marital_units"] = {"your marital unit": {"members": members}}
    situation["tax_units"] = {"your tax unit": {"members": members}}
    situation["spm_units"] = {"your spm_unit": {"members": members}}
    situation["households"] = {
        "your household": {"members": members, "state_name": {YEAR: "TX"}}
    }

    if with_childcare:
        situation["tax_units"]["your tax unit"]["tax_unit_childcare_expenses"] = {
            YEAR: 10000
        }

    return situation

In [5]:
def calculate_income(situation, reform=None):
    simulation = Simulation(situation=situation, reform=reform)
    return simulation.calculate("household_net_income", YEAR)

In [6]:
def create_reform_graph(reform_name, reform):
    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles=(
            "Single, No Childcare",
            "Single, With Childcare",
            "Married, No Childcare",
            "Married, With Childcare",
        ),
    )

    colors = {
        "No children": "#0066cc",
        "1 newborn": "#4d94ff",
        "Newborn + 5 y.o.": "#99c2ff",
        "Newborn + 5 y.o. + 10 y.o.": "#cce6ff",
    }

    scenarios = [
        ("single", False),
        ("single", True),
        ("married", False),
        ("married", True),
    ]

    x = np.linspace(0, MAX_INCOME, 201)

    for i, (filing_status, childcare) in enumerate(scenarios):
        row = i // 2 + 1
        col = i % 2 + 1

        for child_scenario, child_ages in [
            ("No children", []),
            ("1 newborn", [0]),
            ("Newborn + 5 y.o.", [0, 5]),
            ("Newborn + 5 y.o. + 10 y.o.", [0, 5, 10]),
        ]:
            situation = create_situation(filing_status, child_ages, childcare)
            baseline = calculate_income(situation)
            reform_result = calculate_income(situation, reform)

            color = colors[child_scenario]

            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=reform_result - baseline,
                    mode="lines",
                    name=child_scenario,
                    line=dict(color=color),
                    legendgroup=child_scenario,
                    showlegend=col == 1
                    and row == 1,  # Show legend only for the first subplot
                ),
                row=row,
                col=col,
            )

    fig.update_layout(
        height=800,
        width=1200,
        title_text=f"{reform_name} Reform Comparison",
        legend_title="Household Type",
        legend=dict(yanchor="top", y=0.99, xanchor="left", x=1.05),
    )

    for i in range(1, 5):
        fig.update_xaxes(
            title_text="Earnings",
            tickformat="$,.0f",
            range=[0, MAX_INCOME],
            row=((i - 1) // 2) + 1,
            col=((i - 1) % 2) + 1,
        )
        fig.update_yaxes(
            title_text="Net Impact",
            tickformat="$,.0f",
            row=((i - 1) // 2) + 1,
            col=((i - 1) % 2) + 1,
        )

    fig = format_fig(fig)
    return fig

In [12]:
# Create and display separate graphs for each reform
fig = create_reform_graph("Family First Act", reforms)
fig.show()

In [9]:
def create_situation_10yo(filing_status, num_children, with_childcare=False):
    situation = {
        "people": {"you": {"age": {YEAR: 40}, "employment_income": {YEAR: 0}}},
        "axes": [],
    }
    members = ["you"]

    if filing_status == "married":
        situation["people"]["your partner"] = {
            "age": {YEAR: 40},
            "employment_income": {YEAR: 0},
        }
        members.append("your partner")
        situation["axes"] = [
            [
                {
                    "name": "employment_income",
                    "count": 201,
                    "min": 0,
                    "max": MAX_INCOME / 2,
                    "period": YEAR,
                    "index": 0,
                },
                {
                    "name": "employment_income",
                    "count": 201,
                    "min": 0,
                    "max": MAX_INCOME / 2,
                    "period": YEAR,
                    "index": 1,
                },
            ]
        ]
    else:
        situation["axes"] = [
            [
                {
                    "name": "employment_income",
                    "count": 201,
                    "min": 0,
                    "max": MAX_INCOME,
                    "period": YEAR,
                    "index": 0,
                }
            ]
        ]

    for i in range(1, num_children + 1):
        child_id = f"child {i}"
        situation["people"][child_id] = {
            "age": {YEAR: 10},
            "employment_income": {YEAR: 0},
        }
        members.append(child_id)

    situation["families"] = {"your family": {"members": members}}
    situation["marital_units"] = {"your marital unit": {"members": members}}
    situation["tax_units"] = {"your tax unit": {"members": members}}
    situation["spm_units"] = {"your spm_unit": {"members": members}}
    situation["households"] = {
        "your household": {"members": members, "state_name": {YEAR: "TX"}}
    }

    if with_childcare:
        situation["tax_units"]["your tax unit"]["tax_unit_childcare_expenses"] = {
            YEAR: 10000 * num_children
        }

    return situation

In [10]:
def create_reform_graph_10yo(reform_name, reform):
    fig = make_subplots(
        rows=2,
        cols=2,
        subplot_titles=(
            "Single, No Childcare",
            "Single, With Childcare",
            "Married, No Childcare",
            "Married, With Childcare",
        ),
    )

    colors = {
        "1 child": "#4d94ff",
        "2 children": "#99c2ff",
        "3 children": "#cce6ff",
    }

    scenarios = [
        ("single", False),
        ("single", True),
        ("married", False),
        ("married", True),
    ]

    x = np.linspace(0, MAX_INCOME, 201)

    for i, (filing_status, childcare) in enumerate(scenarios):
        row = i // 2 + 1
        col = i % 2 + 1

        # Add a thin black line at y=0 for each subplot
        fig.add_shape(
            type="line",
            x0=0,
            x1=MAX_INCOME,
            y0=0,
            y1=0,
            line=dict(color="black", width=0.5),
            row=row,
            col=col,
        )

        for child_scenario, num_children in [
            ("1 child", 1),
            ("2 children", 2),
            ("3 children", 3),
        ]:
            situation = create_situation_10yo(filing_status, num_children, childcare)
            baseline = calculate_income(situation)
            reform_result = calculate_income(situation, reform)

            color = colors[child_scenario]

            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=reform_result - baseline,
                    mode="lines",
                    name=child_scenario,
                    line=dict(color=color),
                    legendgroup=child_scenario,
                    showlegend=col == 1 and row == 1,
                ),
                row=row,
                col=col,
            )

    fig.update_layout(
        height=800,
        width=1200,
        title_text=f"{reform_name} Reform Comparison (all children age 6-12)",
        legend_title="Household Type",
        legend=dict(yanchor="top", y=0.99, xanchor="left", x=1.05),
    )

    for i in range(1, 5):
        fig.update_xaxes(
            tickformat="$,.0f",
            range=[0, MAX_INCOME],
            row=((i - 1) // 2) + 1,
            col=((i - 1) % 2) + 1,
            title_text="",  # Remove x-axis title
        )
        fig.update_yaxes(
            tickformat="$,.0f",
            range=[-10000, 10000],  # Set y-axis range from -10k to 10k
            row=((i - 1) // 2) + 1,
            col=((i - 1) % 2) + 1,
            title_text="",  # Remove y-axis title
        )

    fig = format_fig(fig)
    return fig

In [13]:
# Create and display separate graphs for each reform with 10-year-old children scenarios
fig = create_reform_graph_10yo("Family First Act", reforms)
fig.show()