# 08_flat_per_capita_rebate

## Part A — What we’re doing
We compute a **Flat Per-Capita** household rebate on the calibrated Step-01 panel.

- **Base**: `$1,700 × household_size` (configurable).
- **Phase-out (AGI basis)**:
  - Starts at **$50,000** for non-MFJ filers.
  - Starts at **$100,000** for MFJ filers.
  - Tapers at **$0.05 per $1** above the start.
- Floors at **0**.

**Outputs**
- `outputs/rebates/flat_per_capita/rebate_records_2024.csv` — household-level rebates  
- `outputs/rebates/flat_per_capita/summary_2024.csv` — totals & parameters  
- `outputs/rebates/flat_per_capita/by_decile_2024.csv` — decile totals (equivalized income)  
- `outputs/rebates/flat_per_capita/by_size_2024.csv` — totals by size bucket  
- `outputs/rebates/flat_per_capita/by_status_2024.csv` — totals by filing status  
- `outputs/rebates/flat_per_capita/plots/deciles_2024.png` — decile bar chart

---

## Part B — Inputs
- Step-01 panel: `intermediate/ca_panel_2024.(parquet|csv)` with  
  `household_agi`, `household_size`, `filing_status`, `household_weight`.

- Policy module:
  - `policy/rebates/flat_per_capita.py`

- Config (if present):  
  `rebate.flat.amount` (default 1700),  
  `rebate.flat.phaseout.single_start` (50k),  
  `rebate.flat.phaseout.mfj_start` (100k),  
  `rebate.flat.phaseout.rate` (0.05).

---

## Part C — Methods
1) Load panel & parameters  
2) Compute **record-level** base and phased rebate  
3) Aggregate totals and breakdowns (deciles via AGI per capita; size buckets; filing status)  
4) Save CSVs + a simple decile plot

---

## Part D — Acceptance checks
- Rebates **≥ 0** for all households  
- **With-phase-out ≤ base** (overall and by groups)  
- Breakdown sums ≈ overall total (within rounding)

---

## Part E — Troubleshooting
- If panel missing columns, re-run Step-01.  
- If all zeros, check amount/thresholds/rate are as intended.  
- If deciles look odd, ensure AGI numeric and hh_size ≥ 1.

---


In [None]:
import json, time
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from policy.rebates.flat_per_capita import flat_per_capita_rebate

START_TS = time.time()

# ---------- Paths ----------
INTERMEDIATE = Path("intermediate")
OUT_DIR = Path("outputs/rebates/flat_per_capita")
(OUT_DIR / "plots").mkdir(parents=True, exist_ok=True)

# ---------- Config (optional) ----------
def _load_config():
    for p in ["config.json", "config/config.json"]:
        f = Path(p)
        if f.exists():
            try:
                with open(f, "r") as fh:
                    return json.load(fh)
            except Exception:
                pass
    return {}

CFG = _load_config()
FLAT_AMOUNT = CFG.get("rebate", {}).get("flat", {}).get("amount", 1700.0)
PH = CFG.get("rebate", {}).get("flat", {}).get("phaseout", {})
FLAT_SINGLE_START = PH.get("single_start", 50_000.0)
FLAT_MFJ_START = PH.get("mfj_start", 100_000.0)
FLAT_RATE = PH.get("rate", 0.05)

print("Parameters:", {
    "flat.amount": FLAT_AMOUNT,
    "flat.single_start": FLAT_SINGLE_START,
    "flat.mfj_start": FLAT_MFJ_START,
    "flat.rate": FLAT_RATE
})

# ---------- Load Step-01 panel ----------
panel = None
for fname in ["ca_panel_2024.parquet", "ca_panel_2024.csv",
              "ca_panel_2024_2025.parquet", "ca_panel_2024_2025.csv"]:
    p = INTERMEDIATE / fname
    if p.exists():
        panel = pd.read_parquet(p) if p.suffix == ".parquet" else pd.read_csv(p)
        break
if panel is None:
    raise FileNotFoundError("Step 01 panel not found in intermediate/.")

req = ["household_agi","household_size","household_weight","filing_status"]
miss = [c for c in req if c not in panel.columns]
if miss:
    raise KeyError(f"Panel missing columns: {miss}")

# Types & derived
panel["household_agi"] = pd.to_numeric(panel["household_agi"], errors="coerce").fillna(0.0)
panel["household_size"] = pd.to_numeric(panel["household_size"], errors="coerce").fillna(1).astype(int).clip(lower=1)
panel["household_weight"] = pd.to_numeric(panel["household_weight"], errors="coerce").fillna(0.0)
panel["filing_status"] = panel["filing_status"].astype(str)
panel["size_bucket"] = np.where(panel["household_size"] >= 7, 7, panel["household_size"]).astype(int)
panel["status_group"] = np.where(panel["filing_status"].str.lower().str.contains("mfj"),
                                 "MFJ", "Single/Other")

# ---------- Compute rebate (record level) ----------
# Base (no phase-out): amount * size
panel["rebate_flat_base"] = panel["household_size"].astype(float) * float(FLAT_AMOUNT)

# With phase-out
panel["rebate_flat"] = [
    flat_per_capita_rebate(agi, sz, fs, FLAT_AMOUNT, FLAT_SINGLE_START, FLAT_MFJ_START, FLAT_RATE)
    for agi, sz, fs in zip(panel["household_agi"], panel["household_size"], panel["filing_status"])
]

# Acceptance: non-negativity
assert (panel["rebate_flat"] >= 0).all(), "Negative rebate found."

# ---------- Aggregations ----------
def wsum(x, w): return float((x.astype(float) * w.astype(float)).sum())

# Equivalized income deciles (AGI per capita)
inc_pc = panel["household_agi"].astype(float) / panel["household_size"].clip(lower=1).astype(float)
x = inc_pc.to_numpy(); w = panel["household_weight"].astype(float).to_numpy()
idx = np.argsort(x); xs, ws = x[idx], w[idx]; cw = np.cumsum(ws)
if len(ws) > 0 and cw[-1] > 0:
    cuts = [cw[-1] * k / 10 for k in range(1, 10)]
    edges = [-np.inf]
    for c in cuts:
        i = np.searchsorted(cw, c, side="left")
        i = min(max(i, 0), len(xs)-1)
        edges.append(xs[i])
    edges.append(np.inf)
    panel["decile"] = pd.cut(inc_pc, bins=edges, labels=range(1,11), include_lowest=True).astype(int)
else:
    panel["decile"] = 1

# Totals
total_with = wsum(panel["rebate_flat"], panel["household_weight"])
total_base = wsum(panel["rebate_flat_base"], panel["household_weight"])
assert total_with <= total_base + 1e-6, "With-phaseout exceeds base."

# By size bucket
by_size = (panel.groupby("size_bucket")
                .apply(lambda g: wsum(g["rebate_flat"], g["household_weight"]))
                .reset_index(name="weighted_total"))

# By filing status
by_status = (panel.groupby("status_group")
                  .apply(lambda g: wsum(g["rebate_flat"], g["household_weight"]))
                  .reset_index(name="weighted_total"))

# By decile
by_dec = (panel.groupby("decile")
               .apply(lambda g: wsum(g["rebate_flat"], g["household_weight"]))
               .reset_index(name="weighted_total")
               .sort_values("decile"))

# ---------- Save outputs ----------
panel.loc[:, ["household_agi","household_size","filing_status","household_weight",
              "rebate_flat","rebate_flat_base"]].to_csv(OUT_DIR / "rebate_records_2024.csv", index=False)

pd.DataFrame([{
    "policy": "flat_per_capita",
    "amount": FLAT_AMOUNT,
    "single_start": FLAT_SINGLE_START,
    "mfj_start": FLAT_MFJ_START,
    "phaseout_rate": FLAT_RATE,
    "total_with_phaseout": total_with,
    "total_no_phaseout": total_base
}]).to_csv(OUT_DIR / "summary_2024.csv", index=False)

by_dec.to_csv(OUT_DIR / "by_decile_2024.csv", index=False)
by_size.to_csv(OUT_DIR / "by_size_2024.csv", index=False)
by_status.to_csv(OUT_DIR / "by_status_2024.csv", index=False)

# ---------- Plot ----------
plt.figure()
plt.bar(by_dec["decile"].astype(str), by_dec["weighted_total"].astype(float))
plt.title("Flat per-capita rebate by equivalized-income decile (2024)")
plt.xlabel("Decile"); plt.ylabel("Weighted rebate total")
plt.tight_layout()
plt.savefig(OUT_DIR / "plots" / "deciles_2024.png", dpi=150)
plt.close()

print(f"✅ 08_flat_per_capita_rebate complete in {time.time()-START_TS:,.1f}s")
