04 — Ingest Retail Sales & Carbon Emissions
=============================================
Lumina Forecasting Hub

Part A: Retail Sales (monthly)
  - electricity/retail-sales → price, revenue, sales, customers by state & sector
  - Feeds Dashboard Page 3: Carbon & Cost Intelligence

Part B: Carbon Emissions (annual, from SEDS)
  - seds → CO2 emissions by state, energy source, and sector
  - Feeds Dashboard Page 3: Carbon intensity map, scatter plots

Usage in Colab:
  1. Run 01_setup_bigquery_schema.py first
  2. Set your EIA_API_KEY and GCP_PROJECT_ID below
  3. Run all cells

In [None]:
from google.colab import auth
auth.authenticate_user()

import requests
import pandas as pd
import numpy as np
import time
from google.cloud import bigquery

# ── Config ──────────────────────────────────────────────────────────
EIA_API_KEY    = "YOUR_EIA_API_KEY"       # <-- UPDATE
GCP_PROJECT_ID = "YOUR_GCP_PROJECT_ID"    # <-- UPDATE
BQ_DATASET     = "lumina"
EIA_BASE_URL   = "https://api.eia.gov/v2"

US_STATES = [
    "AL","AK","AZ","AR","CA","CO","CT","DE","FL","GA",
    "HI","ID","IL","IN","IA","KS","KY","LA","ME","MD",
    "MA","MI","MN","MS","MO","MT","NE","NV","NH","NJ",
    "NM","NY","NC","ND","OH","OK","OR","PA","RI","SC",
    "SD","TN","TX","UT","VT","VA","WA","WV","WI","WY","DC",
]

SECTORS = ["RES", "COM", "IND", "TRA", "OTH"]  # Residential, Commercial, Industrial, Transportation, Other

BACKFILL_START_MONTHLY = "2019-01"
BACKFILL_START_ANNUAL  = "2010"

client = bigquery.Client(project=GCP_PROJECT_ID)
print(f"Connected to BigQuery: {GCP_PROJECT_ID}")

# ── Retry helper for flaky EIA API ──────────────────────────────
def api_get_with_retry(url, params, max_retries=5):
    """GET request with exponential backoff retry for 5xx errors."""
    for attempt in range(max_retries):
        try:
            resp = requests.get(url, params=params, timeout=60)
            resp.raise_for_status()
            return resp
        except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
            if attempt < max_retries - 1:
                wait = 2 ** attempt * 5
                print(f"  API error ({e}), retrying in {wait}s (attempt {attempt+1}/{max_retries})")
                time.sleep(wait)
            else:
                raise



In [None]:
def fetch_retail_sales(sector_code, start, end=None):
    """
    Fetch monthly retail electricity sales for a given sector across all states.
    
    Returns: revenue (million $), sales (MWh), price (cents/kWh), customers
    """
    route = f"{EIA_BASE_URL}/electricity/retail-sales/data/"
    
    all_records = []
    offset = 0
    page_size = 5000
    
    while True:
        params = {
            "api_key": EIA_API_KEY,
            "frequency": "monthly",
            "data[0]": "revenue",
            "data[1]": "sales",
            "data[2]": "price",
            "data[3]": "customers",
            "facets[sectorid][]": sector_code,
            "sort[0][column]": "period",
            "sort[0][direction]": "asc",
            "offset": offset,
            "length": page_size,
        }
        if start:
            params["start"] = start
        if end:
            params["end"] = end
        
        resp = api_get_with_retry(route, params=params)
        body = resp.json()
        
        data = body.get("response", {}).get("data", [])
        total = int(body.get("response", {}).get("total", 0))
        
        if not data:
            break
        
        all_records.extend(data)
        offset += page_size
        
        print(f"  [Retail/{sector_code}] {len(all_records)}/{total}", end="\r")
        
        if offset >= total:
            break
        
        time.sleep(0.25)
    
    print(f"  [Retail/{sector_code}] {len(all_records)} records fetched")
    return pd.DataFrame(all_records) if all_records else pd.DataFrame()

In [None]:
def get_max_period_retail():
    query = f"""
    SELECT FORMAT_DATE('%Y-%m', MAX(period_month)) AS max_period
    FROM `{GCP_PROJECT_ID}.{BQ_DATASET}.fact_retail_sales`
    """
    try:
        result = client.query(query).to_dataframe()
        val = result["max_period"].iloc[0]
        return val if not pd.isna(val) else BACKFILL_START_MONTHLY
    except Exception:
        return BACKFILL_START_MONTHLY

In [None]:
MODE = "backfill"  # Change to "incremental" after first run

if MODE == "incremental":
    retail_start = get_max_period_retail()
else:
    retail_start = BACKFILL_START_MONTHLY

print(f"=== RETAIL SALES INGESTION (from {retail_start}) ===\n")

all_retail = []
for sector in SECTORS:
    df = fetch_retail_sales(sector, start=retail_start)
    if not df.empty:
        all_retail.append(df)

if all_retail:
    raw_retail = pd.concat(all_retail, ignore_index=True)
    print(f"\nRaw retail records: {len(raw_retail):,}")
    
    # ── Transform ────────────────────────────────────────────────────
    df = raw_retail.copy()
    
    # Filter to US states only
    if "stateid" in df.columns:
        df = df[df["stateid"].isin(US_STATES)].copy()
    
    # Parse period
    df["period_month"] = pd.to_datetime(df["period"] + "-01", format="%Y-%m-%d", errors="coerce")
    
    # Rename columns
    col_map = {
        "stateid": "state_code",
        "sectorid": "sector_code",
        "revenue": "revenue_musd",
        "sales": "sales_mwh",
        "price": "price_cents_kwh",
        "customers": "customers",
    }
    df = df.rename(columns={k: v for k, v in col_map.items() if k in df.columns})
    
    # Type casting
    for num_col in ["revenue_musd", "sales_mwh", "price_cents_kwh", "customers"]:
        if num_col in df.columns:
            df[num_col] = pd.to_numeric(df[num_col], errors="coerce")
    
    if "customers" in df.columns:
        df["customers"] = df["customers"].astype("Int64")  # Nullable integer
    
    # Final columns
    final_cols = ["period_month", "state_code", "sector_code", "revenue_musd", "sales_mwh", "price_cents_kwh", "customers"]
    for col in final_cols:
        if col not in df.columns:
            df[col] = None
    
    result_retail = df[final_cols].dropna(subset=["period_month"]).copy()
    result_retail = result_retail.sort_values(["period_month", "state_code", "sector_code"]).reset_index(drop=True)
    
    print(f"Cleaned retail records: {len(result_retail):,}")
    print(f"Date range: {result_retail['period_month'].min()} → {result_retail['period_month'].max()}")
    
    # ── Load to BigQuery ─────────────────────────────────────────────
    table_ref = f"{GCP_PROJECT_ID}.{BQ_DATASET}.fact_retail_sales"
    write_mode = "WRITE_TRUNCATE" if MODE == "backfill" else "WRITE_APPEND"
    job_config = bigquery.LoadJobConfig(write_disposition=write_mode)
    
    job = client.load_table_from_dataframe(result_retail, table_ref, job_config=job_config)
    job.result()
    print(f"Loaded {len(result_retail):,} rows to fact_retail_sales")
else:
    print("No retail sales data fetched.")

In [None]:
def fetch_seds_co2(start="2010"):
    """
    Fetch annual CO2 emissions from the SEDS dataset.
    
    SEDS series codes for CO2:
      - TETCB = Total energy CO2 emissions (million metric tons)
      - Various msn codes for breakdowns by source/sector
    
    We use the seds route with appropriate facets.
    """
    route = f"{EIA_BASE_URL}/seds/data/"
    
    # SEDS MSN codes for CO2 emissions (million metric tons CO2)
    # Pattern: XX TC B  where XX = source, TC = total consumption, B = billion Btu basis
    # CO2 codes end in 'CD' for carbon dioxide
    co2_msn_patterns = [
        "TETCB",  # Total energy-related CO2 emissions
        "CLTCB",  # Coal CO2
        "NNTCB",  # Natural gas CO2
        "PATCB",  # Petroleum CO2
    ]
    
    all_records = []
    offset = 0
    page_size = 5000
    
    while True:
        params = {
            "api_key": EIA_API_KEY,
            "frequency": "annual",
            "data[0]": "value",
            "sort[0][column]": "period",
            "sort[0][direction]": "asc",
            "offset": offset,
            "length": page_size,
            "start": start,
        }
        
        # Filter for CO2-related series
        for i, msn in enumerate(co2_msn_patterns):
            params[f"facets[seriesId][]"] = msn  # Note: API may need different handling
        
        resp = api_get_with_retry(route, params=params)
        body = resp.json()
        
        data = body.get("response", {}).get("data", [])
        total = int(body.get("response", {}).get("total", 0))
        
        if not data:
            break
        
        all_records.extend(data)
        offset += page_size
        
        print(f"  [SEDS/CO2] {len(all_records)}/{total}", end="\r")
        
        if offset >= total:
            break
        
        time.sleep(0.25)
    
    print(f"  [SEDS/CO2] {len(all_records)} records fetched")
    return pd.DataFrame(all_records) if all_records else pd.DataFrame()


def fetch_seds_co2_by_source(start="2010"):
    """
    Alternative approach: Fetch CO2 emissions using the co2-emissions route
    which is more structured for our needs.
    
    Route: /co2-emissions/co2-emissions-aggregates/data/
    Facets: sectorId, fuelId, stateId
    """
    route = f"{EIA_BASE_URL}/co2-emissions/co2-emissions-aggregates/data/"
    
    all_records = []
    offset = 0
    page_size = 5000
    
    while True:
        params = {
            "api_key": EIA_API_KEY,
            "frequency": "annual",
            "data[0]": "value",
            "sort[0][column]": "period",
            "sort[0][direction]": "asc",
            "offset": offset,
            "length": page_size,
            "start": start,
        }
        
        resp = requests.get(route, params=params)
        
        # This endpoint may be deprecated — fall back gracefully
        if resp.status_code != 200:
            print(f"  CO2 aggregates endpoint returned {resp.status_code}, trying SEDS route...")
            return pd.DataFrame()
        
        body = resp.json()
        data = body.get("response", {}).get("data", [])
        total = int(body.get("response", {}).get("total", 0))
        
        if not data:
            break
        
        all_records.extend(data)
        offset += page_size
        
        print(f"  [CO2 Aggregates] {len(all_records)}/{total}", end="\r")
        
        if offset >= total:
            break
        
        time.sleep(0.25)
    
    print(f"  [CO2 Aggregates] {len(all_records)} records fetched")
    return pd.DataFrame(all_records) if all_records else pd.DataFrame()

In [None]:
print(f"\n=== CARBON EMISSIONS INGESTION (from {BACKFILL_START_ANNUAL}) ===\n")

# Try the dedicated CO2 endpoint first, fall back to SEDS
raw_co2 = fetch_seds_co2_by_source(start=BACKFILL_START_ANNUAL)

if raw_co2.empty:
    print("Falling back to SEDS route for CO2 data...")
    raw_co2 = fetch_seds_co2(start=BACKFILL_START_ANNUAL)

if not raw_co2.empty:
    print(f"\nRaw CO2 records: {len(raw_co2):,}")
    print(f"Columns: {list(raw_co2.columns)}")
    
    df = raw_co2.copy()
    
    # The CO2 aggregates endpoint returns: period, stateId, sectorId, fuelId, value
    # The SEDS endpoint returns: period, stateId, seriesId, value
    
    # Normalize column names
    col_map = {
        "stateId": "state_code",
        "stateid": "state_code",
        "sectorId": "sector_code",
        "sectorid": "sector_code",
        "fuelId": "source_code",
        "fuelid": "source_code",
        "seriesId": "source_code",
        "value": "emissions_mmt",
    }
    df.columns = [col_map.get(c, c) for c in df.columns]
    
    # Filter to US states
    if "state_code" in df.columns:
        df = df[df["state_code"].isin(US_STATES)].copy()
    
    # Parse period as integer year
    df["period_year"] = pd.to_numeric(df["period"], errors="coerce").astype("Int64")
    df["emissions_mmt"] = pd.to_numeric(df.get("emissions_mmt", pd.Series(dtype=float)), errors="coerce")
    
    # Ensure required columns exist
    for col in ["state_code", "source_code", "sector_code"]:
        if col not in df.columns:
            df[col] = "ALL"
    
    final_cols = ["period_year", "state_code", "source_code", "sector_code", "emissions_mmt"]
    result_co2 = df[final_cols].dropna(subset=["period_year"]).copy()
    result_co2 = result_co2.sort_values(["period_year", "state_code"]).reset_index(drop=True)
    
    print(f"Cleaned CO2 records: {len(result_co2):,}")
    print(f"Year range: {result_co2['period_year'].min()} → {result_co2['period_year'].max()}")
    
    # ── Load to BigQuery ─────────────────────────────────────────────
    table_ref = f"{GCP_PROJECT_ID}.{BQ_DATASET}.fact_carbon_emissions"
    job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE")
    
    job = client.load_table_from_dataframe(result_co2, table_ref, job_config=job_config)
    job.result()
    print(f"Loaded {len(result_co2):,} rows to fact_carbon_emissions")
else:
    print("No CO2 data fetched. You may need to check available SEDS series codes.")
    print("Try exploring: https://api.eia.gov/v2/seds/?api_key=YOUR_KEY")
    print("Or use the SEDS bulk download: https://www.eia.gov/opendata/bulk/SEDS.zip")

In [None]:
print("\n=== RETAIL SALES QUALITY CHECK ===")
retail_quality = f"""
SELECT
    sector_code,
    COUNT(*) AS rows,
    COUNT(DISTINCT state_code) AS states,
    MIN(period_month) AS earliest,
    MAX(period_month) AS latest,
    ROUND(AVG(price_cents_kwh), 2) AS avg_price_cents,
    ROUND(SUM(sales_mwh) / 1e9, 2) AS total_sales_twh,
    ROUND(SUM(revenue_musd) / 1e3, 2) AS total_revenue_busd
FROM `{GCP_PROJECT_ID}.{BQ_DATASET}.fact_retail_sales`
GROUP BY sector_code
ORDER BY total_sales_twh DESC
"""
try:
    df_rq = client.query(retail_quality).to_dataframe()
    print(df_rq.to_string(index=False))
except Exception as e:
    print(f"Retail quality check error: {e}")

print("\n=== CARBON EMISSIONS QUALITY CHECK ===")
co2_quality = f"""
SELECT
    period_year,
    COUNT(DISTINCT state_code) AS states,
    ROUND(SUM(emissions_mmt), 1) AS total_mmt,
    COUNT(*) AS rows
FROM `{GCP_PROJECT_ID}.{BQ_DATASET}.fact_carbon_emissions`
GROUP BY period_year
ORDER BY period_year DESC
LIMIT 10
"""
try:
    df_cq = client.query(co2_quality).to_dataframe()
    print(df_cq.to_string(index=False))
except Exception as e:
    print(f"CO2 quality check error: {e}")

In [None]:
import matplotlib.pyplot as plt

scatter_query = f"""
WITH latest_price AS (
    SELECT
        state_code,
        AVG(price_cents_kwh) AS avg_price
    FROM `{GCP_PROJECT_ID}.{BQ_DATASET}.fact_retail_sales`
    WHERE sector_code = 'RES'
        AND period_month >= DATE_SUB(CURRENT_DATE(), INTERVAL 12 MONTH)
    GROUP BY state_code
),
latest_co2 AS (
    SELECT
        state_code,
        SUM(emissions_mmt) AS total_co2
    FROM `{GCP_PROJECT_ID}.{BQ_DATASET}.fact_carbon_emissions`
    WHERE period_year = (
        SELECT MAX(period_year) FROM `{GCP_PROJECT_ID}.{BQ_DATASET}.fact_carbon_emissions`
    )
    GROUP BY state_code
),
latest_gen AS (
    SELECT
        state_code,
        SUM(generation_mwh) / 1e3 AS total_gwh
    FROM `{GCP_PROJECT_ID}.{BQ_DATASET}.fact_monthly_generation`
    WHERE period_month >= DATE_SUB(CURRENT_DATE(), INTERVAL 12 MONTH)
        AND sector_code = '99'
    GROUP BY state_code
)
SELECT
    p.state_code,
    g.state_name,
    g.population,
    p.avg_price,
    SAFE_DIVIDE(c.total_co2 * 1e6, gen.total_gwh) AS co2_intensity_tons_gwh
FROM latest_price p
JOIN `{GCP_PROJECT_ID}.{BQ_DATASET}.dim_geography` g ON p.state_code = g.state_code
LEFT JOIN latest_co2 c ON p.state_code = c.state_code
LEFT JOIN latest_gen gen ON p.state_code = gen.state_code
WHERE gen.total_gwh > 0
"""

try:
    df_scatter = client.query(scatter_query).to_dataframe()
    
    if not df_scatter.empty:
        fig, ax = plt.subplots(figsize=(12, 8))
        
        sizes = df_scatter["population"] / 1e5  # Scale for visibility
        scatter = ax.scatter(
            df_scatter["co2_intensity_tons_gwh"],
            df_scatter["avg_price"],
            s=sizes,
            alpha=0.6,
            c=df_scatter["co2_intensity_tons_gwh"],
            cmap="RdYlGn_r",
            edgecolors="white",
            linewidth=0.5,
        )
        
        # Label top states
        for _, row in df_scatter.nlargest(10, "population").iterrows():
            ax.annotate(
                row["state_code"],
                (row["co2_intensity_tons_gwh"], row["avg_price"]),
                fontsize=8, ha="center", va="bottom",
            )
        
        ax.set_xlabel("CO2 Intensity (tons/GWh)", fontsize=12)
        ax.set_ylabel("Avg Residential Price (¢/kWh)", fontsize=12)
        ax.set_title("State-Level: Carbon Intensity vs. Electricity Price\n(bubble size = population)", fontsize=14)
        ax.grid(True, alpha=0.2)
        plt.colorbar(scatter, label="CO2 Intensity", shrink=0.8)
        plt.tight_layout()
        plt.show()
    else:
        print("No data for scatter plot.")
except Exception as e:
    print(f"Scatter plot error: {e}")
    print("This visualization requires all 3 fact tables to be populated.")