In [6]:
"""
Generación de figuras para el análisis de correlaciones NO₂–tráfico y NO₂–meteorología.

Uso:
    python generacion_figuras_correlaciones.py --input 7_4_no2_with_traffic_and_1meteo_and_1trafic_id.parquet --outdir figs

Requisitos:
    Python ≥3.9 y los paquetes:
        pandas, numpy, matplotlib, scipy, statsmodels, pyarrow

Cada figura se guarda en formato PNG con 300 dpi en la carpeta indicada.
No se especifican paletas ni colores personalizados; se utilizan los defaults
para cumplir con las directrices de visualización.
"""

from __future__ import annotations

import argparse
import os
import warnings
from functools import partial

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from statsmodels.nonparametric.smoothers_lowess import lowess

warnings.filterwarnings("ignore", category=RuntimeWarning)

TARGET = "no2_value"
TRAFFIC_VARS = ["carga", "intensidad", "ocupacion"]
METEO_VARS = [
    "wind_speed",
    "t2m",
    "ssrd",
    "ssr",
    "sp",
    "d2m",
    "u10",
    "v10",
    "tp",
]
SCALES = {
    "Hourly": None,  # sin re-muestreo
    "Daily": "D",
    "Weekly": "W",
    "Monthly": "M",
}


def _ensure_datetime_index(df: pd.DataFrame) -> pd.DataFrame:
    if not isinstance(df.index, pd.DatetimeIndex):
        # Se asume presencia de columna "datetime"
        df = df.set_index(pd.to_datetime(df["fecha"]))
    return df.sort_index()


def _aggregate(df: pd.DataFrame, freq: str | None) -> pd.DataFrame:
    if freq is None:
        return df.copy()
    return df.resample(freq).mean(numeric_only=True)


def _corr_and_p(x: pd.Series, y: pd.Series) -> tuple[float, float]:
    mask = (~x.isna()) & (~y.isna())
    if mask.sum() < 3:
        return np.nan, np.nan
    r, p = pearsonr(x[mask], y[mask])
    return r, p


def correlation_matrix_across_scales(
    df: pd.DataFrame,
    variables: list[str],
) -> tuple[np.ndarray, np.ndarray]:
    """Devuelve dos matrices (r, p) de shape (len(variables), len(SCALES))."""
    r_mat = np.zeros((len(variables), len(SCALES))) * np.nan
    p_mat = np.zeros_like(r_mat)
    for j, (scale_name, freq) in enumerate(SCALES.items()):
        agg = _aggregate(df, freq)
        for i, var in enumerate(variables):
            r, p = _corr_and_p(agg[TARGET], agg[var])
            r_mat[i, j] = r
            p_mat[i, j] = p
    return r_mat, p_mat


def plot_heatmap(r_mat: np.ndarray, p_mat: np.ndarray, variables: list[str], title: str, outpath: str):
    fig, ax = plt.subplots(figsize=(6, 3 + 0.3 * len(variables)))
    cmap = plt.cm.get_cmap("viridis")
    im = ax.imshow(r_mat, aspect="auto", cmap=cmap, vmin=-1, vmax=1)

    # sombreado de no significativos
    sig_mask = p_mat >= 0.05
    alpha_mask = np.where(sig_mask, 0.3, 1.0)
    im.set_alpha(alpha_mask)

    # etiquetas
    ax.set_xticks(np.arange(r_mat.shape[1]))
    ax.set_yticks(np.arange(r_mat.shape[0]))
    ax.set_xticklabels(list(SCALES.keys()))
    ax.set_yticklabels(variables)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    # anotación de valores y n
    for i in range(r_mat.shape[0]):
        for j in range(r_mat.shape[1]):
            text = ax.text(j, i, f"{r_mat[i, j]:.2f}", ha="center", va="center", color="white")

    ax.set_title(title)
    fig.colorbar(im, ax=ax, label="r (Pearson)")
    fig.tight_layout()
    fig.savefig(outpath, dpi=300)
    plt.close(fig)


def correlation_vs_lag(
    df: pd.DataFrame,
    variables: list[str],
    max_lag: int = 12,
    n_boot: int = 500,
    outpath: str = "fig_a2_cor_vs_lag.png",
):
    lags = np.arange(0, max_lag + 1)
    fig, ax = plt.subplots(figsize=(6, 4))

    for var in variables:
        r_vals = []
        ci_low = []
        ci_high = []
        for lag in lags:
            shifted = df[var].shift(lag)
            r, _ = _corr_and_p(df[TARGET], shifted)
            r_vals.append(r)
            # Bootstrap para IC
            boot_rs = []
            mask = (~shifted.isna()) & (~df[TARGET].isna())
            xy = np.column_stack([shifted[mask].values, df[TARGET][mask].values])
            n = len(xy)
            if n > 30:
                for _ in range(n_boot):
                    idx = np.random.randint(0, n, n)
                    boot_r, _ = pearsonr(xy[idx, 0], xy[idx, 1])
                    boot_rs.append(boot_r)
                ci_low.append(np.percentile(boot_rs, 2.5))
                ci_high.append(np.percentile(boot_rs, 97.5))
            else:
                ci_low.append(np.nan)
                ci_high.append(np.nan)
        ax.plot(lags, r_vals, label=var)
        ax.fill_between(lags, ci_low, ci_high, alpha=0.2)

    ax.set_xlabel("Desfase (h)")
    ax.set_ylabel("r (Pearson)")
    ax.set_title("Figura a.2  Correlación frente a desfase")
    ax.legend()
    fig.tight_layout()
    fig.savefig(outpath, dpi=300)
    plt.close(fig)


def scatter_hexbin_reg(df: pd.DataFrame, var: str, lag: int, outpath: str):
    shifted = df[var].shift(lag)
    mask = (~shifted.isna()) & (~df[TARGET].isna())
    x = shifted[mask].values
    y = df[TARGET][mask].values

    fig, ax = plt.subplots(figsize=(6, 4))
    hb = ax.hexbin(x, y, gridsize=60, mincnt=1)
    ax.set_xlabel(f"{var} (lag={lag} h)")
    ax.set_ylabel("NO₂ (µg/m³)")

    # regresión lineal
    if len(x) > 2:
        coef = np.polyfit(x, y, 1)
        x_line = np.linspace(x.min(), x.max(), 100)
        y_line = np.polyval(coef, x_line)
        ax.plot(x_line, y_line, lw=2)
        ax.text(0.05, 0.95, f"y={coef[0]:.2f}x+{coef[1]:.2f}", transform=ax.transAxes, ha="left", va="top")

    ax.set_title("Figura a.3  Dispersión NO₂ vs carga (lag 4 h)")
    fig.tight_layout()
    fig.savefig(outpath, dpi=300)
    plt.close(fig)


def bar_coefficients_evolution(coeff_dict: dict[str, list[float]], outpath: str):
    """
    Figura b.2 – Evolución de los coeficientes de correlación meteorológicos
    con mejor legibilidad y colores diferenciados.
    """
    variables = list(coeff_dict.keys())
    scales    = list(SCALES.keys())
    x         = np.arange(len(scales))
    width     = 0.12                                # barras un poco más estrechas
    cmap      = plt.cm.get_cmap("Set2", len(variables))

    fig, ax = plt.subplots(figsize=(7, 4.5))
    for i, var in enumerate(variables):
        ax.bar(
            x + i * width,
            coeff_dict[var],
            width,
            label=var,
            color=cmap(i),                          # color distinto por variable
            edgecolor="black",
            alpha=0.85
        )

    ax.set_xticks(x + width * (len(variables) - 1) / 2)
    ax.set_xticklabels(scales, rotation=45, ha="right")
    ax.set_ylabel("Coeficiente de correlación (r)")
    ax.set_title("Evolución de coeficientes meteorológicos")

    ax.axhline(0, color="grey", linewidth=0.8)      # línea base
    ax.grid(axis="y", ls="--", alpha=0.4)           # rejilla suave
    ax.legend(frameon=False, ncol=max(1, len(variables) // 2))

    fig.tight_layout()
    fig.savefig(outpath, dpi=300, bbox_inches="tight")
    plt.close(fig)



def scatter_lowess(df: pd.DataFrame, var: str, frac: float, outpath: str):
    mask = (~df[var].isna()) & (~df[TARGET].isna())
    x = df[var][mask]
    y = df[TARGET][mask]

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.scatter(x, y, s=10, alpha=0.3)

    # LOWESS
    if len(x) > 30:
        lowess_res = lowess(y, x, frac=frac, return_sorted=True)
        ax.plot(lowess_res[:, 0], lowess_res[:, 1], lw=2)
    ax.set_xlabel(var)
    ax.set_ylabel("NO₂ (µg/m³)")
    ax.set_title(f"Dispersión NO₂ vs {var}")
    fig.tight_layout()
    fig.savefig(outpath, dpi=300)
    plt.close(fig)


def wind_rose(df: pd.DataFrame, outpath: str):
    if not {"u10", "v10"}.issubset(df.columns):
        print("No se dispone de componentes u10/v10 para generar la rosa de viento.")
        return
    dir_rad = np.arctan2(df["u10"], df["v10"])
    dir_deg = np.degrees(dir_rad) % 360
    bins = np.arange(0, 360 + 30, 30)
    quintiles = pd.qcut(df[TARGET], 5, labels=False, duplicates="drop")

    fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(6, 6))
    for q in range(5):
        mask = quintiles == q
        counts, _ = np.histogram(dir_deg[mask], bins=bins)
        angles = np.deg2rad(bins[:-1] + 15)
        ax.bar(angles, counts, width=np.deg2rad(30), bottom=0.0, alpha=0.2 + 0.15 * q, label=f"Q{q + 1}")

    ax.set_theta_zero_location("N")
    ax.set_theta_direction(-1)
    ax.set_title("Rosa de viento ponderada por quintiles de NO₂")
    ax.legend(loc="upper right", bbox_to_anchor=(1.1, 1.1))
    fig.tight_layout()
    fig.savefig(outpath, dpi=300)
    plt.close(fig)


def main(input_path: str, outdir: str):
    os.makedirs(outdir, exist_ok=True)
    df = pd.read_parquet(input_path, engine="pyarrow")
    df = _ensure_datetime_index(df)

    # Figura a.1 — tráfico-NO₂
    r_traf, p_traf = correlation_matrix_across_scales(df, TRAFFIC_VARS)
    plot_heatmap(r_traf, p_traf, TRAFFIC_VARS, "Correlaciones tráfico - NO₂", os.path.join(outdir, "fig_a1_heatmap_traf.png"))

    # Figura a.2 — perfil lag
    correlation_vs_lag(df, TRAFFIC_VARS, outpath=os.path.join(outdir, "fig_a2_cor_vs_lag.png"))

    # Figura a.3 — dispersión hexbin
    scatter_hexbin_reg(df, "carga", lag=4, outpath=os.path.join(outdir, "fig_a3_hexbin_carga_lag4.png"))

    # Figura b.1 — meteorología-NO₂
    r_met, p_met = correlation_matrix_across_scales(df, METEO_VARS)
    plot_heatmap(r_met, p_met, METEO_VARS, "Figura b.1  Correlaciones meteorología–NO₂", os.path.join(outdir, "fig_b1_heatmap_meteo.png"))

    # Figura b.2 — barras evolución
    coeff_dict = {var: r_met[i, :] for i, var in enumerate(METEO_VARS)}
    # seleccionar top-5 por |r| en Monthly
    top5 = sorted(METEO_VARS, key=lambda v: abs(coeff_dict[v][-1]), reverse=True)[:5]
    bar_coefficients_evolution({v: coeff_dict[v] for v in top5}, os.path.join(outdir, "fig_b2_barras_evolucion.png"))

    # Figura b.3 — dispersión LOWESS
    scatter_lowess(df, "wind_speed", 0.2, os.path.join(outdir, "fig_b3_wind_speed.png"))
    scatter_lowess(df, "t2m", 0.2, os.path.join(outdir, "fig_b3_t2m.png"))

    # Figura b.4 — rosa de viento ponderada
    wind_rose(df, os.path.join(outdir, "fig_b4_wind_rose.png"))

main('../data/super_processed/7_4_no2_with_traffic_and_1meteo_and_1trafic_id.parquet', '.')


  return df.resample(freq).mean(numeric_only=True)
  cmap = plt.cm.get_cmap("viridis")
  return df.resample(freq).mean(numeric_only=True)
  cmap = plt.cm.get_cmap("viridis")
  cmap      = plt.cm.get_cmap("Set2", len(variables))
