# Marriage Penalty/Bonus Analysis

Compare `income_tax + state_income_tax` between:
- **Factual**: actual marital status
- **Counterfactual**: flipped marital status

Transformations:
- **Married → Single**: Spouse income combined into primary earner
- **Single → Married**: Add spouse with $0 income (tests bracket benefit)

In [1]:
import pandas as pd
import numpy as np
from policyengine_us import Simulation

YEAR = 2025

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load CPS data
df = pd.read_csv('/Users/pavelmakarchuk/analysis-notebooks/cps_households.csv')
df['total_wages'] = df['pwages'] + df['swages']

print(f"Total households: {len(df):,}")
print(f"Marital status: {df['mstat'].value_counts().to_dict()}")

Total households: 112,502
Marital status: {1: 56604, 2: 55898}


In [3]:
# Filter to households with meaningful earnings ($30k-$150k)
subset = df[(df['total_wages'] > 30000) & (df['total_wages'] < 150000)].head(50).copy()

print(f"Subset: {len(subset)} households with $30k-$150k wages")
print(f"Married (mstat=2): {(subset['mstat']==2).sum()}")
print(f"Single (mstat=1): {(subset['mstat']==1).sum()}")

Subset: 50 households with $30k-$150k wages
Married (mstat=2): 32
Single (mstat=1): 18


In [4]:
def create_situation_from_row(row, flip_marital_status=False):
    """Create a PolicyEngine situation from a TAXSIM row.
    
    If flip_marital_status=True:
    - Married (mstat=2) becomes single: remove spouse, combine ALL incomes
    - Single (mstat=1) becomes married: add spouse with $0 income
    """
    is_married = row['mstat'] == 2
    
    # Determine effective marital status
    if flip_marital_status:
        effective_married = not is_married
    else:
        effective_married = is_married
    
    # Helper to safely get numeric values
    def get_val(col):
        return float(row.get(col, 0) or 0)
    
    # Primary earner income
    primary_wages = get_val('pwages')
    primary_self_emp = get_val('psemp')
    
    # Spouse income (only if originally married)
    spouse_wages = get_val('swages') if is_married else 0
    spouse_self_emp = get_val('ssemp') if is_married else 0
    
    # Shared income types (allocated to primary, will combine spouse if flipping)
    dividends = get_val('dividends')
    interest = get_val('intrec')
    stcg = get_val('stcg')
    ltcg = get_val('ltcg')
    pensions = get_val('pensions')
    social_security = get_val('gssi')
    
    # Build people
    people = {
        "adult": {
            "age": {YEAR: int(row['page'])},
            "employment_income": {YEAR: primary_wages},
            "self_employment_income": {YEAR: primary_self_emp},
            "dividend_income": {YEAR: dividends},
            "interest_income": {YEAR: interest},
            "short_term_capital_gains": {YEAR: stcg},
            "long_term_capital_gains": {YEAR: ltcg},
            "pension_income": {YEAR: pensions},
            "social_security": {YEAR: social_security},
        }
    }
    members = ["adult"]
    marital_members = ["adult"]
    
    # Handle spouse
    if effective_married:
        if is_married and not flip_marital_status:
            # Original married, staying married - spouse keeps their income
            people["spouse"] = {
                "age": {YEAR: int(row['sage']) if row['sage'] > 0 else int(row['page'])},
                "employment_income": {YEAR: spouse_wages},
                "self_employment_income": {YEAR: spouse_self_emp},
            }
        else:
            # Single becoming married - add spouse with $0 income
            people["spouse"] = {
                "age": {YEAR: int(row['page'])},
                "employment_income": {YEAR: 0},
            }
        members.append("spouse")
        marital_members.append("spouse")
    elif is_married and flip_marital_status:
        # Married becoming single - combine ALL spouse income into adult
        people["adult"]["employment_income"][YEAR] += spouse_wages
        people["adult"]["self_employment_income"][YEAR] += spouse_self_emp
    
    # Add dependents
    dep_ages = [row[f'age{i}'] for i in range(1, 12) if f'age{i}' in row and pd.notna(row[f'age{i}']) and row[f'age{i}'] > 0]
    for i, age in enumerate(dep_ages):
        dep_name = f"child_{i}"
        people[dep_name] = {"age": {YEAR: int(age)}}
        members.append(dep_name)
    
    situation = {
        "people": people,
        "families": {"family": {"members": members}},
        "marital_units": {"marital_unit": {"members": marital_members}},
        "tax_units": {"tax_unit": {"members": members}},
        "households": {
            "household": {
                "members": members,
                "state_name": {YEAR: "MN"},
            }
        },
    }
    
    return situation

In [5]:
def calculate_taxes(situation):
    """Calculate income_tax + state_income_tax for a situation."""
    sim = Simulation(situation=situation)
    income_tax = sim.calculate("income_tax", YEAR)
    state_income_tax = sim.calculate("state_income_tax", YEAR)
    return float(income_tax.sum() + state_income_tax.sum())

In [6]:
# Calculate factual and counterfactual taxes for each household
results = []

for idx, row in subset.iterrows():
    try:
        # Factual (actual marital status)
        factual_situation = create_situation_from_row(row, flip_marital_status=False)
        factual_tax = calculate_taxes(factual_situation)
        
        # Counterfactual (flipped marital status)
        counterfactual_situation = create_situation_from_row(row, flip_marital_status=True)
        counterfactual_tax = calculate_taxes(counterfactual_situation)
        
        results.append({
            'taxsimid': row['taxsimid'],
            'original_mstat': 'married' if row['mstat'] == 2 else 'single',
            'pwages': row['pwages'],
            'swages': row['swages'],
            'depx': row['depx'],
            'factual_tax': factual_tax,
            'counterfactual_tax': counterfactual_tax,
            'marriage_penalty': counterfactual_tax - factual_tax if row['mstat'] == 1 else factual_tax - counterfactual_tax,
        })
    except Exception as e:
        print(f"Error on row {idx}: {e}")

results_df = pd.DataFrame(results)
print(f"Processed {len(results_df)} households")

Processed 50 households


In [7]:
# View results
results_df.head(20)

Unnamed: 0,taxsimid,original_mstat,pwages,swages,depx,factual_tax,counterfactual_tax,marriage_penalty
0,2.0,married,0.0,41904.761905,0.0,1840.755859,4437.500977,-2596.745117
1,8.0,married,0.0,64668.47619,1.0,4775.831055,5551.355957,-775.524902
2,12.0,married,45047.619048,26190.47619,0.0,6691.18457,10651.116211,-3959.931641
3,17.0,single,83809.52381,0.0,0.0,14271.6875,8969.470703,-5302.216797
4,25.0,single,75428.571429,0.0,2.0,4332.105957,2867.181641,-1464.924316
5,26.0,married,55824.47619,0.0,0.0,4016.921875,6977.046387,-2960.124512
6,28.0,single,44000.0,0.0,0.0,4801.024902,2162.375,-2638.649902
7,29.0,single,31428.571429,0.0,1.0,-6190.179688,-8171.620117,-1981.44043
8,30.0,married,77523.809524,28809.52381,0.0,13203.947266,20758.546875,-7554.599609
9,36.0,married,41904.761905,18857.142857,0.0,7874.586426,13305.869141,-5431.282715


In [8]:
# Summary statistics
print("=== Marriage Penalty/Bonus Summary ===")
print(f"\nPositive = marriage PENALTY (pay MORE when married)")
print(f"Negative = marriage BONUS (pay LESS when married)")
print(f"\nMean effect: ${results_df['marriage_penalty'].mean():,.0f}")
print(f"Median effect: ${results_df['marriage_penalty'].median():,.0f}")
print(f"\nHouseholds with marriage penalty: {(results_df['marriage_penalty'] > 0).sum()}")
print(f"Households with marriage bonus: {(results_df['marriage_penalty'] < 0).sum()}")
print(f"Households with no change: {(results_df['marriage_penalty'] == 0).sum()}")

# Summary table
print("\n=== Results Table ===")
display_cols = ['taxsimid', 'original_mstat', 'pwages', 'swages', 'depx', 'factual_tax', 'counterfactual_tax', 'marriage_penalty']
results_df[display_cols].head(20)

=== Marriage Penalty/Bonus Summary ===

Positive = marriage PENALTY (pay MORE when married)
Negative = marriage BONUS (pay LESS when married)

Mean effect: $-4,017
Median effect: $-3,534

Households with marriage penalty: 0
Households with marriage bonus: 50
Households with no change: 0

=== Results Table ===


Unnamed: 0,taxsimid,original_mstat,pwages,swages,depx,factual_tax,counterfactual_tax,marriage_penalty
0,2.0,married,0.0,41904.761905,0.0,1840.755859,4437.500977,-2596.745117
1,8.0,married,0.0,64668.47619,1.0,4775.831055,5551.355957,-775.524902
2,12.0,married,45047.619048,26190.47619,0.0,6691.18457,10651.116211,-3959.931641
3,17.0,single,83809.52381,0.0,0.0,14271.6875,8969.470703,-5302.216797
4,25.0,single,75428.571429,0.0,2.0,4332.105957,2867.181641,-1464.924316
5,26.0,married,55824.47619,0.0,0.0,4016.921875,6977.046387,-2960.124512
6,28.0,single,44000.0,0.0,0.0,4801.024902,2162.375,-2638.649902
7,29.0,single,31428.571429,0.0,1.0,-6190.179688,-8171.620117,-1981.44043
8,30.0,married,77523.809524,28809.52381,0.0,13203.947266,20758.546875,-7554.599609
9,36.0,married,41904.761905,18857.142857,0.0,7874.586426,13305.869141,-5431.282715


In [9]:
# Breakdown by original marital status
print("=== By Original Marital Status ===")
for status in ['single', 'married']:
    subset_status = results_df[results_df['original_mstat'] == status]
    if len(subset_status) > 0:
        print(f"\n{status.upper()} households (n={len(subset_status)}):")
        print(f"  Mean marriage effect: ${subset_status['marriage_penalty'].mean():,.0f}")
        print(f"  Avg factual tax: ${subset_status['factual_tax'].mean():,.0f}")
        print(f"  Avg counterfactual tax: ${subset_status['counterfactual_tax'].mean():,.0f}")

=== By Original Marital Status ===

SINGLE households (n=18):
  Mean marriage effect: $-3,355
  Avg factual tax: $5,900
  Avg counterfactual tax: $2,545

MARRIED households (n=32):
  Mean marriage effect: $-4,389
  Avg factual tax: $7,781
  Avg counterfactual tax: $12,171


In [10]:
# Export results to CSV
results_df.to_csv('marriage_penalty_results.csv', index=False)
print(f"Saved {len(results_df)} rows to marriage_penalty_results.csv")

Saved 50 rows to marriage_penalty_results.csv
