In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from climakitae.core.data_interface import get_data, get_data_options

desired_ssps = ["SSP 1-2.6", "SSP 2-4.5", "SSP 3-7.0", "SSP 5-8.5"]
hist_years = (1970, 2014)
fut_years  = (2070, 2100)
res_3km = "3 km"
timescale = "monthly"

# Only prefer Statistical, never Dynamical
methods_pref = ["Statistical"]


def _get_data_safe(**kw):
    """
    Wrapper for get_data that enforces:
      - resolution == 3 km
      - downscaling_method is Statistical-only (no Dynamical anywhere)
    If get_data fails, returns None (no fallback to other methods).
    """
    # Enforce 3 km resolution
    res = str(kw.get("resolution", "")).strip()
    assert res == res_3km, f"Refusing non-3km resolution: {res}"

    dm = kw.get("downscaling_method", None)
    if dm is not None:
        dm_str = str(dm).lower()
        # Must contain "statistical"
        assert "statistical" in dm_str, f"downscaling_method must be Statistical, got {dm}"
        # Must NOT contain "dynamical"
        assert "dynamical" not in dm_str, f"Mixed Statistical+Dynamical not allowed: {dm}"

    try:
        return get_data(**kw)
    except TypeError:
        # If Statistical 3km is not available, we just return None.
        # We do NOT retry without downscaling_method to avoid implicit Dynamical.
        return None


def _convert_units(da, var_key):
    if da is None:
        return None
    return convertDataUnits(da, var_key) if "convertDataUnits" in globals() else da


def _reduce_to_3d_for_rio(da):
    if da is None:
        return None
    for d in ["scenario", "simulation"]:
        if d in da.dims:
            da = da.isel({d: 0}) if da.sizes[d] > 1 else da.squeeze(d, drop=True)

    spatial = []
    try:
        if da.rio.y_dim in da.dims:
            spatial.append(da.rio.y_dim)
        if da.rio.x_dim in da.dims:
            spatial.append(da.rio.x_dim)
    except Exception:
        pass

    for cand in ["y", "x"]:
        if cand in da.dims and cand not in spatial:
            spatial.append(cand)

    extra = [d for d in da.dims if d not in spatial and d != "time"]
    for d in extra:
        da = da.isel({d: 0}) if da.sizes[d] > 1 else da.squeeze(d, drop=True)
    return da


def _clip_if_possible(da, boundary):
    if da is None or boundary is None:
        return da
    da3 = _reduce_to_3d_for_rio(da)
    if "reprojectAndClipDataArray" in globals():
        out = reprojectAndClipDataArray(da3, boundary)
        return out if out is not None else da3
    if "maskDataArrayToBoundary" in globals():
        out = maskDataArrayToBoundary(da3, boundary)
        return out if out is not None else da3
    return da3


def _spatial_avg(da):
    if da is None:
        return None
    if "diagnoseSpatialWeights" in globals() and "calculateSpatialAverage" in globals():
        try:
            w = diagnoseSpatialWeights(da)
            return calculateSpatialAverage(da, w)
        except Exception:
            pass

    dims = [d for d in ["y", "x"] if d in da.dims]
    out = da.mean(dims, skipna=True)
    out.load()
    return out


def available_monthly_3km(var_name):
    """
    Return map: {downscaling_method -> [scenarios]} BUT
    - ONLY for resolution == '3 km'
    - ONLY for Statistical-only methods (no Dynamical in the string)
    """
    df = get_data_options(variable=var_name)
    if not isinstance(df, pd.DataFrame):
        return {}

    df = df.reset_index()

    # Strictly 3 km only
    df = df[df["resolution"].astype(str).str.strip() == res_3km]

    # Monthly only
    if "timescale" in df.columns:
        df = df[df["timescale"] == timescale]

    avail = {}
    if "downscaling_method" not in df.columns or "scenario" not in df.columns:
        return avail

    # Filter to Statistical-only methods
    df_stat = df[
        df["downscaling_method"]
        .astype(str)
        .str.lower()
        .str.contains("statistical")
    ]
    # Drop any that also mention dynamical
    df_stat = df_stat[
        ~df_stat["downscaling_method"]
        .astype(str)
        .str.lower()
        .str.contains("dynamical")
    ]

    if df_stat.empty:
        return avail

    for m in df_stat["downscaling_method"].unique():
        scenarios = df_stat[df_stat["downscaling_method"] == m]["scenario"].unique().tolist()
        avail[m] = sorted(scenarios)

    return avail


def pick_method_and_label(avail_map, ssp):
    """
    Pick a (downscaling_method, scenario_label) pair for a given SSP,
    constrained to Statistical-only methods.
    """
    for m in methods_pref:  # ["Statistical"]
        for sc in avail_map.get(m, []):
            if ssp in str(sc):
                return m, sc
    return None, None


def pick_historical(avail_map):
    """
    Pick historical (method, label) using Statistical-only methods.
    """
    for m in methods_pref:  # ["Statistical"]
        for sc in avail_map.get(m, []):
            if "historical" in str(sc).lower():
                return m, sc
    return None, None


def monthly_mean_std_over_years(da, start_year, end_year):
    if da is None:
        return None, None
    da = da.sel(time=slice(f"{start_year}-01-01", f"{end_year}-12-31"))
    extra = [d for d in da.dims if d != "time"]
    for d in extra:
        if da.sizes.get(d, 1) > 1:
            da = da.mean(d, skipna=True)
        else:
            da = da.squeeze(d, drop=True)

    tmp = da.groupby("time.year").apply(
        lambda x: x.groupby("time.month").mean("time", skipna=True)
    )
    mean = tmp.mean("year", skipna=True)
    std  = tmp.std("year",  skipna=True)
    return mean, std


def fetch_region_var(region_name, boundary_gdf, lat_slice, lon_slice, var_key, var_name):
    avail = available_monthly_3km(var_name)
    print(f"\n[{region_name.upper()}] {var_key} @3km Statistical availability:")
    if not avail:
        print("  no 3km Statistical metadata -> skipping this var/region.")
        return None
    else:
        for m, scs in avail.items():
            print(f"  {m}/{timescale}: {scs}")

    out = {}

    # ---- HISTORICAL (Statistical only, 3km only) ----
    hist_method, hist_label = pick_historical(avail)
    if hist_method is None or hist_label is None:
        print("  no Historical Statistical @3km -> skip var/region.")
        return None

    da_hist = _get_data_safe(
        variable=var_name,
        resolution=res_3km,
        timescale=timescale,
        scenario=[hist_label],
        downscaling_method=hist_method,
        time_slice=hist_years,
        latitude=lat_slice,
        longitude=lon_slice
    )

    if da_hist is None or getattr(da_hist, "time", None) is None or da_hist.time.size == 0:
        print("  historical Statistical 3km missing -> skip var/region.")
        return None

    da_hist = _convert_units(da_hist, var_key)
    da_hist = _clip_if_possible(da_hist, boundary_gdf)
    da_hist = _spatial_avg(da_hist)
    out["Historical"] = da_hist
    print("  historical Statistical 3km ok.")

    # ---- FUTURE (SSPs, Statistical-only, 3km-only) ----
    for ssp in desired_ssps:
        m, lab = pick_method_and_label(avail, ssp)
        if m is None or lab is None:
            print(f"  {ssp} missing as Statistical @3km -> keep as missing")
            continue

        da = _get_data_safe(
            variable=var_name,
            resolution=res_3km,
            timescale=timescale,
            scenario=[lab],
            downscaling_method=m,  # Statistical-only, enforced in _get_data_safe
            time_slice=fut_years,
            latitude=lat_slice,
            longitude=lon_slice
        )

        if da is None or getattr(da, "time", None) is None or da.time.size == 0:
            print(f"  empty {ssp} Statistical 3km -> keep as missing")
            continue

        da = _convert_units(da, var_key)
        da = _clip_if_possible(da, boundary_gdf)
        da = _spatial_avg(da)
        out[ssp] = da
        print(f"  {ssp} ok via {m} (Statistical 3km).")

    return out


def plot_seasonal(region_name, var_key, data_by_label):
    if data_by_label is None or "Historical" not in data_by_label:
        return

    plt.figure(figsize=(10, 6))
    months = np.arange(1, 13)

    h_mean, h_std = monthly_mean_std_over_years(data_by_label["Historical"], *hist_years)
    if h_mean is None:
        return

    h_vals = np.squeeze(h_mean.values)
    h_lo   = np.squeeze((h_mean - h_std).values)
    h_hi   = np.squeeze((h_mean + h_std).values)

    plt.plot(
        months, h_vals, color="black", lw=2,
        label=f"Historical {hist_years[0]}-{hist_years[1]} (Statistical 3km)"
    )
    plt.fill_between(months, h_lo, h_hi, color="black", alpha=0.12)

    for ssp in desired_ssps:
        if ssp not in data_by_label:
            plt.plot([], [], label=f"{ssp} (missing Statistical 3km)")
            continue

        f_mean, f_std = monthly_mean_std_over_years(data_by_label[ssp], *fut_years)
        if f_mean is None:
            plt.plot([], [], label=f"{ssp} (missing Statistical 3km)")
            continue

        f_vals = np.squeeze(f_mean.values)
        f_lo   = np.squeeze((f_mean - f_std).values)
        f_hi   = np.squeeze((f_mean + f_std).values)

        plt.plot(
            months, f_vals, lw=2,
            label=f"{ssp} ({fut_years[0]}-{fut_years[1]}) (Statistical 3km)"
        )
        plt.fill_between(months, f_lo, f_hi, alpha=0.12)

    units = data_by_label["Historical"].attrs.get("units", "")
    plt.title(f"{region_name} seasonal cycle @3km Statistical (monthly): {var_key}")
    plt.xlabel("month")
    plt.xticks(months, list("JFMAMJJASOND"))
    plt.ylabel(f"{analysisVariablesDictionary.get(var_key, var_key)} ({units})")
    plt.grid(True, ls=":")
    plt.legend()
    plt.show()


def _slice_from_boundary(boundary):
    if boundary is None:
        return None, None
    minx, miny, maxx, maxy = boundary.to_crs("EPSG:4326").total_bounds
    return slice(miny, maxy), slice(minx, maxx)


regions = []

moj_b = globals().get("mojaveBoundaryWgs84", None)
moj_lat = globals().get("mojaveLatitudeSlice", None)
moj_lon = globals().get("mojaveLongitudeSlice", None)
if moj_lat is None or moj_lon is None:
    moj_lat, moj_lon = _slice_from_boundary(moj_b)
regions.append(("Mojave", moj_b, moj_lat, moj_lon))

jt_b = globals().get("jTreeBoundary_WGS84", globals().get("joshuaTreeBoundaryWgs84", None))
jt_lat = globals().get("joshuaLatitudeSlice", globals().get("jTreeLatitudeSlice", None))
jt_lon = globals().get("joshuaLongitudeSlice", globals().get("jTreeLongitudeSlice", None))
if jt_lat is None or jt_lon is None:
    jt_lat, jt_lon = _slice_from_boundary(jt_b)
regions.append(("Joshua Tree", jt_b, jt_lat, jt_lon))

vars_to_do = {k: v for k, v in analysisVariablesDictionary.items() if k in ["temp", "precip"]}

for region_name, boundary, lat_sl, lon_sl in regions:
    for var_key, var_name in vars_to_do.items():
        data_dict = fetch_region_var(region_name, boundary, lat_sl, lon_sl, var_key, var_name)
        plot_seasonal(region_name, var_key, data_dict)