<a href="https://colab.research.google.com/github/Lanxuan-Luo/qm2-procrastinatepros/blob/main/Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import annotations

import argparse
import json
import os
import re
from dataclasses import dataclass
from io import StringIO
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

try:
    import requests
except Exception:
    requests = None

# only for DATA2 map.
try:
    import plotly.express as px
except Exception:
    px = None


STATE_TO_ABBR: Dict[str, str] = {
    "Alabama":"AL","Alaska":"AK","Arizona":"AZ","Arkansas":"AR","California":"CA","Colorado":"CO","Connecticut":"CT",
    "Delaware":"DE","Florida":"FL","Georgia":"GA","Hawaii":"HI","Idaho":"ID","Illinois":"IL","Indiana":"IN","Iowa":"IA",
    "Kansas":"KS","Kentucky":"KY","Louisiana":"LA","Maine":"ME","Maryland":"MD","Massachusetts":"MA","Michigan":"MI",
    "Minnesota":"MN","Mississippi":"MS","Missouri":"MO","Montana":"MT","Nebraska":"NE","Nevada":"NV","New Hampshire":"NH",
    "New Jersey":"NJ","New Mexico":"NM","New York":"NY","North Carolina":"NC","North Dakota":"ND","Ohio":"OH","Oklahoma":"OK",
    "Oregon":"OR","Pennsylvania":"PA","Rhode Island":"RI","South Carolina":"SC","South Dakota":"SD","Tennessee":"TN","Texas":"TX",
    "Utah":"UT","Vermont":"VT","Virginia":"VA","Washington":"WA","West Virginia":"WV","Wisconsin":"WI","Wyoming":"WY",
    "District of Columbia":"DC",
}
ABBR_TO_STATE: Dict[str, str] = {v: k for k, v in STATE_TO_ABBR.items()}

def ensure_outdir(path: str) -> str:
    os.makedirs(path, exist_ok=True)
    return path


def safe_read_csv(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    return pd.read_csv(path, low_memory=False)


def download_csv(url: str, timeout: int = 60) -> pd.DataFrame:
    if requests is None:
        raise RuntimeError("requests is not installed. Install it: python3 -m pip install requests")
    r = requests.get(url, timeout=timeout)
    r.raise_for_status()
    return pd.read_csv(StringIO(r.text))



In [2]:

# DATA1
def make_data1_global_bar(outdir: str, year: int = 2022, top_n: int = 15, force_include: str = "United States of America") -> str:
    """Global adult obesity prevalence by country (OWID grapher, WHO series)."""
    outdir = ensure_outdir(outdir)
    owid_url = "https://ourworldindata.org/grapher/share-of-adults-defined-as-obese.csv"
    owid = download_csv(owid_url)

    value_col = "share-of-adults-defined-as-obese"
    if value_col not in owid.columns:
        candidates = [c for c in owid.columns if c not in ("Entity","Code","Year")]
        if not candidates:
            raise ValueError(f"Could not find value column in OWID CSV. Columns: {list(owid.columns)}")
        value_col = candidates[0]

    df = owid[owid["Year"] == year].copy()
    df = df.rename(columns={value_col: "obesity_pct"})
    df["obesity_pct"] = pd.to_numeric(df["obesity_pct"], errors="coerce")
    df = df.dropna(subset=["obesity_pct"])

    top = df.sort_values("obesity_pct", ascending=False).head(top_n).copy()

    if force_include and (force_include not in set(top["Entity"])):
        usa = df[df["Entity"] == force_include].copy()
        if len(usa) > 0:
            top = pd.concat([top, usa], ignore_index=True)
            top = top.drop_duplicates(subset=["Entity"], keep="first")
            top = top.sort_values("obesity_pct", ascending=False)

    fig = plt.figure(figsize=(12, 8))
    plt.barh(top["Entity"][::-1], top["obesity_pct"][::-1])
    plt.xlabel("Adult obesity prevalence (%)")
    plt.title(f"Adult obesity prevalence by country (WHO, {year}, total)")
    plt.tight_layout()

    outpath = os.path.join(outdir, "DATA1_global_obesity_top15_plus_USA.png")
    fig.savefig(outpath, dpi=200)
    plt.close(fig)
    return outpath


In [4]:
# DATA2 loaders & plots
@dataclass
class Data2StateObesity:
    df: pd.DataFrame
    year: int


def load_data2_from_plotly_html(html_path: str, year: int = 2024) -> Data2StateObesity:
    """Parse a Plotly HTML to recover state values (hovertext, locations, z)."""
    if not os.path.exists(html_path):
        raise FileNotFoundError(f"DATA2 html not found: {html_path}")

    text = open(html_path, "r", encoding="utf-8", errors="ignore").read()

    hover_m = re.search(r'"hovertext":\s*\[(.*?)\]\s*,\s*"locationmode"', text, flags=re.S)
    loc_m = re.search(r'"locations":\s*\[(.*?)\]\s*,\s*"name"', text, flags=re.S)
    z_m = re.search(r'"z":\s*\[(.*?)\]\s*\}\]', text, flags=re.S)

    if not (hover_m and loc_m and z_m):
        raise ValueError("Could not parse hovertext/locations/z arrays from the Plotly HTML.")

    hover_raw = "[" + hover_m.group(1) + "]"
    loc_raw = "[" + loc_m.group(1) + "]"
    z_raw = "[" + z_m.group(1) + "]"

    states = json.loads(hover_raw)
    abbrs = json.loads(loc_raw)
    vals = json.loads(z_raw)

    df = pd.DataFrame({"state": states, "abbr": abbrs, "obesity_pct": vals})
    df["obesity_pct"] = pd.to_numeric(df["obesity_pct"], errors="coerce")
    df["is_state"] = df["abbr"].isin([a for a in ABBR_TO_STATE.keys() if a != "DC"])
    return Data2StateObesity(df=df, year=year)


def load_data2_from_csv(csv_path: str, year: int = 2024) -> Data2StateObesity:
    df = safe_read_csv(csv_path).copy()

    if "obesity_pct" not in df.columns:
        for alt in ("obesity_2024", "obesity", "value"):
            if alt in df.columns:
                df = df.rename(columns={alt: "obesity_pct"})
                break
    if "obesity_pct" not in df.columns:
        raise ValueError(f"CSV must contain obesity_pct (or obesity_2024). Columns: {list(df.columns)}")

    if "abbr" not in df.columns:
        if "state" in df.columns:
            df["abbr"] = df["state"].map(STATE_TO_ABBR)
        else:
            raise ValueError("CSV must contain either 'abbr' or 'state'")

    if "state" not in df.columns:
        df["state"] = df["abbr"].map(ABBR_TO_STATE)

    df["obesity_pct"] = pd.to_numeric(df["obesity_pct"], errors="coerce")
    df["is_state"] = df["abbr"].isin([a for a in ABBR_TO_STATE.keys() if a != "DC"])
    return Data2StateObesity(df=df, year=year)


def plot_data2_top10_bar(d2: Data2StateObesity, outdir: str) -> str:
    outdir = ensure_outdir(outdir)
    df = d2.df.copy()
    df = df[df["is_state"]].dropna(subset=["obesity_pct"]).sort_values("obesity_pct", ascending=False)
    top10 = df.head(10).copy()

    fig = plt.figure(figsize=(12, 8))
    plt.barh(top10["state"][::-1], top10["obesity_pct"][::-1])
    plt.xlabel("Obesity prevalence (%)")
    plt.title(f"Top 10 states by adult obesity prevalence (BRFSS {d2.year})")
    plt.tight_layout()

    outpath = os.path.join(outdir, f"DATA2_top10_states_bar_{d2.year}.png")
    fig.savefig(outpath, dpi=200)
    plt.close(fig)
    return outpath


def plot_data2_choropleth_html(d2: Data2StateObesity, outdir: str) -> Optional[str]:
    if px is None:
        return None

    outdir = ensure_outdir(outdir)
    df = d2.df.copy()
    df = df[df["abbr"].isin([a for a in ABBR_TO_STATE.keys() if a != "DC"])].dropna(subset=["obesity_pct"])

    fig = px.choropleth(
        df,
        locations="abbr",
        locationmode="USA-states",
        color="obesity_pct",
        scope="usa",
        hover_name="state",
        hover_data={"abbr": False, "obesity_pct": ":.1f"},
        title=f"USA adult obesity prevalence by state (BRFSS, {d2.year})",
    )

    outpath = os.path.join(outdir, f"DATA2_US_obesity_state_map_{d2.year}.html")
    fig.write_html(outpath, include_plotlyjs="cdn")
    return outpath


def get_top5_states_from_data2(d2: Data2StateObesity) -> List[str]:
    df = d2.df.copy()
    df = df[df["is_state"]].dropna(subset=["obesity_pct"]).sort_values("obesity_pct", ascending=False)
    return df.head(5)["state"].tolist()


In [None]:

# DATA3
def scatter_with_fit(
    data: pd.DataFrame,
    x: str,
    y: str,
    xlabel: str,
    ylabel: str,
    title: str,
    outpath: str,
) -> None:
    d = data[[x, y]].dropna()
    if len(d) < 5:
        fig = plt.figure(figsize=(7, 5))
        plt.title(title)
        plt.text(0.5, 0.5, f"Not enough data points (n={len(d)})", ha="center", va="center")
        plt.axis("off")
        plt.tight_layout()
        fig.savefig(outpath, dpi=200)
        plt.close(fig)
        return

    xvals = d[x].to_numpy(dtype=float)
    yvals = d[y].to_numpy(dtype=float)

    r = float(np.corrcoef(xvals, yvals)[0, 1])

    m, b = np.polyfit(xvals, yvals, 1)
    xx = np.linspace(np.nanmin(xvals), np.nanmax(xvals), 200)
    yy = m * xx + b

    fig = plt.figure(figsize=(7, 5))
    plt.scatter(xvals, yvals, s=14, alpha=0.75)
    plt.plot(xx, yy, linewidth=2)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(f"{title}\nPearson r = {r:.3f} (n={len(d)})")
    plt.tight_layout()
    fig.savefig(outpath, dpi=200)
    plt.close(fig)


def run_data3(
    outdir: str,
    chr_csv: str,
    election_csv: str,
    state_fullnames: List[str],
) -> Tuple[str, str, List[str]]:
    outdir = ensure_outdir(outdir)

    chr = safe_read_csv(chr_csv)
    elec = safe_read_csv(election_csv)

    needed_chr = [
        "5-digit FIPS Code",
        "State Abbreviation",
        "Name",
        "Adult obesity raw value",
        "Some college raw value",
        "Median household income raw value",
    ]
    missing_chr = [c for c in needed_chr if c not in chr.columns]
    if missing_chr:
        raise ValueError(
            "CHR file does not have expected columns. "
            f"Missing: {missing_chr}. Available (first 30): {list(chr.columns)[:30]}"
        )

    chr = chr[needed_chr].copy()
    chr["fips"] = chr["5-digit FIPS Code"].astype(str).str.zfill(5)

    chr["obesity_pct"] = pd.to_numeric(chr["Adult obesity raw value"], errors="coerce")
    chr["some_college_pct"] = pd.to_numeric(chr["Some college raw value"], errors="coerce")
    chr["median_income"] = pd.to_numeric(chr["Median household income raw value"], errors="coerce")

    if "county_fips" not in elec.columns:
        raise ValueError(f"Election file missing 'county_fips'. Columns (first 30): {list(elec.columns)[:30]}")
    if "per_gop" not in elec.columns:
        if "votes_gop" in elec.columns and "votes_dem" in elec.columns:
            g = pd.to_numeric(elec["votes_gop"], errors="coerce")
            d = pd.to_numeric(elec["votes_dem"], errors="coerce")
            elec["per_gop"] = g / (g + d)
        else:
            raise ValueError("Election file missing 'per_gop' and cannot compute it (no votes_gop/votes_dem).")

    elec["fips"] = elec["county_fips"].astype(str).str.zfill(5)
    elec["per_gop"] = pd.to_numeric(elec["per_gop"], errors="coerce")

    df = chr.merge(elec[["fips", "per_gop"]], on="fips", how="left")

    merged_path = os.path.join(outdir, "DATA3_merged_county_dataset.csv")
    df.to_csv(merged_path, index=False)

    if state_fullnames:
        st_abbrs = []
        for s in state_fullnames:
            s = s.strip()
            if not s:
                continue
            ab = STATE_TO_ABBR.get(s)
            if ab is None:
                raise ValueError(f"Unknown state name: '{s}'. Use full names like 'West Virginia'.")
            st_abbrs.append(ab)
        st_abbrs = st_abbrs[:5]
    else:
        state_means = (
            df.dropna(subset=["obesity_pct"])
              .groupby("State Abbreviation", as_index=False)["obesity_pct"]
              .mean()
              .sort_values("obesity_pct", ascending=False)
        )
        st_abbrs = state_means.head(5)["State Abbreviation"].tolist()

    top5_table = pd.DataFrame({"State Abbreviation": st_abbrs, "State": [ABBR_TO_STATE.get(a, a) for a in st_abbrs]})
    top5_path = os.path.join(outdir, "DATA3_top5_states_used.csv")
    top5_table.to_csv(top5_path, index=False)

    plot_paths: List[str] = []
    for ab in st_abbrs:
        d_st = df[df["State Abbreviation"] == ab].copy()
        st_name = ABBR_TO_STATE.get(ab, ab)

        p1 = os.path.join(outdir, f"DATA3_{ab}_correlation_education_some_college_vs_obesity.png")
        scatter_with_fit(
            d_st,
            x="some_college_pct",
            y="obesity_pct",
            xlabel="Some college (%)",
            ylabel="Adult obesity (%)",
            title=f"{st_name}: Education (Some college %) vs Adult obesity",
            outpath=p1,
        )
        plot_paths.append(p1)

        p2 = os.path.join(outdir, f"DATA3_{ab}_correlation_economic_median_income_vs_obesity.png")
        scatter_with_fit(
            d_st,
            x="median_income",
            y="obesity_pct",
            xlabel="Median household income (USD)",
            ylabel="Adult obesity (%)",
            title=f"{st_name}: Economic (Median household income) vs Adult obesity",
            outpath=p2,
        )
        plot_paths.append(p2)

        p3 = os.path.join(outdir, f"DATA3_{ab}_correlation_political_gop_share_2020_vs_obesity.png")
        scatter_with_fit(
            d_st,
            x="per_gop",
            y="obesity_pct",
            xlabel="GOP vote share (2020, county-level)",
            ylabel="Adult obesity (%)",
            title=f"{st_name}: Political proxy (GOP share 2020) vs Adult obesity",
            outpath=p3,
        )
        plot_paths.append(p3)

    return merged_path, top5_path, plot_paths

# CLI
def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--outdir", default="outputs", help="Output folder (default: ./outputs)")
    p.add_argument("--only", choices=["all", "data1", "data2", "data3"], default="all",
                   help="Run only a subset (default: all)")

    p.add_argument("--data2_year", type=int, default=2024)
    p.add_argument("--data2_html", default="DATA2_US_obesity_state_map_2024.html",
                   help="Existing Plotly HTML to parse for DATA2 (recommended).")
    p.add_argument("--data2_csv", default="",
                   help="Optional CSV for DATA2 instead of HTML. Must have state/abbr + obesity_pct.")

    p.add_argument("--chr_csv", default="analytic_data2022.csv",
                   help="CHR analytic data CSV (county-level), e.g., analytic_data2022.csv")
    p.add_argument("--election_csv", default="2020_US_County_Level_Presidential_Results.csv",
                   help="2020 county election CSV with county_fips and per_gop")
    p.add_argument("--states", default="",
                   help="Comma-separated 5 states to use for DATA3 (full names). "
                        "Example: 'West Virginia,Mississippi,Louisiana,Arkansas,Alabama'")

    return p.parse_args()


def main() -> None:
    args = parse_args()
    outdir = ensure_outdir(args.outdir)

    if args.only in ("all", "data1"):
        try:
            p1 = make_data1_global_bar(outdir=outdir, year=2022, top_n=15, force_include="United States of America")
            print("[DATA1] Saved:", p1)
        except Exception as e:
            print("[DATA1] Failed:", repr(e))

    d2_obj: Optional[Data2StateObesity] = None
    if args.only in ("all", "data2", "data3"):
        try:
            if args.data2_csv:
                d2_obj = load_data2_from_csv(args.data2_csv, year=args.data2_year)
            else:
                d2_obj = load_data2_from_plotly_html(args.data2_html, year=args.data2_year)

            if args.only in ("all", "data2"):
                p2 = plot_data2_top10_bar(d2_obj, outdir=outdir)
                print("[DATA2] Saved:", p2)
                p2m = plot_data2_choropleth_html(d2_obj, outdir=outdir)
                if p2m:
                    print("[DATA2] Saved:", p2m)
                else:
                    print("[DATA2] Plotly not installed; skipped map. Install: python3 -m pip install plotly")
        except Exception as e:
            print("[DATA2] Failed:", repr(e))
            d2_obj = None

    if args.only in ("all", "data3"):
        try:
            states_list: List[str] = []
            if args.states.strip():
                states_list = [s.strip() for s in args.states.split(",") if s.strip()]
                states_list = states_list[:5]
            elif d2_obj is not None:
                states_list = get_top5_states_from_data2(d2_obj)

            merged_path, top5_path, plot_paths = run_data3(
                outdir=outdir,
                chr_csv=args.chr_csv,
                election_csv=args.election_csv,
                state_fullnames=states_list,
            )
            print("[DATA3] Saved merged county dataset:", merged_path)
            print("[DATA3] Saved states used:", top5_path)
            print("[DATA3] Saved plots:", len(plot_paths))
        except Exception as e:
            print("[DATA3] Failed:", repr(e))


if __name__ == "__main__":
    main()
