In [None]:
from __future__ import annotations
import re
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# ===================== Path Configuration (Modify as Needed) =====================
PATH_MAP = Path("./poverty/county_to_iso_or_city.xlsx")   # Figure 1: county_fips -> iso, zone mapping
MAP_SHEET = 0                                             # Or sheet name
DIR_LEAD_2022 = Path("./poverty/LEAD/")                   # Figure 2: Directory containing "AK AMI Counties 2022.csv" etc.
PATH_INCOME_GROWTH = Path("./poverty/macroeconomic.xlsx")   # Figure 4: Income growth factors (Row where 2022=1)
INCOME_SHEET = 0                                          # Or sheet name, script auto-searches for "Real Dispos" row
DIR_ISO_PRICES = Path("./rider/")                         # Figures 5-6: Several xlsx files (e.g., ERCOT_zone_prices.xlsx)
PATH_GAS_OIL = Path("./poverty/fuel/state_fuel_price.xlsx")  # Figure 7: Each sheet is a state, rows contain NG, DFO
OUT_DIR = Path("./poverty/")                              # Output directory (One xlsx per year)
YEARS = list(range(2025, 2030 + 1))                       # Target Years
# =============== Utility: State FIPS -> State Abbreviation (For matching LEAD rows to states) ===============
STATE_FIPS_TO_ABBR = {
    "01":"AL","02":"AK","04":"AZ","05":"AR","06":"CA","08":"CO","09":"CT","10":"DE","11":"DC","12":"FL",
    "13":"GA","15":"HI","16":"ID","17":"IL","18":"IN","19":"IA","20":"KS","21":"KY","22":"LA","23":"ME",
    "24":"MD","25":"MA","26":"MI","27":"MN","28":"MS","29":"MO","30":"MT","31":"NE","32":"NV","33":"NH",
    "34":"NJ","35":"NM","36":"NY","37":"NC","38":"ND","39":"OH","40":"OK","41":"OR","42":"PA","44":"RI",
    "45":"SC","46":"SD","47":"TN","48":"TX","49":"UT","50":"VT","51":"VA","53":"WA","54":"WV","55":"WI",
    "56":"WY","72":"PR"}
# ===================== General Cleaning Functions =====================
def clean_zone(z: str, iso: str | None = None) -> str:
    """Remove prefix + remove punctuation + uppercase. E.g.: 'CAISO_ZP-26' -> 'ZP26'; 'MISO_LRZ8_' -> 'LRZ8'"""
    if pd.isna(z) or z is None:
        return ""
    z0 = str(z).strip()
    # Remove leading ISO prefix (PJM_, CAISO-, ERCOT_ etc.)
    if iso and isinstance(iso, str) and len(iso) > 0:
        pattern = r"^\s*{0}[\s_\-:/]+".format(re.escape(iso))
        z0 = re.sub(pattern, "", z0, flags=re.IGNORECASE)
    # Remove all non-alphanumeric characters
    z0 = re.sub(r"[^A-Za-z0-9]+", "", z0)
    return z0.upper()

def clean_iso_name_from_filename(p: Path) -> str:
    """Extract ISO name 'PJM' from 'PJM_zone_prices.xlsx'."""
    name = p.stem
    m = re.match(r"([A-Za-z\-]+)", name)
    iso = m.group(1) if m else name
    return iso.upper().replace("-", "")

# ===================== 1) Load Mapping Table =====================
def load_county_map(path: Path, sheet=0) -> pd.DataFrame:
    df = pd.read_excel(path, sheet_name=sheet, dtype={"county_fips": str})
    # Keep only necessary columns
    need_cols = ["county_fips", "iso", "zone"]
    miss = [c for c in need_cols if c not in df.columns]
    if miss:
        raise ValueError(f"Mapping table missing columns: {miss}")
    df["county_fips"] = df["county_fips"].astype(str).str.zfill(5)
    df["iso"] = df["iso"].astype(str).str.strip().str.upper()
    # Zone cleaning (Don't remove prefix yet, as ISO varies; clean per county later)
    return df[need_cols]

# ===================== 2) Aggregate LEAD 2022 Baseline =====================
def load_lead_2022(dir_path: Path) -> pd.DataFrame:
    records = []
    for csv in sorted(dir_path.glob("* AMI Counties 2022.csv")):
        try:
            df = pd.read_csv(csv, low_memory=False)
        except UnicodeDecodeError:
            df = pd.read_csv(csv, low_memory=False, encoding="latin1")

        df.columns = [c.strip() for c in df.columns]

        need_val = ["HINCP*UNITS","ELEP*UNITS","GASP*UNITS","FULP*UNITS"]
        for c in need_val:
            if c not in df.columns:
                raise KeyError(f"{c} missing in {csv.name}")

        units_col = "UNITS" if "UNITS" in df.columns else ("FREQUENCY" if "FREQUENCY" in df.columns else None)
        if units_col is None:
            raise KeyError(f"Neither UNITS nor FREQUENCY found: {csv.name}")
        seg_candidates = ["AMI150","TEN","TEN-YBL6","TEN-BLD","TEN-HFL","NAME"]
        seg_cols = [c for c in seg_candidates if c in df.columns]

        if "FIP" not in df.columns:
            raise KeyError(f"FIP missing in {csv.name}")
        df["county_fips"] = df["FIP"].astype(str).str.zfill(5)
        df["state_abbr"]  = df["county_fips"].str[:2].map(STATE_FIPS_TO_ABBR)
        cols = ["county_fips","state_abbr", units_col] + seg_cols + need_val
        tmp = df[cols].copy()
        tmp = tmp.rename(columns={units_col: "UNITS"})
        tmp["UNITS"] = pd.to_numeric(tmp["UNITS"], errors="coerce").fillna(0.0)
        for c in need_val:
            tmp[c] = pd.to_numeric(tmp[c], errors="coerce").fillna(0.0)
        records.append(tmp)

    lead = pd.concat(records, ignore_index=True)
    return lead

# ===================== 3) Income Growth Factors (2022=1) =====================
def load_income_growth(path: Path, sheet=0, years: List[int] = YEARS) -> Dict[int, float]:
    df = pd.read_excel(path, sheet_name=sheet)
    row_idx = None
    if df.columns.size >= 2:
        first_col = df.columns[0]
        mask = df[first_col].astype(str).str.contains("Real Dispos", case=False, na=False)
        if mask.any():
            row_idx = df[mask].index[-1]
    if row_idx is None:
        row_idx = df.dropna(how="all", axis=0).index[-1]
    row = df.loc[row_idx]
    growth = {}
    for y in years + [2022]:
        if y in df.columns:
            growth[y] = float(row[y])
        else:
            # Error if column missing
            raise KeyError(f"Income growth table missing year column {y}")
    if abs(growth[2022] - 1.0) > 1e-6:
        # Normalize to 2022=1 if not already
        base = growth[2022]
        for k in list(growth.keys()):
            growth[k] = growth[k] / base
    # Return only target years
    return {y: growth[y] for y in years}

# ===================== 4) ISO Price Ratios (vs 2022) =====================
def load_iso_price_ratios(dir_path: Path, years: list[int]) -> dict:
    import re, numpy as np, pandas as pd, openpyxl

    def _clean_zone(z: str, iso: str | None = None) -> str:
        if z is None or (isinstance(z, float) and pd.isna(z)):
            return ""
        s = str(z).strip()
        if iso:
            s = re.sub(rf"^\s*{re.escape(iso)}[\s_\-:/]+", "", s, flags=re.IGNORECASE)
        s = re.sub(r"[^A-Za-z0-9]+", "", s)
        return s.upper()

    def _iso_from_filename(p: Path) -> str:
        return re.sub(r"[^A-Za-z]", "", p.stem.split("_")[0]).upper()

    def _coerce_year_cols_to_int(df: pd.DataFrame) -> pd.DataFrame:
        mapper = {}
        for c in df.columns:
            cs = str(c).strip()
            if re.fullmatch(r"\d{4}", cs):
                mapper[c] = int(cs)
        return df.rename(columns=mapper)

    def _read_sheet_pandas(xlsx: Path, sheet: str) -> pd.DataFrame | None:
        try:
            return pd.read_excel(xlsx, sheet_name=sheet, engine="openpyxl")
        except Exception:
            return None

    def _read_sheet_openpyxl_data_only(xlsx: Path, sheet: str) -> pd.DataFrame | None:
        try:
            wb = openpyxl.load_workbook(xlsx, data_only=True, read_only=True)
            if sheet not in wb.sheetnames:
                return None
            ws = wb[sheet]
            data = list(ws.values)
            if not data:
                return None
            return pd.DataFrame(data[1:], columns=data[0])
        except Exception:
            return None

    def _numeric_series(s: pd.Series) -> pd.Series:
        s = s.copy()
        new_idx = []
        for k in s.index:
            ks = str(k).strip()
            if re.fullmatch(r"\d{4}", ks):
                new_idx.append(int(ks))
            else:
                new_idx.append(k)
        s.index = new_idx
        for k in list(s.index):
            if isinstance(k, int) and 2000 <= k <= 2100:
                s.loc[k] = pd.to_numeric(s.loc[k], errors="coerce")
        return s

    def _pick_row(df_idxed: pd.DataFrame, names: list[str]) -> pd.Series | None:
        for nm in names:
            if nm in df_idxed.index:
                return df_idxed.loc[nm]
        for nm in names:
            m = df_idxed.index.to_series().str.contains(re.escape(nm), case=False, na=False)
            if m.any():
                return df_idxed.loc[m].iloc[0]
        return None

    def _get_year_val(s: pd.Series, y: int) -> float | None:
        v = s.get(y, s.get(str(y), np.nan))
        return float(v) if pd.notna(v) else None

    ratios: dict = {}

    for xlsx in sorted(dir_path.glob("*_zone_prices.xlsx")):
        iso = _iso_from_filename(xlsx)
        ratios.setdefault(iso, {})
        try:
            sheets = pd.ExcelFile(xlsx).sheet_names
        except Exception:
            continue

        for sheet in sheets:
            # --- Correction: Explicit None check, no longer using 'or' to chain DataFrames ---
            df = _read_sheet_pandas(xlsx, sheet)
            if df is None or df.empty:
                df = _read_sheet_openpyxl_data_only(xlsx, sheet)
            if df is None or df.empty:
                continue

            df = _coerce_year_cols_to_int(df)
            # Find item column; fallback to first column
            item_col = None
            for c in df.columns:
                if str(c).strip().lower() == "item":
                    item_col = c
                    break
            if item_col is None:
                item_col = df.columns[0]
            df[item_col] = df[item_col].astype(str).str.strip().str.lower()
            df = df.set_index(item_col)

            s_total  = _pick_row(df, ["total"])
            s_minus  = _pick_row(df, ["total_minus_dc", "total minus dc", "total_minus_dc "])
            s_gshare = _pick_row(df, ["gen_share (=gen/total)", "gen_share", "gen share"])
            s_dc     = _pick_row(df, ["dc_cumulate", "dc_cumulative", "dc cumulate"])
            if s_total is None:
                continue

            s_total = _numeric_series(s_total)
            if s_minus is not None:
                s_minus = _numeric_series(s_minus)
            if s_gshare is not None:
                s_gshare = _numeric_series(s_gshare)
            if s_dc is not None:
                s_dc = _numeric_series(s_dc)

            # Rebuild if total_minus_dc missing/NaN
            need_rebuild = (s_minus is None) or (s_minus.dropna().empty) or pd.isna(s_minus.iloc[0])
            if need_rebuild:
                df2 = _read_sheet_openpyxl_data_only(xlsx, sheet)
                if df2 is not None and not df2.empty:
                    df2 = _coerce_year_cols_to_int(df2)
                    if item_col in df2.columns:
                        df2[item_col] = df2[item_col].astype(str).str.strip().str.lower()
                        df2 = df2.set_index(item_col)
                        s2 = _pick_row(df2, ["total_minus_dc", "total minus dc", "total_minus_dc "])
                        if s2 is not None:
                            s_minus = _numeric_series(s2)

            if (s_minus is None) or (s_minus.dropna().empty) or pd.isna(s_minus.iloc[0]):
                if (s_gshare is not None) and (s_dc is not None) and (not s_total.dropna().empty):
                    s_minus = s_total - s_dc
                else:
                    continue

            # Base year (first column assumed 2022)
            try:
                base_total = float(s_total.iloc[0])
            except Exception:
                continue
            if base_total == 0 or not np.isfinite(base_total):
                continue

            zone_clean = _clean_zone(sheet, iso)
            if not zone_clean:
                continue

            for y in years:
                v_tot = _get_year_val(s_total, y)
                v_min = _get_year_val(s_minus, y)
                if v_tot is None or v_min is None or not np.isfinite(v_tot) or not np.isfinite(v_min):
                    continue
                r_with = float(v_tot) / base_total
                r_no   = float(v_min) / base_total
                ratios[iso].setdefault(zone_clean, {})
                ratios[iso][zone_clean][f"with_dc_{y}"] = r_with
                ratios[iso][zone_clean][f"no_dc_{y}"]   = r_no

    return ratios

# ===================== 5) State Fuel Price Ratios (vs 2022) =====================
def load_state_fuel_price_ratios(path: Path, years: List[int] = YEARS) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Returns two DataFrames: gas_ratio[STATE_ABBR, year], oil_ratio[STATE_ABBR, year]
    Based on 'NG (Natural Gas)' and 'DFO (Distillate Fuel Oil)' rows.
    """
    gas_ratios = {}
    oil_ratios = {}
    xls = pd.ExcelFile(path)
    for sheet in xls.sheet_names:
        df = pd.read_excel(path, sheet_name=sheet)
        # Row name in first column
        idx_col = df.columns[0]
        df[idx_col] = df[idx_col].astype(str).str.strip()
        df = df.set_index(idx_col)
        def pick_row_exact(cand: List[str]) -> pd.Series | None:
            for nm in cand:
                if nm in df.index:
                    return df.loc[nm]
            # Fuzzy match fallback
            for nm in cand:
                m = df.index.to_series().str.contains(re.escape(nm), case=False, na=False)
                if m.any():
                    return df.loc[m].iloc[0]
            return None
        s_ng  = pick_row_exact(["NG (Natural Gas)", "Natural Gas"])
        s_dfo = pick_row_exact(["DFO (Distillate Fuel Oil)", "Distillate Fuel Oil"])
        if s_ng is None or s_dfo is None:
            continue
        base_ng  = float(s_ng[2022])
        base_dfo = float(s_dfo[2022])
        if base_ng == 0 or base_dfo == 0:
            continue
        gas_ratios[sheet] = {y: float(s_ng[y]) / base_ng for y in years if y in s_ng.index}
        oil_ratios[sheet] = {y: float(s_dfo[y]) / base_dfo for y in years if y in s_dfo.index}
    gas_df = pd.DataFrame.from_dict(gas_ratios, orient="index").sort_index()
    oil_df = pd.DataFrame.from_dict(oil_ratios, orient="index").sort_index()
    gas_df.index.name = "state_abbr"
    oil_df.index.name = "state_abbr"
    return gas_df, oil_df

In [None]:
OUT_DIR.mkdir(parents=True, exist_ok=True)
# 0. Mapping Table
m = load_county_map(PATH_MAP, MAP_SHEET)
# ISOs with price tables
iso_price_files = list(DIR_ISO_PRICES.glob("*_zone_prices.xlsx"))
iso_valid = set(clean_iso_name_from_filename(p) for p in iso_price_files)
# Filter: Discard counties with no ISO or ISO not in price tables
m = m[m["iso"].isin(iso_valid)].copy()

In [None]:
lead = load_lead_2022(DIR_LEAD_2022)
base = lead.merge(m, on="county_fips", how="inner")
income_growth = load_income_growth(PATH_INCOME_GROWTH, INCOME_SHEET, YEARS)

In [None]:
price_ratios = load_iso_price_ratios(DIR_ISO_PRICES, YEARS)
gas_ratio_df, oil_ratio_df = load_state_fuel_price_ratios(PATH_GAS_OIL, YEARS)

In [None]:
def split_and_clean_zones(series_zones, iso):
    tokens = []
    for z in series_zones.astype(str):
        # Allow spaces: "A | B"
        parts = re.split(r"\s*\|\s*", z)
        for p in parts:
            if p and p.lower() != "nan":
                tokens.append(clean_zone(p, iso))
    # Deduplicate and remove empty
    return sorted({t for t in tokens if t})

In [None]:
results_by_year = {y: [] for y in YEARS}
n_counties = base["county_fips"].nunique()
rows_appended = 0
skipped_no_zone = 0

pbar = tqdm(base.groupby("county_fips", sort=False), total=n_counties, desc="Processing counties", leave=True)
skipped_examples = []
for county, g in pbar:
    iso = g["iso"].iloc[0]
    state_abbr = g["state_abbr"].iloc[0]

    # Zones for this county (clean + deduplicate)
    zones = split_and_clean_zones(g["zone"], iso)
    zone_ratios = [price_ratios.get(iso, {}).get(z, None) for z in zones]
    zone_ratios = [r for r in zone_ratios if r]
    if not zone_ratios:
        skipped_no_zone += 1
        skipped_examples.append({
            "county_fips": county, "iso": iso, "zones_raw": list(g["zone"].astype(str)),
            "zones_clean": zones})
        try:
            pbar.set_postfix(rows=rows_appended, skipped_no_zone=skipped_no_zone)
        except Exception:
            pass
        continue

    def avg_ratio(key: str, year: int):
        vals = [r.get(f"{key}_{year}") for r in zone_ratios if r.get(f"{key}_{year}") is not None]
        return float(np.mean(vals)) if vals else None

    # Segmentation columns (keep whichever exist)
    seg_cols = [c for c in ["AMI150","TEN","TEN-YBL6","TEN-BLD","TEN-HFL","NAME"] if c in g.columns]

    for _, row in g.iterrows():
        inc0  = float(row["HINCP*UNITS"])
        elec0 = float(row["ELEP*UNITS"])
        gas0  = float(row["GASP*UNITS"])
        fuel0 = float(row["FULP*UNITS"])

        for y in YEARS:
            r_with = avg_ratio("with_dc", y)
            r_no   = avg_ratio("no_dc", y)
            if r_with is None or r_no is None:
                continue

            income = inc0 * float(income_growth[y])
            elec_with = elec0 * r_with
            elec_no   = elec0 * r_no

            try:
                gas_ratio = float(gas_ratio_df.loc[state_abbr, y])
                oil_ratio = float(oil_ratio_df.loc[state_abbr, y])
            except Exception:
                continue
            gas_cost  = gas0  * gas_ratio
            fuel_cost = fuel0 * oil_ratio

            if income <= 0:
                burden_with = np.nan
                burden_no   = np.nan
            else:
                burden_with = 100.0 * (elec_with + gas_cost + fuel_cost) / income
                burden_no   = 100.0 * (elec_no   + gas_cost + fuel_cost) / income

            out_row = {
                "county_fips": county,
                "iso": iso,
                "zone_list": ";".join(zones),
                "state": state_abbr,
                "UNITS": float(row["UNITS"]),
                "income_total": income,
                "elec_with_dc": elec_with,
                "elec_no_dc":   elec_no,
                "gas":  gas_cost,
                "fuel": fuel_cost,
                "energy_burden_with_dc_%": burden_with,
                "energy_burden_no_dc_%":   burden_no,
            }
            for c in seg_cols:
                out_row[c] = row[c]

            results_by_year[y].append(out_row)
            rows_appended += 1

    # Update pbar stats
    try:
        pbar.set_postfix(rows=rows_appended, skipped_no_zone=skipped_no_zone)
    except Exception:
        pass

# Optional: Print summary
print(f"Done. counties={n_counties}, skipped_no_zone={skipped_no_zone}, rows_appended={rows_appended}")



In [None]:
from math import ceil
try:
    from tqdm.auto import tqdm
except Exception:
    def tqdm(x, **k): return x

CSV_CHUNK_ROWS = 2_000_000   # Chunk size for writing huge tables (adjustable)

# 5) Output CSVs (One per Year)
for y in YEARS:
    dfy = pd.DataFrame(results_by_year[y])
    if dfy.empty:
        continue

    out_csv = OUT_DIR / f"energy_burden_{y}.csv"
    n = len(dfy)

    if n <= CSV_CHUNK_ROWS:
        # Standard Write
        dfy.to_csv(out_csv, index=False, encoding="utf-8-sig")
    else:
        # Stream Write (Low Memory/IO pressure)
        it = tqdm(range(0, n, CSV_CHUNK_ROWS), desc=f"Writing CSV {y}")
        first = True
        for s in it:
            e = min(s + CSV_CHUNK_ROWS, n)
            dfy.iloc[s:e].to_csv(
                out_csv,
                mode="w" if first else "a",
                header=first,
                index=False,
                encoding="utf-8-sig",
            )
            first = False

print(f"Done! CSVs written to: {OUT_DIR.resolve()}")