## 🔗 Open This Notebook in Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DavidLangworthy/ds4s/blob/master/Day%203_%20Pollution%20and%20Public%20Health.ipynb)

# 🌫️ Day 3 – Visualizing Pollution and Public Health
### How Air Quality and Economic Development Intersect

Today you’ll blend two datasets, add diagnostics, and build an explorable scatter plot that connects air pollution with income.

### 🗂️ Data Card
| Field | Details |
| --- | --- |
| **Dataset** | World Bank – PM2.5 Exposure & GDP per Capita |
| **Source & link** | World Bank DataBank — [EN.ATM.PM25.MC.M3](https://data.worldbank.org/indicator/EN.ATM.PM25.MC.M3) |
| **Temporal / spatial coverage** | Country-level, annual, 1990–2019 |
| **Key units** | PM2.5 exposure (µg/m³), GDP per capita (current USD) |
| **Method & caveats** | GDP data use current USD; PM2.5 exposure estimates modelled from satellite retrievals. Missing data in small nations are common. |

### ⏱️ Learning Path for Today

            Each loop takes about 10–15 minutes:
            - [ ] Load both datasets and extract the target year.
- [ ] Run diagnostics (shape, nulls, ranges) before merging.
- [ ] Merge, clean, and engineer bins for interpretation.
- [ ] Build an interactive scatter with storytelling and accessibility.

            > 👩‍🏫 **Teacher tip:** Use these checkpoints for quick formative assessment. Have students raise a colored card after each check cell to signal confidence or questions.

> ### 👩‍🏫 Teacher Sidebar
> **Suggested timing:** ~45 minutes including gallery walk.
>
> **Likely misconceptions:** Interpreting correlation as causation; forgetting to log-scale income axis.
>
> **Fast finisher extension:** Segment by World Bank income groups or regions and compare slopes.

In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Mapping, Sequence

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

try:
    import plotly.express as px  # noqa: F401 - imported for student use
except ModuleNotFoundError:  # pragma: no cover - Plotly installed in Colab
    px = None

pd.options.display.float_format = "{:.2f}".format
sns.set_theme(style="whitegrid", context="talk")
plt.rcParams.update(
    {
        "axes.titlesize": 18,
        "axes.titleweight": "bold",
        "axes.labelsize": 13,
        "axes.grid": True,
        "grid.alpha": 0.25,
        "figure.dpi": 120,
        "axes.spines.top": False,
        "axes.spines.right": False,
    }
)

STORY_KEYS = (
    "title",
    "subtitle",
    "claim",
    "evidence",
    "visual",
    "takeaway",
    "source",
    "units",
    "annotation",
    "alt_text",
)


def load_csv(path: Path, *, description: str = "", **read_kwargs) -> pd.DataFrame:
    df = pd.read_csv(path, **read_kwargs)
    label = description or path.name
    print(
        f"✅ Loaded {label} with shape {df.shape[0]} rows × {df.shape[1]} columns."
    )
    return df


def validate_columns(
    df: pd.DataFrame, required: Sequence[str], *, df_name: str = "DataFrame"
) -> None:
    missing = [col for col in required if col not in df.columns]
    if missing:
        raise ValueError(f"{df_name} is missing columns: {missing}")
    print(f"✅ {df_name} includes required columns: {', '.join(required)}")


def expect_rows_between(
    df: pd.DataFrame,
    lower: int,
    upper: int,
    *,
    df_name: str = "DataFrame",
) -> None:
    rows = len(df)
    if not (lower <= rows <= upper):
        raise ValueError(
            f"{df_name} has {rows} rows; expected between {lower} and {upper}."
        )
    print(f"✅ {df_name} row count {rows} within [{lower}, {upper}].")


def quick_null_check(df: pd.DataFrame, *, df_name: str = "DataFrame") -> pd.Series:
    nulls = df.isna().sum()
    print(f"{df_name} missing values per column:\n{nulls}")
    return nulls


def quick_preview(
    df: pd.DataFrame, *, n: int = 5, df_name: str = "DataFrame"
) -> pd.DataFrame:
    print(f"🔍 Previewing {df_name} (first {n} rows):")
    return df.head(n)


def numeric_sanity_check(
    series: pd.Series,
    *,
    minimum: float | None = None,
    maximum: float | None = None,
    name: str = "Series",
) -> None:
    if minimum is not None and series.min() < minimum:
        raise ValueError(
            f"{name} has values below the expected minimum of {minimum}."
        )
    if maximum is not None and series.max() > maximum:
        raise ValueError(
            f"{name} has values above the expected maximum of {maximum}."
        )
    print(
        f"✅ {name} within expected range"
        f"{f' ≥ {minimum}' if minimum is not None else ''}"
        f"{f' and ≤ {maximum}' if maximum is not None else ''}."
    )


def story_fields_are_complete(story: Mapping[str, str]) -> None:
    missing = [key for key in STORY_KEYS if not str(story.get(key, "")).strip()]
    if missing:
        raise ValueError(
            "Please complete the storytelling scaffold before plotting: "
            + ", ".join(missing)
        )
    print(
        "✅ Story scaffold complete (title, subtitle, claim, evidence, visual,"
        " takeaway, source, units, annotation, alt text)."
    )


def print_story_scaffold(story: Mapping[str, str]) -> None:
    story_fields_are_complete(story)
    print("\n📖 Story Scaffold")
    print(f"Claim: {story['claim']}")
    print(f"Evidence: {story['evidence']}")
    print(f"Visual focus: {story['visual']}")
    print(f"Takeaway: {story['takeaway']}")
    print(f"Source: {story['source']} ({story['units']})")


def apply_matplotlib_story(ax: plt.Axes, story: Mapping[str, str]) -> None:
    story_fields_are_complete(story)
    ax.set_title(f"{story['title']}\n{story['subtitle']}", loc="left", pad=18)
    ax.figure.text(
        0.01,
        -0.08,
        (
            f"Claim: {story['claim']} | Evidence: {story['evidence']}"
            f" | Takeaway: {story['takeaway']}"
            f"\nSource: {story['source']} • Units: {story['units']}"
        ),
        ha="left",
        fontsize=10,
    )


def annotate_callout(
    ax: plt.Axes,
    *,
    xy: tuple[float, float],
    xytext: tuple[float, float],
    text: str,
) -> None:
    ax.annotate(
        text,
        xy=xy,
        xytext=xytext,
        arrowprops=dict(arrowstyle="->", color="black", lw=1),
        bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", alpha=0.8),
    )


def record_alt_text(text: str) -> None:
    print(f"📝 Alt text ready: {text}")


def accessibility_checklist(
    *, palette: str, has_alt_text: bool, contrast_passed: bool = True
) -> None:
    print("♿ Accessibility checklist:")
    print(f" • Palette: {palette}")
    print(
        f" • Alt text provided: {'yes' if has_alt_text else 'add alt text before sharing'}"
    )
    print(f" • Contrast OK: {'yes' if contrast_passed else 'adjust colors'}")


def save_figure(fig: plt.Figure, filename: str) -> Path:
    plots_dir = Path.cwd() / "plots"
    plots_dir.mkdir(parents=True, exist_ok=True)
    output_path = plots_dir / filename
    fig.savefig(output_path, dpi=300, bbox_inches="tight")
    print(f"💾 Saved figure to {output_path}")
    return output_path


def save_plotly_figure(fig, filename: str) -> Path:
    plots_dir = Path.cwd() / "plots"
    plots_dir.mkdir(parents=True, exist_ok=True)
    html_path = plots_dir / filename.replace(".png", ".html")
    fig.write_html(html_path)
    print(f"💾 Saved interactive figure to {html_path}")
    try:
        static_path = plots_dir / filename
        fig.write_image(str(static_path))
        print(f"💾 Saved static image to {static_path}")
    except Exception as exc:  # pragma: no cover - depends on kaleido
        print(f"⚠️ Static export skipped: {exc}")
    return html_path

In [None]:
from pathlib import Path

DATA_DIR = Path.cwd() / "data"
PLOTS_DIR = Path.cwd() / "plots"
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Plots directory: {PLOTS_DIR}")

## Loop 1 · Load & Slice 2019 Data
Work with a single recent year for a clear snapshot.

In [None]:
pm_path = DATA_DIR / "pm25_exposure.csv"
gdp_path = DATA_DIR / "gdp_per_country.csv"

df_pm = load_csv(pm_path, description="World Bank PM2.5 exposure")
df_gdp = load_csv(gdp_path, description="World Bank GDP per capita")

validate_columns(
    df_pm,
    ["Country Name", "Country Code", "2019"],
    df_name="PM dataset",
)
validate_columns(
    df_gdp,
    ["Country Name", "Country Code", "2019"],
    df_name="GDP dataset",
)

df_pm_2019 = df_pm[["Country Name", "Country Code", "2019"]].rename(
    columns={"2019": "PM25"}
)
df_gdp_2019 = df_gdp[["Country Name", "Country Code", "2019"]].rename(
    columns={"2019": "GDP_per_capita"}
)

In [None]:
quick_preview(df_pm_2019, n=5, df_name="PM 2019")
quick_preview(df_gdp_2019, n=5, df_name="GDP 2019")

## Loop 2 · Merge & Diagnose
Catch mismatched records and nulls before plotting.

In [None]:
df_merged = pd.merge(
    df_pm_2019,
    df_gdp_2019,
    on=["Country Name", "Country Code"],
    how="inner",
)
df_merged = df_merged.dropna()
df_merged["GDP_per_capita"] = pd.to_numeric(df_merged["GDP_per_capita"], errors="coerce")
df_merged = df_merged.dropna()

expect_rows_between(df_merged, 140, 190, df_name="PM25+GDP merged")
numeric_sanity_check(df_merged["PM25"], minimum=1, maximum=120, name="PM2.5 (µg/m³)")
numeric_sanity_check(
    df_merged["GDP_per_capita"], minimum=300, maximum=120000, name="GDP per capita (USD)"
)

In [None]:
quick_preview(df_merged, n=5, df_name="merged snapshot")

## Loop 3 · Enrich for Storytelling
Classify countries into broad income groups for annotation.

In [None]:
bins = [0, 4000, 13000, 25000, 1000000]
                labels = ["Low", "Lower-middle", "Upper-middle", "High"]
                df_merged["IncomeGroup"] = pd.cut(
                    df_merged["GDP_per_capita"], bins=bins, labels=labels, include_lowest=True
                )
                income_counts = df_merged["IncomeGroup"].value_counts().sort_index()
                print("Income group counts:
", income_counts)

## Loop 4 · Build the Plotly Story
Apply the storytelling scaffold, annotation, and accessibility checks.

In [None]:
worst_pm = df_merged.nlargest(1, "PM25").iloc[0]
cleanest = df_merged.nsmallest(1, "PM25").iloc[0]

story = {
    "title": "Air Pollution Drops as Economies Grow — With Stark Exceptions",
    "subtitle": "PM2.5 exposure vs. GDP per capita (2019)",
    "claim": "Higher-income countries generally breathe cleaner air, yet some lower-middle income nations face severe pollution.",
    "evidence": (
        f"{worst_pm['Country Name']} reports PM2.5 above {worst_pm['PM25']:.0f} µg/m³ while {cleanest['Country Name']} sits near {cleanest['PM25']:.0f} µg/m³."
    ),
    "visual": "Log-scale scatter with hover details and income-group color encoding.",
    "takeaway": "Economic growth helps but is not sufficient; targeted clean air policies are essential.",
    "source": "World Bank (2024 update)",
    "units": "PM2.5 (µg/m³), GDP per capita (USD)",
    "annotation": f"{worst_pm['Country Name']} faces the highest PM2.5 exposure",
    "alt_text": (
        "Scatter plot on log x-axis showing most wealthy countries clustered at low PM2.5 levels,"
        " while lower income nations span much higher pollution levels with wide variation."
    ),
}

print_story_scaffold(story)

In [None]:
fig = px.scatter(
    df_merged,
    x="GDP_per_capita",
    y="PM25",
    color="IncomeGroup",
    hover_name="Country Name",
    size="PM25",
    size_max=18,
    log_x=True,
    template="plotly_white",
    labels={
        "GDP_per_capita": "GDP per capita (USD, log scale)",
        "PM25": "PM2.5 exposure (µg/m³)",
    },
)

fig.update_layout(
    legend_title="World Bank income group",
    margin=dict(l=40, r=40, t=80, b=120),
)

fig.add_annotation(
    x=worst_pm["GDP_per_capita"],
    y=worst_pm["PM25"],
    text=story["annotation"],
    showarrow=True,
    arrowcolor="#d1495b",
)

story_fields_are_complete(story)
fig.update_layout(
    title=dict(
        text=f"<b>{story['title']}</b><br><sup>{story['subtitle']}</sup>",
        x=0,
        xanchor="left",
    )
)
fig.add_annotation(
    xref="paper",
    yref="paper",
    x=0,
    y=-0.25,
    align="left",
    showarrow=False,
    text=(
        f"Claim: {story['claim']}<br>Evidence: {story['evidence']}<br>Takeaway: {story['takeaway']}<br>Source: {story['source']} • Units: {story['units']}"
    ),
)

record_alt_text(story["alt_text"])
accessibility_checklist(
    palette="Colorblind-safe Plotly qualitative", has_alt_text=True
)

fig.show()

In [None]:
save_plotly_figure(fig, "day03_solution_plot.png")