In [None]:
# Here in these scripts, we have the calculation methods of all climatic indices for both precipitation and temperature indices for all datasets

In [None]:
                 #########################                                           ##############################
                 #########################                  PRISM                    ##############################
                 #########################                                           ##############################

In [None]:
# Calculating the climatic indices for PRISM (Precipitation)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\daily_loop3\prism_vs_stations_8Nearest_LWR_1991_2012.csv"
physical_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\Split_by_country\filtered_stations_US.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\ClimaticIndices-8Nearest-3"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, prism_lwr8_val) ...")
df_data = pd.read_csv(csv_file)
# parse 'time' as datetime if needed
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# unify station_name: remove leading/trailing spaces, uppercase
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. (OPTIONAL) MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
# unify station_name
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    roll_5 = series.rolling(5, min_periods=1).sum()
    return roll_5.max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    maxr, curr = 0, 0
    for val in is_dry:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    maxr, curr = 0, 0
    for val in is_wet:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_r95p_r99p(series, percentile=(95,99)):
    """R95p, R99p TOT in mm, plus percentage of total."""
    # only wet days >=1 mm for percentile
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentile[0])
    p99 = np.percentile(wet, percentile[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total   = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """Count #wet days >=5 mm, #dry days <1 mm."""
    w = (series >= wet_thr).sum()
    d = (series <  dry_thr).sum()
    return w, d

###############################################################################
# 5. COMPUTE INDICES FOR EACH STATION
###############################################################################
rx1_list, rx5_list = [], []
cdd_list, cwd_list = [], []
r95_list, r99_list = [], []
wet_list, dry_list = [], []

print("Computing indices for each station...")

grouped = df_data.groupby("station_name", as_index=False)
for st_name, grp in grouped:
    # Sort by time (just in case)
    grp = grp.sort_values("time")

    # daily obs/era5
    obs_series = grp["obs"].dropna().reset_index(drop=True)
    prism_series = grp["prism_lwr8_val"].dropna().reset_index(drop=True)  # <--- REFERENCE CHANGED HERE

    # A) Rx1day
    obs_rx1 = calc_rx1day(obs_series)
    prism_rx1 = calc_rx1day(prism_series)
    rx1_list.append({"station_name": st_name,
                     "obs_rx1day": obs_rx1,
                     "prism_rx1day": prism_rx1})

    # B) Rx5day
    obs_rx5 = calc_rx5day(obs_series)
    prism_rx5 = calc_rx5day(prism_series)
    rx5_list.append({"station_name": st_name,
                     "obs_rx5day": obs_rx5,
                     "prism_rx5day": prism_rx5})

    # C) CDD
    obs_cdd_val = calc_cdd(obs_series)
    prism_cdd_val = calc_cdd(prism_series)
    cdd_list.append({"station_name": st_name,
                     "obs_cdd": obs_cdd_val,
                     "prism_cdd": prism_cdd_val})

    # D) CWD
    obs_cwd_val = calc_cwd(obs_series)
    prism_cwd_val = calc_cwd(prism_series)
    cwd_list.append({"station_name": st_name,
                     "obs_cwd": obs_cwd_val,
                     "prism_cwd": prism_cwd_val})

    # E) R95 / R99
    or95a, or95p, or99a, or99p = calc_r95p_r99p(obs_series)
    er95a, er95p, er99a, er99p = calc_r95p_r99p(prism_series)
    r95_list.append({
        "station_name": st_name,
        "obs_r95amt": or95a, "obs_r95pct": or95p,
        "prism_r95amt": er95a, "prism_r95pct": er95p
    })
    r99_list.append({
        "station_name": st_name,
        "obs_r99amt": or99a, "obs_r99pct": or99p,
        "prism_r99amt": er99a, "prism_r99pct": er99p
    })

    # F) wet/dry days
    obs_wet5, obs_dry = calc_wetdays_drydays(obs_series)
    prism_wet5, prism_dry = calc_wetdays_drydays(prism_series)
    wet_list.append({
        "station_name": st_name,
        "obs_wetdays5mm": obs_wet5, 
        "prism_wetdays5mm": prism_wet5
    })
    dry_list.append({
        "station_name": st_name,
        "obs_drydays": obs_dry,
        "prism_drydays": prism_dry
    })

print("Finished computing. Now merging lat/lon from physical file ...")

def attach_coords(df_in):
    """Attach lat, lon, elev from df_phys on station_name."""
    df_out = pd.merge(
        df_in, 
        df_phys[["station_name","lat","lon","elev"]],
        on="station_name", 
        how="left"
    )
    return df_out

df_rx1 = attach_coords(pd.DataFrame(rx1_list))
df_rx5 = attach_coords(pd.DataFrame(rx5_list))
df_cdd = attach_coords(pd.DataFrame(cdd_list))
df_cwd = attach_coords(pd.DataFrame(cwd_list))
df_r95 = attach_coords(pd.DataFrame(r95_list))
df_r99 = attach_coords(pd.DataFrame(r99_list))
df_wet = attach_coords(pd.DataFrame(wet_list))
df_dry = attach_coords(pd.DataFrame(dry_list))

###############################################################################
# 6. SAVE OUTPUT
###############################################################################
print("Saving index tables to Excel in:", output_dir)
df_rx1.to_excel(os.path.join(output_dir, "rx1day.xlsx"),  index=False)
df_rx5.to_excel(os.path.join(output_dir, "rx5day.xlsx"),  index=False)
df_cdd.to_excel(os.path.join(output_dir, "cdd.xlsx"),     index=False)
df_cwd.to_excel(os.path.join(output_dir, "cwd.xlsx"),     index=False)
df_r95.to_excel(os.path.join(output_dir, "r95p.xlsx"),    index=False)
df_r99.to_excel(os.path.join(output_dir, "r99p.xlsx"),    index=False)
df_wet.to_excel(os.path.join(output_dir, "wetdays.xlsx"), index=False)
df_dry.to_excel(os.path.join(output_dir, "drydays.xlsx"), index=False)

print("\nAll precipitation-based indices have been saved to Excel.")

###############################################################################
# (OPTIONAL) QUICK MAP EXAMPLE
###############################################################################
try:
    print("\nQuick map example for obs_rx5day ...")
    # Load shapefiles
    gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
    gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

    gdf_stations = gpd.GeoDataFrame(
        df_rx5,
        geometry=gpd.points_from_xy(df_rx5["lon"], df_rx5["lat"]),
        crs="EPSG:4326"
    )

    fig, ax = plt.subplots(figsize=(10,8), subplot_kw={"projection": ccrs.PlateCarree()})
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='cyan', linewidth=1)

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')

    sc = ax.scatter(gdf_stations.geometry.x, gdf_stations.geometry.y,
                    c=gdf_stations["obs_rx5day"], cmap="Reds", s=60,
                    transform=ccrs.PlateCarree(), edgecolor="k")
    plt.colorbar(sc, ax=ax, label="Obs Rx5day (mm)")
    ax.set_extent([-95.5, -72, 38.5, 52.5])  # approximate bounding box
    gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
    gl.right_labels = False
    gl.top_labels   = False

    plt.title("Obs Rx5day (from CSV daily data)", fontsize=14)
    plt.show()

except Exception as e:
    print("Mapping step failed:", e)

print("\n✅ Done computing precipitation-based indices from 'prism_lwr8_val' column!")


In [None]:
# Calculating the climatic indices for PRISM (Temperature - Tmin/Tmax)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Temperature\daily_loop\prism_vs_stations_8Nearest_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation_Temperature.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Temperature\ClimaticIndices-8Nearest"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, prism_lwr5_val) ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS – FULL ETCCDI SET FOR Tmin / Tmax
###############################################################################
def absolute_extremes(series, kind):
    """Return the single-day absolute extreme."""
    if kind == "max":
        return series.max(skipna=True)
    else:                       # "min"
        return series.min(skipna=True)

# -------------------------------------------------------------------------
# percentile thresholds (5-day moving window, baseline = 1991-2012)
# -------------------------------------------------------------------------
BASE_START, BASE_END = "1991-01-01", "2012-12-31"

def _climatology_percentiles(s, p):
    """Return a Series (index = 1…366) of the p-th percentile."""
    # drop 29 Feb so every year has 365 days
    s = s[~((s.index.month == 2) & (s.index.day == 29))]
    df = pd.DataFrame({"val": s, "doy": s.index.dayofyear})
    climo = []
    for d in range(1, 366):
        win = list(range(d-2, d+3))                       # ±2-day window
        win = [(x-1) % 365 + 1 for x in win]              # wrap around ends
        vals = df.loc[df["doy"].isin(win), "val"]
        climo.append(np.nanpercentile(vals, p) if len(vals) else np.nan)
    return pd.Series(climo, index=range(1, 366), name=f"p{p}")

def percentile_flags(s, perc_series, side):
    """Return Boolean Series: True where value is < or > percentile."""
    doy = s.index.dayofyear
    thr = perc_series.reindex(doy).values
    if side == "low":
        return s < thr
    else:
        return s > thr

def spell_length(bool_series, min_run=6):
    """Total # days in spells of ≥ min_run consecutive Trues."""
    is_true = bool_series.fillna(False).values
    # identify run lengths
    run_ends = np.where(np.diff(np.concatenate(([0], is_true, [0]))))[0]
    lengths  = run_ends[1::2] - run_ends[::2]
    return lengths[lengths >= min_run].sum()

# -------------------------------------------------------------------------
# absolute-threshold counters
# -------------------------------------------------------------------------
def count_threshold(series, op, thr):
    if op == "<":
        return (series < thr).sum()
    else:
        return (series > thr).sum()

###############################################################################
# 5.  ANNUAL ETCCDI INDICES  –  FORMAT COMPATIBLE WITH THE SEASONAL SCRIPT
#     • keeps:   TXx  TNn  TX90p  TN10p  FD  WSDI  CSDI
#     • drops:   TR, ID, SU, TXn, TNx
###############################################################################
rows = []

print("→ computing *annual* indices …")
for st_name, st_grp in df_data.groupby("station_name"):

    st_grp = st_grp.set_index("time").sort_index()

    # build daily Series ................................................................
    obs_max = st_grp.loc[st_grp["var"] == "tmax", "obs"].asfreq("D")
    obs_min = st_grp.loc[st_grp["var"] == "tmin", "obs"].asfreq("D")
    prism_max = st_grp.loc[st_grp["var"] == "tmax", "prism_lwr5_val"].asfreq("D")
    prism_min = st_grp.loc[st_grp["var"] == "tmin", "prism_lwr5_val"].asfreq("D")
    if obs_max.empty:          # station has no data at all
        continue

    # ── fixed-year climatology (1991-2012, 5-day moving window) ─────────────-
    p90_TX_obs = _climatology_percentiles(obs_max[BASE_START:BASE_END], 90)
    p10_TN_obs = _climatology_percentiles(obs_min[BASE_START:BASE_END], 10)
    p90_TX_prism = _climatology_percentiles(prism_max[BASE_START:BASE_END], 90)
    p10_TN_prism = _climatology_percentiles(prism_min[BASE_START:BASE_END], 10)

    # flags for the *full* record (faster than recomputing year-by-year)
    flags = {
        "obs_TX90": obs_max >
                     p90_TX_obs.reindex(obs_max.index.dayofyear).values,
        "prism_TX90": prism_max >
                     p90_TX_prism.reindex(prism_max.index.dayofyear).values,
        "obs_TN10": obs_min <
                     p10_TN_obs.reindex(obs_min.index.dayofyear).values,
        "prism_TN10": prism_min <
                     p10_TN_prism.reindex(prism_min.index.dayofyear).values,
    }

    # ── iterate over years (December belongs to the following DJF year) ──────
    years = np.unique(obs_max.index.year)
    for yr in years:
        mask = obs_max.index.year == yr
        if mask.sum() < 200:         # at least ~55 % of a year
            continue

        def _sel(s):          # helper to slice one year
            return s[s.index.year == yr]

        # intensity ....................................................................
        TXx_obs = _sel(obs_max).max()
        TXx_prism = _sel(prism_max).max()
        TNn_obs = _sel(obs_min).min()
        TNn_prism = _sel(prism_min).min()

        # percentile frequencies (percentage of days)
        TX90p_obs = flags["obs_TX90"][mask].mean() * 100.0
        TX90p_prism = flags["prism_TX90"][mask].mean() * 100.0
        TN10p_obs = flags["obs_TN10"][mask].mean() * 100.0
        TN10p_prism = flags["prism_TN10"][mask].mean() * 100.0

        # spell duration (≥6 consecutive days)
        WSDI_obs = spell_length(flags["obs_TX90"][mask], min_run=6)
        WSDI_prism = spell_length(flags["prism_TX90"][mask], min_run=6)
        CSDI_obs = spell_length(flags["obs_TN10"][mask], min_run=6)
        CSDI_prism = spell_length(flags["prism_TN10"][mask], min_run=6)

        # absolute‐threshold count
        FD_obs = (_sel(obs_min) < 0).sum()
        FD_prism = (_sel(prism_min) < 0).sum()

        # helper for ratios (avoid /0)
        ratio = lambda o, e: np.nan if (o == 0 or np.isnan(o)) else e / o

        rows.append(dict(
            station_name=st_name, year=yr,
            TXx_obs=TXx_obs,   TXx_prism=TXx_prism,   TXx_ratio=ratio(TXx_obs, TXx_prism),
            TNn_obs=TNn_obs,   TNn_prism=TNn_prism,   TNn_ratio=ratio(TNn_obs, TNn_prism),
            TX90p_obs=TX90p_obs, TX90p_prism=TX90p_prism,
            TX90p_ratio=ratio(TX90p_obs, TX90p_prism),
            TN10p_obs=TN10p_obs, TN10p_prism=TN10p_prism,
            TN10p_ratio=ratio(TN10p_obs, TN10p_prism),
            FD_obs=FD_obs,     FD_prism=FD_prism,     FD_ratio=ratio(FD_obs, FD_prism),
            WSDI_obs=WSDI_obs, WSDI_prism=WSDI_prism,
            WSDI_ratio=ratio(WSDI_obs, WSDI_prism),
            CSDI_obs=CSDI_obs, CSDI_prism=CSDI_prism,
            CSDI_ratio=ratio(CSDI_obs, CSDI_prism)
        ))

df_yr = (pd.DataFrame(rows)
         .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                on="station_name", how="left")
         .sort_values(["station_name", "year"]))

###############################################################################
# 6.  SAVE OUTPUT  – ONE PARQUET ( + XLSX )  +  OPTIONAL per-index sheets
###############################################################################
annual_pq  = os.path.join(output_dir, "Indices_Annual.parquet")
annual_xls = annual_pq.replace(".parquet", ".xlsx")

df_yr.to_parquet(annual_pq, index=False)
df_yr.to_excel  (annual_xls, index=False)
print(f"✓ Annual indices saved → {annual_pq}")
print(f"✓ …and also saved as   → {annual_xls}")

# OPTIONAL: write one Excel file per index in the familiar “wide” format
# ---------------------------------------------------------------------------
index_roots = ["TXx", "TNn", "TX90p", "TN10p", "FD", "WSDI", "CSDI"]
def _wide(idx_root: str) -> pd.DataFrame:
    return (df_yr
            .pivot_table(index="station_name",
                         values=[f"{idx_root}_obs", f"{idx_root}_prism",
                                 f"{idx_root}_ratio"])
            .reset_index()
            .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                   on="station_name", how="left"))

print("\n(optional) individual workbooks …")
for idx in index_roots:
    w = _wide(idx)
    fp = os.path.join(output_dir, f"{idx}.xlsx")
    w.to_excel(fp, index=False)
    print("  •", os.path.basename(fp))

print("\n✅  Annual-index workflow finished.")

In [None]:
# Temporal stratification of precipitation climatic indices for PRISM for having the seasonal indices

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIGURATION
###############################################################################
csv_file      = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\daily_loop3\prism_vs_stations_8Nearest_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\Split_by_country\filtered_stations_US.csv"
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\ClimaticIndices-Seasonal-3"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY DATA & ADD TEMPORAL FIELDS
###############################################################################
print("Loading daily CSV data ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# Standardize station_name
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

# Add month (1..12) and season (DJF, MAM, JJA, SON)
df_data["month"] = df_data["time"].dt.month

def get_season(month):
    if month in [12, 1, 2]:
        return "DJF"
    elif month in [3, 4, 5]:
        return "MAM"
    elif month in [6, 7, 8]:
        return "JJA"
    else:
        return "SON"

df_data["season"] = df_data["month"].apply(get_season)

print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. LOAD PHYSICAL FILE & MERGE COORDINATES
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    return series.rolling(5, min_periods=1).sum().max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    max_run, current_run = 0, 0
    for val in is_dry:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    max_run, current_run = 0, 0
    for val in is_wet:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_r95p_r99p(series, percentiles=(95,99)):
    """
    r95amt, r95pct, r99amt, r99pct:
    - r95amt = sum of daily prcp above 95th percentile
    - r95pct = (r95amt / total) * 100
    - similarly for 99th percentile
    """
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentiles[0])
    p99 = np.percentile(wet, percentiles[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """
    wetdays = #days >= wet_thr
    drydays = #days < dry_thr
    """
    return (series >= wet_thr).sum(), (series < dry_thr).sum()

###############################################################################
# 5. FUNCTION TO COMPUTE INDICES FOR A GROUP (MONTHLY or SEASONAL)
###############################################################################
def compute_indices(df_group):
    """
    For a subset of daily data (e.g. station+month, or station+season),
    compute the climate indices for Obs vs PRISM, plus ratio columns.
    """
    obs_series = df_group["obs"].dropna().reset_index(drop=True)
    prism_series = df_group["prism_lwr8_val"].dropna().reset_index(drop=True)
    if len(obs_series) == 0 or len(prism_series) == 0:
        return None
    
    res = {}
    # Rx1day / Rx5day
    res["rx1day_obs"] = calc_rx1day(obs_series)
    res["rx1day_prism"] = calc_rx1day(prism_series)
    res["rx5day_obs"] = calc_rx5day(obs_series)
    res["rx5day_prism"] = calc_rx5day(prism_series)
    
    # CDD / CWD
    res["cdd_obs"] = calc_cdd(obs_series)
    res["cdd_prism"] = calc_cdd(prism_series)
    res["cwd_obs"] = calc_cwd(obs_series)
    res["cwd_prism"] = calc_cwd(prism_series)
    
    # R95 / R99
    r95_obs = calc_r95p_r99p(obs_series)
    r95_prism = calc_r95p_r99p(prism_series)
    res["r95amt_obs"] = r95_obs[0]
    res["r95pct_obs"] = r95_obs[1]
    res["r95amt_prism"] = r95_prism[0]
    res["r95pct_prism"] = r95_prism[1]
    res["r99amt_obs"] = r95_obs[2]
    res["r99pct_obs"] = r95_obs[3]
    res["r99amt_prism"] = r95_prism[2]
    res["r99pct_prism"] = r95_prism[3]
    
    # Wet / Dry days
    wet_obs, dry_obs = calc_wetdays_drydays(obs_series)
    wet_prism, dry_prism = calc_wetdays_drydays(prism_series)
    res["wetdays_obs"] = wet_obs
    res["wetdays_prism"] = wet_prism
    res["drydays_obs"] = dry_obs
    res["drydays_prism"] = dry_prism
    
    # Ratio columns: prism/OBS if obs != 0
    if res["rx1day_obs"]:
        res["rx1day_ratio"] = res["rx1day_prism"] / res["rx1day_obs"]
    if res["rx5day_obs"]:
        res["rx5day_ratio"] = res["rx5day_prism"] / res["rx5day_obs"]
    if res["cdd_obs"]:
        res["cdd_ratio"] = res["cdd_prism"] / res["cdd_obs"]
    if res["cwd_obs"]:
        res["cwd_ratio"] = res["cwd_prism"] / res["cwd_obs"]
    if res["r95amt_obs"]:
        res["r95amt_ratio"] = res["r95amt_prism"] / res["r95amt_obs"]
    if res["r95pct_obs"]:
        res["r95pct_ratio"] = res["r95pct_prism"] / res["r95pct_obs"]
    if res["r99amt_obs"]:
        res["r99amt_ratio"] = res["r99amt_prism"] / res["r99amt_obs"]
    if res["r99pct_obs"]:
        res["r99pct_ratio"] = res["r99pct_prism"] / res["r99pct_obs"]
    if res["wetdays_obs"]:
        res["wetdays_ratio"] = res["wetdays_prism"] / res["wetdays_obs"]
    if res["drydays_obs"]:
        res["drydays_ratio"] = res["drydays_prism"] / res["drydays_obs"]
    
    return res

###############################################################################
# 6. MONTHLY INDICES
###############################################################################
monthly_results = []
group_month = df_data.groupby(["station_name", "month"])
for (st_name, mon), group in group_month:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["month"] = mon
    monthly_results.append(indices)

df_monthly = pd.DataFrame(monthly_results)
df_monthly = pd.merge(
    df_monthly,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_monthly = df_monthly.sort_values(["station_name", "month"])
monthly_out = os.path.join(output_dir, "Indices_Monthly.xlsx")
df_monthly.to_excel(monthly_out, index=False)
print("Monthly indices saved =>", monthly_out)

###############################################################################
# 7. SEASONAL INDICES
###############################################################################
seasonal_results = []
group_season = df_data.groupby(["station_name", "season"])
for (st_name, seas), group in group_season:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["season"] = seas
    seasonal_results.append(indices)

df_seasonal = pd.DataFrame(seasonal_results)
df_seasonal = pd.merge(
    df_seasonal,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_seasonal = df_seasonal.sort_values(["station_name", "season"])
seasonal_out = os.path.join(output_dir, "Indices_Seasonal.xlsx")
df_seasonal.to_excel(seasonal_out, index=False)
print("Seasonal indices saved =>", seasonal_out)

###############################################################################
# 8. DONE
###############################################################################
print("\nAll monthly and seasonal indices have been saved. (No extreme-event stratification.)")


In [None]:
# DJF for PRISM

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import Point
from scipy.stats import pearsonr
import seaborn as sns

###############################################################################
# 1. CONFIG & PATHS
###############################################################################
indices_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\ClimaticIndices-Seasonal-3"
seasonal_file = os.path.join(indices_dir, "Indices_Seasonal.xlsx")  # single file
output_plots  = os.path.join(indices_dir, "AnalysisPlots_DJF")
os.makedirs(output_plots, exist_ok=True)

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

# Indices in your seasonal file
index_list = ["rx1day","rx5day","cdd","cwd","r95p","r99p","wetdays","drydays"]

# For summary stats, define how to find obs vs. emd columns
index_columns = {
    "rx1day":  ("rx1day_obs",  "rx1day_prism"),
    "rx5day":  ("rx5day_obs",  "rx5day_prism"),
    "cdd":     ("cdd_obs",     "cdd_prism"),
    "cwd":     ("cwd_obs",     "cwd_prism"),
    "r95p":    (("r95amt_obs","r95pct_obs"), ("r95amt_prism","r95pct_prism")),
    "r99p":    (("r99amt_obs","r99pct_obs"), ("r99amt_prism","r99pct_prism")),
    "wetdays": ("wetdays_obs","wetdays_prism"),
    "drydays": ("drydays_obs","drydays_prism"),
}

###############################################################################
# 2. LOAD SEASONAL FILE & FILTER TO DJF
###############################################################################
df_season = pd.read_excel(seasonal_file)
print("Loaded =>", seasonal_file, "| shape =", df_season.shape)

# Filter to DJF
df_season = df_season[df_season["season"]=="DJF"].copy()
df_season = df_season.dropna(subset=["lat","lon"])  # ensure lat/lon exist
print("After filtering to DJF => shape =", df_season.shape)

mdf = df_season.reset_index(drop=True)
master_xlsx = os.path.join(output_plots, "MasterTable_Seasonal_DJF.xlsx")
mdf.to_excel(master_xlsx, index=False)
print(f"\n(A) Master table (DJF) saved => {master_xlsx}")
print("Columns:", mdf.columns.tolist())

###############################################################################
# 3. SUMMARY TABLE (MBE, RMSE, STD, CC, d) for DJF
###############################################################################
def index_of_agreement(obs, model):
    obs_mean = np.mean(obs)
    num = np.sum((model - obs)**2)
    den = np.sum((abs(model - obs_mean) + abs(obs - obs_mean))**2)
    if den == 0:
        return np.nan
    return 1 - num/den

def rmse(a, b):
    return np.sqrt(np.mean((a-b)**2))

def std_of_residuals(a, b):
    return np.std(a-b, ddof=1)

def mean_bias_error(a, b):
    return np.mean(b-a)

summary_rows = []
for idx_name in index_list:
    obs_cols = index_columns[idx_name][0]
    prism_cols = index_columns[idx_name][1]

    if isinstance(obs_cols, tuple):
        # multiple columns
        for oc, ec in zip(obs_cols, prism_cols):
            valid = mdf[[oc, ec]].dropna()
            if len(valid) < 2:
                continue
            obs_vals = valid[oc].values
            prism_vals = valid[ec].values
            MB  = mean_bias_error(obs_vals, prism_vals)
            RM  = rmse(obs_vals, prism_vals)
            SR  = std_of_residuals(obs_vals, prism_vals)
            CC  = pearsonr(obs_vals, prism_vals)[0] if len(obs_vals) > 1 else np.nan
            dd  = index_of_agreement(obs_vals, prism_vals)
            idx_label = f"{idx_name}_{oc.replace('_obs','')}"
            summary_rows.append({
                "Index": idx_label,
                "Count": len(valid),
                "MBE": MB,
                "RMSE": RM,
                "STDres": SR,
                "CC": CC,
                "d": dd,
            })
    else:
        oc = obs_cols
        ec = prism_cols
        valid = mdf[[oc, ec]].dropna()
        if len(valid) < 2:
            continue
        obs_vals = valid[oc].values
        prism_vals = valid[ec].values
        MB = mean_bias_error(obs_vals, prism_vals)
        RM = rmse(obs_vals, prism_vals)
        SR = std_of_residuals(obs_vals, prism_vals)
        CC = pearsonr(obs_vals, prism_vals)[0] if len(obs_vals) > 1 else np.nan
        dd = index_of_agreement(obs_vals, prism_vals)
        summary_rows.append({
            "Index": idx_name,
            "Count": len(valid),
            "MBE": MB,
            "RMSE": RM,
            "STDres": SR,
            "CC": CC,
            "d": dd,
        })

summary_df = pd.DataFrame(summary_rows)
summary_cols = ["Index","Count","MBE","RMSE","STDres","CC","d"]
summary_df = summary_df[summary_cols]
summary_xlsx = os.path.join(output_plots, "SummaryTable_Extremes_DJF.xlsx")
summary_df.to_excel(summary_xlsx, index=False)
print(f"(B) Summary Table (DJF) => {summary_xlsx}\n{summary_df}")

###############################################################################
# 4. MAPPING: Combine Observed, prism, Ratio in One Figure
###############################################################################
gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

def add_basin_lakes(ax):
    #ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='black', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='cyan', linewidth=1)

def plot_map_triptych(df, obs_col, prism_col, ratio_col, idx_name, out_png):
    """
    Creates a single figure with 3 subplots (side-by-side):
      1) Observed
      2) prism
      3) Ratio (MERRA2/OBS)
    Each subplot has a colorbar, a 90th-percentile hotspot circle, etc.
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6),
                             subplot_kw={"projection": ccrs.PlateCarree()})

    # We'll define a small helper to do each subplot
    def scatter_map(ax, value_col, title):
        ax.set_extent([-95.5, -72, 38.5, 52.5])
        add_basin_lakes(ax)
        sc = ax.scatter(df["lon"], df["lat"], c=df[value_col], cmap="viridis",
                        s=60, transform=ccrs.PlateCarree(), edgecolor="k", zorder=10)
        cb = plt.colorbar(sc, ax=ax, shrink=0.8)
        cb.set_label(value_col)

        # Hotspots => top 10%
        vals = df[value_col].dropna().values
        if len(vals) > 0:
            thr = np.percentile(vals, 90)
            is_hot = df[value_col]>=thr
            ax.scatter(df.loc[is_hot,"lon"], df.loc[is_hot,"lat"],
                       marker='o', facecolors='none', edgecolors='red', s=80,
                       transform=ccrs.PlateCarree(), zorder=11,
                       label=f"Hotspot >= {thr:.2f}")
        ax.set_title(title, fontsize=12)
        gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
        gl.right_labels = False
        gl.top_labels   = False
        ax.legend(loc='upper right')

    scatter_map(axes[0], obs_col,  f"{idx_name} Observed (DJF)")
    scatter_map(axes[1], prism_col,  f"{idx_name} PRISM (DJF)")
    scatter_map(axes[2], ratio_col,f"{idx_name} (PRISM/OBS) (DJF)")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

def get_map_cols(idx_name):
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs  = f"{idx_name}_obs"
        prism  = f"{idx_name}_prism"
        ratio= f"{idx_name}_ratio"
        return obs, prism, ratio
    elif idx_name == "wetdays":
        obs  = "wetdays_obs"
        prism  = "wetdays_prism"
        ratio= "wetdays_ratio"
        return obs, prism, ratio
    elif idx_name == "r95p":
        obs  = "r95amt_obs"
        prism  = "r95amt_prism"
        ratio= "r95amt_ratio"
        return obs, prism, ratio
    elif idx_name == "r99p":
        obs  = "r99amt_obs"
        prism  = "r99amt_prism"
        ratio= "r99amt_ratio"
        return obs, prism, ratio
    else:
        return None, None, None

for idx_name in index_list:
    obs_col, prism_col, ratio_col = get_map_cols(idx_name)
    if obs_col is None:
        continue

    needed_cols = [obs_col, prism_col, ratio_col, "lat", "lon"]
    if not all(c in mdf.columns for c in needed_cols):
        print(f"Skipping map for {idx_name} - missing columns.")
        continue

    subdf = mdf.dropna(subset=["lat","lon"]).copy()
    out_png = os.path.join(output_plots, f"DJF_{idx_name}_MAP_3panel.png")
    plot_map_triptych(subdf, obs_col, prism_col, ratio_col, idx_name, out_png)

###############################################################################
# 5. DISTRIBUTION & BOX/CDF/Scatter in One Figure
###############################################################################
def plot_distribution_triptych(df, obs_col, prism_col, label, out_png):
    """
    Creates a single figure with 3 subplots side-by-side:
      1) Boxplot
      2) CDF
      3) Scatter
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18,6))

    # A) Boxplot
    ax_box = axes[0]
    data = pd.DataFrame({"Obs": df[obs_col], "prism": df[prism_col]}).melt(
        var_name="Dataset", value_name=label
    )
    sns.boxplot(data=data, x="Dataset", y=label, ax=ax_box)
    ax_box.set_title(f"Boxplot: {label} (DJF)")

    # B) CDF
    ax_cdf = axes[1]
    obs_vals = df[obs_col].dropna()
    prism_vals = df[prism_col].dropna()

    def ecdf(x):
        xs = np.sort(x)
        ys = np.arange(1, len(xs)+1)/len(xs)
        return xs, ys

    if len(obs_vals)>=2 and len(prism_vals)>=2:
        xs_o, ys_o = ecdf(obs_vals)
        xs_e, ys_e = ecdf(prism_vals)
        ax_cdf.plot(xs_o, ys_o, label="Obs")
        ax_cdf.plot(xs_e, ys_e, label="PRISM")
        ax_cdf.set_title(f"CDF of {label} (DJF)")
        ax_cdf.set_xlabel(label)
        ax_cdf.set_ylabel("Probability")
        ax_cdf.legend()
    else:
        ax_cdf.set_title(f"CDF: not enough data ({label})")

    # C) Scatter
    ax_scat = axes[2]
    valid = df[[obs_col, prism_col]].dropna()
    if len(valid)>=2:
        x = valid[obs_col]
        y = valid[prism_col]
        cc, _ = pearsonr(x, y)
        ax_scat.scatter(x, y, edgecolors='k', alpha=0.7)
        mn, mx = np.nanmin([x.min(), y.min()]), np.nanmax([x.max(), y.max()])
        ax_scat.plot([mn, mx],[mn, mx],'r--')
        ax_scat.set_xlabel(f"Obs {label} (DJF)")
        ax_scat.set_ylabel(f"PRISM {label} (DJF)")
        ax_scat.set_title(f"{label} (Corr={cc:.2f}, DJF)")
    else:
        ax_scat.set_title(f"Scatter: not enough data ({label})")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

for idx_name in index_list:
    # figure out obs, emd columns
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs_col  = f"{idx_name}_obs"
        prism_col  = f"{idx_name}_prism"
    elif idx_name == "wetdays":
        obs_col  = "wetdays_obs"
        prism_col  = "wetdays_prism"
    elif idx_name == "r95p":
        obs_col  = "r95amt_obs"
        prism_col  = "r95amt_prism"
    elif idx_name == "r99p":
        obs_col  = "r99amt_obs"
        prism_col  = "r99amt_prism"
    else:
        continue

    if obs_col not in mdf.columns or prism_col not in mdf.columns:
        print(f"Skipping distribution for {idx_name} - missing columns.")
        continue

    subdf = mdf[[obs_col, prism_col]].dropna()
    if len(subdf)<2:
        print(f"Skipping distribution for {idx_name} - not enough data.")
        continue

    out_3panel = os.path.join(output_plots, f"DJF_{idx_name}_Distribution_3panel.png")
    plot_distribution_triptych(subdf, obs_col, prism_col, idx_name, out_3panel)

###############################################################################
# 6. DONE
###############################################################################
print("\nAll DJF steps completed! See outputs in:", output_plots)

## For the other seasons, just change any DJF to JJA, MAM, or SON

In [None]:
                 #########################                                           ##############################
                 #########################                  EMDNA                    ##############################
                 #########################                                           ##############################

In [None]:
# Calculating the climatic indices for EMDNA (prcp)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG          
###############################################################################
root_dir = (r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder"
            r"\Ensemble files\EMDNA_GLB_Precipitation")

# the 10 ensemble sub-folders we need to process
ENSEMBLES = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]

# static reference files (same for every ensemble)
physical_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

# ---------------------------------------------------------------------------
# start looping over ensembles – EVERYTHING that follows now belongs inside
# this for-loop
# ---------------------------------------------------------------------------
for ENS in ENSEMBLES:
    print("\n" + "="*84)
    print(f"⧉  Processing ensemble {ENS:03d}  ⧉")
    print("="*84)

    ens_dir   = os.path.join(root_dir, str(ENS))
    csv_file  = os.path.join(ens_dir, "daily_loop",
                             f"emdna_vs_stations_25km_LWR_1991_2012_prcp_{ENS:03d}.csv")
    output_dir = os.path.join(ens_dir, "ClimaticIndices-25KM")
    os.makedirs(output_dir, exist_ok=True)


###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
    if not os.path.isfile(csv_file):
        print(f"   ⚠  Daily CSV not found, skipping ensemble {ENS:03d}")
        continue
    
    print("Loading daily CSV data …")
    df_data = pd.read_csv(csv_file)
    df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")
    
    # standardise station names
    df_data["station_name"] = (
        df_data["station_name"].astype(str).str.strip().str.upper()
    )
    
    # keep only precipitation rows (defensive; file *should* contain just prcp)
    if "var" in df_data.columns:
        df_data = df_data[df_data["var"] == "prcp"].copy()
    
    # unify column name expected later in the script
    #df_data = df_data.rename(columns={"emdna_lwr25_val": "emdna_val"})
    
    print(f"df_data shape = {df_data.shape}")
    print("Columns:", df_data.columns.tolist())
    print("Time range:", df_data['time'].min(), "to", df_data['time'].max())
    
    ###############################################################################
    # 3. (OPTIONAL) MERGE WITH PHYSICAL FILE TO GET LAT/LON
    ###############################################################################
    df_phys = pd.read_csv(physical_file)
    df_phys = df_phys.rename(columns={
        "NAME": "station_name",
        "LATITUDE": "lat",
        "LONGITUDE": "lon",
        "Elevation": "elev"
    })
    # unify station_name
    df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()
    
    # We'll merge lat/lon AFTER computing the indices, so each final row has lat/lon.
    
    ###############################################################################
    # 4. HELPER FUNCTIONS FOR CLIMATE INDICES
    ###############################################################################
    def calc_rx1day(series):
        """Max 1-day precipitation."""
        return series.max(skipna=True)
    
    def calc_rx5day(series):
        """Max 5-day running sum."""
        roll_5 = series.rolling(5, min_periods=1).sum()
        return roll_5.max(skipna=True)
    
    def calc_cdd(series, dry_threshold=1.0):
        """Longest run of days < dry_threshold."""
        is_dry = series < dry_threshold
        maxr, curr = 0, 0
        for val in is_dry:
            if val:
                curr += 1
                maxr = max(maxr, curr)
            else:
                curr = 0
        return maxr
    
    def calc_cwd(series, wet_threshold=1.0):
        """Longest run of days >= wet_threshold."""
        is_wet = series >= wet_threshold
        maxr, curr = 0, 0
        for val in is_wet:
            if val:
                curr += 1
                maxr = max(maxr, curr)
            else:
                curr = 0
        return maxr
    
    def calc_r95p_r99p(series, percentile=(95,99)):
        """R95p, R99p TOT in mm, plus percentage of total."""
        # only wet days >=1 mm for percentile
        wet = series[series >= 1.0]
        if len(wet) < 5:
            return np.nan, np.nan, np.nan, np.nan
        p95 = np.percentile(wet, percentile[0])
        p99 = np.percentile(wet, percentile[1])
        r95_amt = wet[wet > p95].sum()
        r99_amt = wet[wet > p99].sum()
        total   = series.sum(skipna=True)
        r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
        r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
        return r95_amt, r95_pct, r99_amt, r99_pct
    
    def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
        """Count #wet days >=5 mm, #dry days <1 mm."""
        w = (series >= wet_thr).sum()
        d = (series <  dry_thr).sum()
        return w, d
    
    ###############################################################################
    # 5. COMPUTE INDICES FOR EACH STATION
    ###############################################################################
    # We'll store results in lists of dicts, then convert to DataFrame => Excel.
    rx1_list, rx5_list = [], []
    cdd_list, cwd_list = [], []
    r95_list, r99_list = [], []
    wet_list, dry_list = [], []
    
    print("Computing indices for each station...")
    grouped = df_data.groupby("station_name", as_index=False)
    
    for st_name, grp in grouped:
        # Sort by time (just in case)
        grp = grp.sort_values("time")
        # daily obs/emd
        obs_series = grp["obs"].dropna().reset_index(drop=True)
        emd_series = grp["emdna_lwr25_val"].dropna().reset_index(drop=True)
    
        # A) Rx1day
        obs_rx1 = calc_rx1day(obs_series)
        emd_rx1 = calc_rx1day(emd_series)
        rx1_list.append({"station_name": st_name,
                         "obs_rx1day": obs_rx1,
                         "emd_rx1day": emd_rx1})
    
        # B) Rx5day
        obs_rx5 = calc_rx5day(obs_series)
        emd_rx5 = calc_rx5day(emd_series)
        rx5_list.append({"station_name": st_name,
                         "obs_rx5day": obs_rx5,
                         "emd_rx5day": emd_rx5})
    
        # C) CDD
        obs_cdd_val = calc_cdd(obs_series)
        emd_cdd_val = calc_cdd(emd_series)
        cdd_list.append({"station_name": st_name,
                         "obs_cdd": obs_cdd_val,
                         "emd_cdd": emd_cdd_val})
    
        # D) CWD
        obs_cwd_val = calc_cwd(obs_series)
        emd_cwd_val = calc_cwd(emd_series)
        cwd_list.append({"station_name": st_name,
                         "obs_cwd": obs_cwd_val,
                         "emd_cwd": emd_cwd_val})
    
        # E) R95 / R99
        or95a, or95p, or99a, or99p = calc_r95p_r99p(obs_series)
        er95a, er95p, er99a, er99p = calc_r95p_r99p(emd_series)
        r95_list.append({
            "station_name": st_name,
            "obs_r95amt": or95a, "obs_r95pct": or95p,
            "emd_r95amt": er95a, "emd_r95pct": er95p
        })
        r99_list.append({
            "station_name": st_name,
            "obs_r99amt": or99a, "obs_r99pct": or99p,
            "emd_r99amt": er99a, "emd_r99pct": er99p
        })
    
        # F) wet/dry days
        obs_wet5, obs_dry = calc_wetdays_drydays(obs_series)
        emd_wet5, emd_dry = calc_wetdays_drydays(emd_series)
        wet_list.append({
            "station_name": st_name,
            "obs_wetdays5mm": obs_wet5, 
            "emd_wetdays5mm": emd_wet5
        })
        dry_list.append({
            "station_name": st_name,
            "obs_drydays": obs_dry,
            "emd_drydays": emd_dry
        })
    
    print("Finished computing. Now merging lat/lon from physical file ...")
    
    # Convert each list to DataFrame => merge lat,lon => save
    def attach_coords(df_in):
        """Attach lat, lon, elev from df_phys on station_name."""
        df_out = pd.merge(
            df_in, 
            df_phys[["station_name","lat","lon","elev"]],
            on="station_name", 
            how="left"
        )
        return df_out
    
    df_rx1 = attach_coords(pd.DataFrame(rx1_list))
    df_rx5 = attach_coords(pd.DataFrame(rx5_list))
    df_cdd = attach_coords(pd.DataFrame(cdd_list))
    df_cwd = attach_coords(pd.DataFrame(cwd_list))
    df_r95 = attach_coords(pd.DataFrame(r95_list))
    df_r99 = attach_coords(pd.DataFrame(r99_list))
    df_wet = attach_coords(pd.DataFrame(wet_list))
    df_dry = attach_coords(pd.DataFrame(dry_list))
    
    ###############################################################################
    # 6. SAVE OUTPUT (SAME FILE NAMES AS BEFORE)
    ###############################################################################
    print("Saving index tables to Excel in:", output_dir)
    df_rx1.to_excel(os.path.join(output_dir, "rx1day.xlsx"),  index=False)
    df_rx5.to_excel(os.path.join(output_dir, "rx5day.xlsx"),  index=False)
    df_cdd.to_excel(os.path.join(output_dir, "cdd.xlsx"),     index=False)
    df_cwd.to_excel(os.path.join(output_dir, "cwd.xlsx"),     index=False)
    df_r95.to_excel(os.path.join(output_dir, "r95p.xlsx"),    index=False)
    df_r99.to_excel(os.path.join(output_dir, "r99p.xlsx"),    index=False)
    df_wet.to_excel(os.path.join(output_dir, "wetdays.xlsx"), index=False)
    df_dry.to_excel(os.path.join(output_dir, "drydays.xlsx"), index=False)
    
    print("\nAll precipitation-based indices have been saved to Excel with station_name (and lat/lon).")
    
    ###############################################################################
    # (OPTIONAL) QUICK MAP EXAMPLE (like obs_rx5day)
    ###############################################################################
    try:
        print("\nQuick map example for obs_rx5day ...")
        # Load shapefiles
        gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
        gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)
    
        gdf_stations = gpd.GeoDataFrame(
            df_rx5,
            geometry=gpd.points_from_xy(df_rx5["lon"], df_rx5["lat"]),
            crs="EPSG:4326"
        )
    
        fig, ax = plt.subplots(figsize=(10,8), subplot_kw={"projection": ccrs.PlateCarree()})
        for geom in gdf_basin.geometry:
            ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)
        for geom in gdf_lakes.geometry:
            ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='cyan', linewidth=1)
    
        ax.add_feature(cfeature.COASTLINE)
        ax.add_feature(cfeature.BORDERS, linestyle=':')
    
        sc = ax.scatter(gdf_stations.geometry.x, gdf_stations.geometry.y,
                        c=gdf_stations["obs_rx5day"], cmap="Reds", s=60,
                        transform=ccrs.PlateCarree(), edgecolor="k")
        plt.colorbar(sc, ax=ax, label="Obs Rx5day (mm)")
    
        ax.set_extent([-95.5, -72, 38.5, 52.5])  # approximate bounding
        gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
        gl.right_labels = False
        gl.top_labels   = False
    
        plt.title("Obs Rx5day (from CSV daily data)", fontsize=14)
        plt.show()
    
    except Exception as e:
        print("Mapping step failed:", e)
    
    print("\n✅ Done computing precipitation-based indices from CSV, with station names included!")


In [None]:
# Calculating the climatic indices for EMDNA (temperature (tmin-tmax)) 

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 0.  WHICH ENSEMBLES TO PROCESS
###############################################################################
ENSEMBLES = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]     

for ens in ENSEMBLES:                                   
    print(f"\n================  ENSEMBLE {ens}  ================\n")

    ###############################################################################
    # 1. FILE PATHS & CONFIG
    ###############################################################################
    csv_file = rf"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\EMDNA_GLB_Temperature\{ens}\daily_loop\emdna_vs_stations_25km_LWR_1991_2012_tmin_tmax.csv"
    physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation_Temperature.csv"

    shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
    lakes_shp = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

    output_dir = rf"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\EMDNA_GLB_Temperature\{ens}\ClimaticIndices-25KM"
    os.makedirs(output_dir, exist_ok=True)

    ###############################################################################
    # 2. LOAD DAILY CSV DATA
    ###############################################################################
    print("Loading daily CSV data (obs, emdna_lwr25_val) ...")
    df_data = pd.read_csv(csv_file)
    df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")
    df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

    print(f"df_data shape = {df_data.shape}")
    print("Columns:", df_data.columns.tolist())
    print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

    ###############################################################################
    # 3. MERGE WITH PHYSICAL FILE TO GET LAT/LON
    ###############################################################################
    df_phys = pd.read_csv(physical_file)
    df_phys = df_phys.rename(columns={
        "NAME": "station_name",
        "LATITUDE": "lat",
        "LONGITUDE": "lon",
        "Elevation": "elev"
    })
    df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

    ###############################################################################
    # 4. HELPER FUNCTIONS – FULL ETCCDI SET FOR Tmin / Tmax
    ###############################################################################
    def absolute_extremes(series, kind):
        """Return the single-day absolute extreme."""
        if kind == "max":
            return series.max(skipna=True)
        else:                       # "min"
            return series.min(skipna=True)

    # -------------------------------------------------------------------------
    # percentile thresholds (5-day moving window, baseline = 1991-2012)
    # -------------------------------------------------------------------------
    BASE_START, BASE_END = "1991-01-01", "2012-12-31"

    def _climatology_percentiles(s, p):
        """Return a Series (index = 1…366) of the p-th percentile."""
        # drop 29 Feb so every year has 365 days
        s = s[~((s.index.month == 2) & (s.index.day == 29))]
        df = pd.DataFrame({"val": s, "doy": s.index.dayofyear})
        climo = []
        for d in range(1, 366):
            win = list(range(d-2, d+3))                       # ±2-day window
            win = [(x-1) % 365 + 1 for x in win]              # wrap around ends
            vals = df.loc[df["doy"].isin(win), "val"]
            climo.append(np.nanpercentile(vals, p) if len(vals) else np.nan)
        return pd.Series(climo, index=range(1, 366), name=f"p{p}")

    def percentile_flags(s, perc_series, side):
        """Return Boolean Series: True where value is < or > percentile."""
        doy = s.index.dayofyear
        thr = perc_series.reindex(doy).values
        if side == "low":
            return s < thr
        else:
            return s > thr

    def spell_length(bool_series, min_run=6):
        """Total # days in spells of ≥ min_run consecutive Trues."""
        is_true = bool_series.fillna(False).values
        # identify run lengths
        run_ends = np.where(np.diff(np.concatenate(([0], is_true, [0]))))[0]
        lengths  = run_ends[1::2] - run_ends[::2]
        return lengths[lengths >= min_run].sum()

    # -------------------------------------------------------------------------
    # absolute-threshold counters
    # -------------------------------------------------------------------------
    def count_threshold(series, op, thr):
        if op == "<":
            return (series < thr).sum()
        else:
            return (series > thr).sum()

    ###############################################################################
    # 5.  ANNUAL ETCCDI INDICES  –  FORMAT COMPATIBLE WITH THE SEASONAL SCRIPT
    #     • keeps:   TXx  TNn  TX90p  TN10p  FD  WSDI  CSDI
    #     • drops:   TR, ID, SU, TXn, TNx
    ###############################################################################
    rows = []

    print("→ computing *annual* indices …")
    for st_name, st_grp in df_data.groupby("station_name"):

        st_grp = st_grp.set_index("time").sort_index()

        # build daily Series ................................................................
        obs_max   = st_grp.loc[st_grp["var"] == "tmax", "obs"].asfreq("D")
        obs_min   = st_grp.loc[st_grp["var"] == "tmin", "obs"].asfreq("D")
        emdna_max = st_grp.loc[st_grp["var"] == "tmax", "emdna_lwr25_val"].asfreq("D")
        emdna_min = st_grp.loc[st_grp["var"] == "tmin", "emdna_lwr25_val"].asfreq("D")
        if obs_max.empty:          # station has no data at all
            continue

        # ── fixed-year climatology (1991-2012, 5-day moving window) ─────────────-
        p90_TX_obs   = _climatology_percentiles(obs_max [BASE_START:BASE_END], 90)
        p10_TN_obs   = _climatology_percentiles(obs_min [BASE_START:BASE_END], 10)
        p90_TX_emdna = _climatology_percentiles(emdna_max[BASE_START:BASE_END], 90)
        p10_TN_emdna = _climatology_percentiles(emdna_min[BASE_START:BASE_END], 10)

        # flags for the *full* record (faster than recomputing year-by-year)
        flags = {
            "obs_TX90":   obs_max  > p90_TX_obs  .reindex(obs_max .index.dayofyear).values,
            "emdna_TX90": emdna_max> p90_TX_emdna.reindex(emdna_max.index.dayofyear).values,
            "obs_TN10":   obs_min  < p10_TN_obs  .reindex(obs_min .index.dayofyear).values,
            "emdna_TN10": emdna_min< p10_TN_emdna.reindex(emdna_min.index.dayofyear).values,
        }

        # ── iterate over years (December belongs to the following DJF year) ──────
        years = np.unique(obs_max.index.year)
        for yr in years:
            mask = obs_max.index.year == yr
            if mask.sum() < 200:         # at least ~55 % of a year
                continue

            def _sel(s):          # helper to slice one year
                return s[s.index.year == yr]

            # intensity ....................................................................
            TXx_obs   = _sel(obs_max  ).max()
            TXx_emdna = _sel(emdna_max).max()
            TNn_obs   = _sel(obs_min  ).min()
            TNn_emdna = _sel(emdna_min).min()

            # percentile frequencies (percentage of days)
            TX90p_obs   = flags["obs_TX90"  ][mask].mean() * 100.0
            TX90p_emdna = flags["emdna_TX90"][mask].mean() * 100.0
            TN10p_obs   = flags["obs_TN10"  ][mask].mean() * 100.0
            TN10p_emdna = flags["emdna_TN10"][mask].mean() * 100.0

            # spell duration (≥6 consecutive days)
            WSDI_obs   = spell_length(flags["obs_TX90"  ][mask], min_run=6)
            WSDI_emdna = spell_length(flags["emdna_TX90"][mask], min_run=6)
            CSDI_obs   = spell_length(flags["obs_TN10"  ][mask], min_run=6)
            CSDI_emdna = spell_length(flags["emdna_TN10"][mask], min_run=6)

            # absolute‐threshold count
            FD_obs   = (_sel(obs_min  ) < 0).sum()
            FD_emdna = (_sel(emdna_min) < 0).sum()

            # helper for ratios (avoid /0)
            ratio = lambda o, e: np.nan if (o == 0 or np.isnan(o)) else e / o

            rows.append(dict(
                station_name=st_name, year=yr,
                TXx_obs   = TXx_obs,   TXx_emdna   = TXx_emdna,   TXx_ratio   = ratio(TXx_obs, TXx_emdna),
                TNn_obs   = TNn_obs,   TNn_emdna   = TNn_emdna,   TNn_ratio   = ratio(TNn_obs, TNn_emdna),
                TX90p_obs = TX90p_obs, TX90p_emdna = TX90p_emdna, TX90p_ratio = ratio(TX90p_obs, TX90p_emdna),
                TN10p_obs = TN10p_obs, TN10p_emdna = TN10p_emdna, TN10p_ratio = ratio(TN10p_obs, TN10p_emdna),
                FD_obs    = FD_obs,    FD_emdna    = FD_emdna,    FD_ratio    = ratio(FD_obs,  FD_emdna),
                WSDI_obs  = WSDI_obs,  WSDI_emdna  = WSDI_emdna,  WSDI_ratio  = ratio(WSDI_obs, WSDI_emdna),
                CSDI_obs  = CSDI_obs,  CSDI_emdna  = CSDI_emdna,  CSDI_ratio  = ratio(CSDI_obs, CSDI_emdna)
            ))

    df_yr = (pd.DataFrame(rows)
             .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                    on="station_name", how="left")
             .sort_values(["station_name", "year"]))

    ###############################################################################
    # 6.  SAVE OUTPUT  – ONE PARQUET ( + XLSX )  +  OPTIONAL per-index sheets
    ###############################################################################
    annual_pq  = os.path.join(output_dir, "Indices_Annual.parquet")
    annual_xls = annual_pq.replace(".parquet", ".xlsx")

    df_yr.to_parquet(annual_pq, index=False)
    df_yr.to_excel  (annual_xls, index=False)
    print(f"✓ Annual indices saved → {annual_pq}")
    print(f"✓ …and also saved as   → {annual_xls}")

    # OPTIONAL: write one Excel file per index in the familiar “wide” format
    # ---------------------------------------------------------------------------
    index_roots = ["TXx", "TNn", "TX90p", "TN10p", "FD", "WSDI", "CSDI"]
    def _wide(idx_root: str) -> pd.DataFrame:
        return (df_yr
                .pivot_table(index="station_name",
                             values=[f"{idx_root}_obs", f"{idx_root}_emdna", f"{idx_root}_ratio"])
                .reset_index()
                .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                       on="station_name", how="left"))

    print("\n(optional) individual workbooks …")
    for idx in index_roots:
        w  = _wide(idx)
        fp = os.path.join(output_dir, f"{idx}.xlsx")
        w.to_excel(fp, index=False)
        print("  •", os.path.basename(fp))

    print("\n✅  Annual-index workflow finished for ensemble", ens)

In [None]:
# Temporal stratification of climatic indices for prcp EMDNA

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. CONFIGURATION & PATHS   ★ REWRITTEN FOR MULTI-ENSEMBLE ★
###############################################################################
root_dir = (r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder"
            r"\Ensemble files\EMDNA_GLB_Precipitation")

# the 10 ensemble sub-folders to process
ENSEMBLES = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]

# static station-metadata file (same for every ensemble)
physical_file = (r"D:\PhD\GLB\Merged USA and CA\Entire GLB"
                 r"\filtered_stations_with_elevation.csv")


# ────────────────────────────────────────────────────────────────────────────
# BEGIN LOOP over ensembles – everything downstream is indented inside it
# ────────────────────────────────────────────────────────────────────────────
for ENS in ENSEMBLES:
    print("\n" + "="*84)
    print(f"⧉  Temporal-stratified indices for ensemble {ENS:03d}  ⧉")
    print("="*84)

    ens_dir = os.path.join(root_dir, str(ENS))

    # daily-loop CSV produced by the 25-km LWR script
    csv_file = os.path.join(
        ens_dir, "daily_loop",
        f"emdna_vs_stations_25km_LWR_1991_2012_prcp_{ENS:03d}.csv"
    )
    if not os.path.isfile(csv_file):
        print(f"   ⚠  {os.path.basename(csv_file)} not found – skipping ensemble")
        continue

    # output folder for this ensemble
    output_dir = os.path.join(ens_dir, "ClimaticIndices-25KM", "TemporalStrat")
    os.makedirs(output_dir, exist_ok=True)

    ############################################################################
    # 2. LOAD DAILY DATA & ADD TEMPORAL FIELDS
    ############################################################################
    print("Loading daily CSV data …")
    df_data = pd.read_csv(csv_file)
    df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

    # Standardize station_name
    df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()
    
    # Add month (1..12) and season (DJF, MAM, JJA, SON)
    df_data["month"] = df_data["time"].dt.month
    
    def get_season(month):
        if month in [12, 1, 2]:
            return "DJF"
        elif month in [3, 4, 5]:
            return "MAM"
        elif month in [6, 7, 8]:
            return "JJA"
        else:
            return "SON"
    
    df_data["season"] = df_data["month"].apply(get_season)
    
    print("Time range:", df_data["time"].min(), "to", df_data["time"].max())
    
    ###############################################################################
    # 3. LOAD PHYSICAL FILE & MERGE COORDINATES
    ###############################################################################
    df_phys = pd.read_csv(physical_file)
    df_phys = df_phys.rename(columns={
        "NAME": "station_name",
        "LATITUDE": "lat",
        "LONGITUDE": "lon",
        "Elevation": "elev"
    })
    df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()
    
    ###############################################################################
    # 4. HELPER FUNCTIONS FOR CLIMATE INDICES
    ###############################################################################
    def calc_rx1day(series):
        """Max 1-day precipitation."""
        return series.max(skipna=True)
    
    def calc_rx5day(series):
        """Max 5-day running sum."""
        return series.rolling(5, min_periods=1).sum().max(skipna=True)
    
    def calc_cdd(series, dry_threshold=1.0):
        """Longest run of days < dry_threshold."""
        is_dry = series < dry_threshold
        max_run, current_run = 0, 0
        for val in is_dry:
            if val:
                current_run += 1
                max_run = max(max_run, current_run)
            else:
                current_run = 0
        return max_run
    
    def calc_cwd(series, wet_threshold=1.0):
        """Longest run of days >= wet_threshold."""
        is_wet = series >= wet_threshold
        max_run, current_run = 0, 0
        for val in is_wet:
            if val:
                current_run += 1
                max_run = max(max_run, current_run)
            else:
                current_run = 0
        return max_run
    
    def calc_r95p_r99p(series, percentiles=(95,99)):
        """
        r95amt, r95pct, r99amt, r99pct:
        - r95amt = sum of daily prcp above 95th percentile
        - r95pct = (r95amt / total) * 100
        - similarly for 99th percentile
        """
        wet = series[series >= 1.0]
        if len(wet) < 5:
            return np.nan, np.nan, np.nan, np.nan
        p95 = np.percentile(wet, percentiles[0])
        p99 = np.percentile(wet, percentiles[1])
        r95_amt = wet[wet > p95].sum()
        r99_amt = wet[wet > p99].sum()
        total = series.sum(skipna=True)
        r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
        r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
        return r95_amt, r95_pct, r99_amt, r99_pct
    
    def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
        """
        wetdays = #days >= wet_thr
        drydays = #days < dry_thr
        """
        return (series >= wet_thr).sum(), (series < dry_thr).sum()
    
    ###############################################################################
    # 5. FUNCTION TO COMPUTE INDICES FOR A GROUP (MONTHLY or SEASONAL)
    ###############################################################################
    def compute_indices(df_group):
        """
        For a subset of daily data (e.g. station+month, or station+season),
        compute the climate indices for Obs vs EMD, plus ratio columns.
        """
        obs_series = df_group["obs"].dropna().reset_index(drop=True)
        emd_series = df_group["emdna_lwr25_val"].dropna().reset_index(drop=True)
        if len(obs_series) == 0 or len(emd_series) == 0:
            return None
        
        res = {}
        # Rx1day / Rx5day
        res["rx1day_obs"] = calc_rx1day(obs_series)
        res["rx1day_emd"] = calc_rx1day(emd_series)
        res["rx5day_obs"] = calc_rx5day(obs_series)
        res["rx5day_emd"] = calc_rx5day(emd_series)
        
        # CDD / CWD
        res["cdd_obs"] = calc_cdd(obs_series)
        res["cdd_emd"] = calc_cdd(emd_series)
        res["cwd_obs"] = calc_cwd(obs_series)
        res["cwd_emd"] = calc_cwd(emd_series)
        
        # R95 / R99
        r95_obs = calc_r95p_r99p(obs_series)
        r95_emd = calc_r95p_r99p(emd_series)
        res["r95amt_obs"] = r95_obs[0]
        res["r95pct_obs"] = r95_obs[1]
        res["r95amt_emd"] = r95_emd[0]
        res["r95pct_emd"] = r95_emd[1]
        res["r99amt_obs"] = r95_obs[2]
        res["r99pct_obs"] = r95_obs[3]
        res["r99amt_emd"] = r95_emd[2]
        res["r99pct_emd"] = r95_emd[3]
        
        # Wet / Dry days
        wet_obs, dry_obs = calc_wetdays_drydays(obs_series)
        wet_emd, dry_emd = calc_wetdays_drydays(emd_series)
        res["wetdays_obs"] = wet_obs
        res["wetdays_emd"] = wet_emd
        res["drydays_obs"] = dry_obs
        res["drydays_emd"] = dry_emd
        
        # Ratio columns: EMD/OBS if obs != 0
        if res["rx1day_obs"]:
            res["rx1day_ratio"] = res["rx1day_emd"] / res["rx1day_obs"]
        if res["rx5day_obs"]:
            res["rx5day_ratio"] = res["rx5day_emd"] / res["rx5day_obs"]
        if res["cdd_obs"]:
            res["cdd_ratio"] = res["cdd_emd"] / res["cdd_obs"]
        if res["cwd_obs"]:
            res["cwd_ratio"] = res["cwd_emd"] / res["cwd_obs"]
        if res["r95amt_obs"]:
            res["r95amt_ratio"] = res["r95amt_emd"] / res["r95amt_obs"]
        if res["r95pct_obs"]:
            res["r95pct_ratio"] = res["r95pct_emd"] / res["r95pct_obs"]
        if res["r99amt_obs"]:
            res["r99amt_ratio"] = res["r99amt_emd"] / res["r99amt_obs"]
        if res["r99pct_obs"]:
            res["r99pct_ratio"] = res["r99pct_emd"] / res["r99pct_obs"]
        if res["wetdays_obs"]:
            res["wetdays_ratio"] = res["wetdays_emd"] / res["wetdays_obs"]
        if res["drydays_obs"]:
            res["drydays_ratio"] = res["drydays_emd"] / res["drydays_obs"]
        
        return res
    
    ###############################################################################
    # 6. MONTHLY INDICES
    ###############################################################################
    monthly_results = []
    group_month = df_data.groupby(["station_name", "month"])
    for (st_name, mon), group in group_month:
        indices = compute_indices(group)
        if indices is None:
            continue
        indices["station_name"] = st_name
        indices["month"] = mon
        monthly_results.append(indices)
    
    df_monthly = pd.DataFrame(monthly_results)
    df_monthly = pd.merge(
        df_monthly,
        df_phys[["station_name", "lat", "lon", "elev"]],
        on="station_name",
        how="left"
    )
    df_monthly = df_monthly.sort_values(["station_name", "month"])
    monthly_out = os.path.join(output_dir, "Indices_Monthly.xlsx")
    df_monthly.to_excel(monthly_out, index=False)
    print("Monthly indices saved =>", monthly_out)
    
    ###############################################################################
    # 7. SEASONAL INDICES
    ###############################################################################
    seasonal_results = []
    group_season = df_data.groupby(["station_name", "season"])
    for (st_name, seas), group in group_season:
        indices = compute_indices(group)
        if indices is None:
            continue
        indices["station_name"] = st_name
        indices["season"] = seas
        seasonal_results.append(indices)
    
    df_seasonal = pd.DataFrame(seasonal_results)
    df_seasonal = pd.merge(
        df_seasonal,
        df_phys[["station_name", "lat", "lon", "elev"]],
        on="station_name",
        how="left"
    )
    df_seasonal = df_seasonal.sort_values(["station_name", "season"])
    seasonal_out = os.path.join(output_dir, "Indices_Seasonal.xlsx")
    df_seasonal.to_excel(seasonal_out, index=False)
    print("Seasonal indices saved =>", seasonal_out)
    
    ###############################################################################
    # 8. DONE
    ###############################################################################
    print(f"✅  Finished ensemble {ENS:03d}  →  outputs in  {output_dir}")


In [None]:
# Seasonal (DJF / MAM / JJA / SON) analysis for prcp

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import Point
from scipy.stats import pearsonr
import seaborn as sns

###############################################################################
# 1. CONFIGURATION & PATHS  ★ MULTI-ENSEMBLE + MULTI-SEASON ★
###############################################################################
root_dir = (r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder"
            r"\Ensemble files\EMDNA_GLB_Precipitation")

# ensembles to analyse
ENSEMBLES = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]

# seasons to analyse
SEASONS = ["DJF", "MAM", "JJA", "SON"]

# static files
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

# index → (obs-cols , emd-cols) mapping (unchanged)
index_columns = {
    "rx1day":  ("rx1day_obs",  "rx1day_emd"),
    "rx5day":  ("rx5day_obs",  "rx5day_emd"),
    "cdd":     ("cdd_obs",     "cdd_emd"),
    "cwd":     ("cwd_obs",     "cwd_emd"),
    "r95p":    (("r95amt_obs", "r95pct_obs"),
                ("r95amt_emd", "r95pct_emd")),
    "r99p":    (("r99amt_obs", "r99pct_obs"),
                ("r99amt_emd", "r99pct_emd")),
    "wetdays": ("wetdays_obs", "wetdays_emd"),
    "drydays": ("drydays_obs", "drydays_emd"),
}
# ── BLOCK-1 : add right after index_columns ──────────────────────────────
index_list = [
    "rx1day", "rx5day", "cdd", "cwd",
    "r95p", "r99p", "wetdays", "drydays"
]
# ────────────────────────────────────────────────────────────────────────────
# OUTER LOOP → ensembles
# ────────────────────────────────────────────────────────────────────────────
for ENS in ENSEMBLES:
    indices_dir = os.path.join(root_dir, str(ENS),
                               "ClimaticIndices-25KM", "TemporalStrat")
    seasonal_file = os.path.join(indices_dir, "Indices_Seasonal.xlsx")
    
    # ❷ …but fall back to the legacy “ClimaticIndices2” folder if not found
    if not os.path.isfile(seasonal_file):
        legacy_dir   = os.path.join(root_dir, str(ENS), "ClimaticIndices2")
        legacy_file  = os.path.join(legacy_dir, "Indices_Seasonal.xlsx")
        if os.path.isfile(legacy_file):
            indices_dir   = legacy_dir          # use the old folder
            seasonal_file = legacy_file
        else:
            print(f"⚠  Indices_Seasonal.xlsx not found for ensemble {ENS:03d}")
            continue

    # load once per ensemble – we'll slice by season later
    df_all_seasons = pd.read_excel(seasonal_file).dropna(subset=["lat", "lon"])
    print(f"\n================  ENSEMBLE {ENS:03d}  =================")
    print("Seasonal file rows:", len(df_all_seasons))

    # ────────────────────────────────────────────────────────────────────────
    # INNER LOOP → seasons
    # ────────────────────────────────────────────────────────────────────────
    for SEAS in SEASONS:
        mdf = df_all_seasons[df_all_seasons["season"] == SEAS].copy()
        if mdf.empty:
            print(f"  · No data for {SEAS} – skipped")
            continue

        # season-specific output folder
        output_plots = os.path.join(indices_dir, f"AnalysisPlots_{SEAS}")
        os.makedirs(output_plots, exist_ok=True)
        print(f"\n—— {SEAS} ——  ({len(mdf)} station-season rows)  →  {output_plots}")

        # keep a master-table copy for this season
        master_xlsx = os.path.join(output_plots, f"MasterTable_{SEAS}.xlsx")
        mdf.to_excel(master_xlsx, index=False)
        print(f"(A) Master table saved → {master_xlsx}")

        # ── BLOCK-2 : full per-season workflow ──────────────────────────
        # === 2-A  Summary statistics ===================================
        def index_of_agreement(obs, mod):
            om = np.mean(obs)
            num = np.sum((mod - obs) ** 2)
            den = np.sum((np.abs(mod - om) + np.abs(obs - om)) ** 2)
            return np.nan if den == 0 else 1 - num / den

        rmse      = lambda a, b: np.sqrt(np.mean((a - b) ** 2))
        std_resid = lambda a, b: np.std(a - b, ddof=1)
        mean_bias = lambda a, b: np.mean(b - a)

        summary_rows = []
        for idx in index_list:
            obs_cols, emd_cols = index_columns[idx]

            # one-column indices
            if not isinstance(obs_cols, tuple):
                oc, ec = obs_cols, emd_cols
                v = mdf[[oc, ec]].dropna()
                if len(v) < 2:   continue
                o, e = v[oc].values, v[ec].values
                summary_rows.append({
                    "Index":  idx,
                    "Count":  len(v),
                    "MBE":    mean_bias(o, e),
                    "RMSE":   rmse(o, e),
                    "STDres": std_resid(o, e),
                    "CC":     pearsonr(o, e)[0] if len(o) > 1 else np.nan,
                    "d":      index_of_agreement(o, e),
                })
            # two-column (amt / pct) indices
            else:
                for oc, ec in zip(obs_cols, emd_cols):
                    v = mdf[[oc, ec]].dropna()
                    if len(v) < 2: continue
                    o, e = v[oc].values, v[ec].values
                    summary_rows.append({
                        "Index":  f"{idx}_{oc.replace('_obs','')}",
                        "Count":  len(v),
                        "MBE":    mean_bias(o, e),
                        "RMSE":   rmse(o, e),
                        "STDres": std_resid(o, e),
                        "CC":     pearsonr(o, e)[0] if len(o) > 1 else np.nan,
                        "d":      index_of_agreement(o, e),
                    })

        summary_df = pd.DataFrame(summary_rows)[
            ["Index", "Count", "MBE", "RMSE", "STDres", "CC", "d"]
        ]
        summ_xlsx = os.path.join(output_plots,
                                 f"SummaryTable_Extremes_{SEAS}.xlsx")
        summary_df.to_excel(summ_xlsx, index=False)
        print(f"(B) Summary table → {summ_xlsx}")

        # === 2-B  Helpers for maps / plots =============================
        gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
        gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

        def add_outline(ax):
            ax.add_feature(cfeature.BORDERS, linestyle=":")
            for g in gdf_basin.geometry:
                ax.add_geometries([g], ccrs.PlateCarree(),
                                  facecolor="none", edgecolor="black")
            for g in gdf_lakes.geometry:
                ax.add_geometries([g], ccrs.PlateCarree(),
                                  facecolor="none", edgecolor="cyan")

        def map_triptych(df, oc, ec, rc, idx):
            fig, axes = plt.subplots(1, 3, figsize=(18, 6),
                         subplot_kw=dict(projection=ccrs.PlateCarree()))
            def one(ax, col, ttl):
                ax.set_extent([-95.5, -72, 38.5, 52.5])
                add_outline(ax)
                sc = ax.scatter(df["lon"], df["lat"], c=df[col],
                                cmap="viridis", s=60, edgecolor="k",
                                transform=ccrs.PlateCarree())
                plt.colorbar(sc, ax=ax, shrink=0.8).set_label(col)
                vals = df[col].dropna().values
                if len(vals):
                    thr = np.percentile(vals, 90)
                    hot = df[col] >= thr
                    ax.scatter(df.loc[hot,"lon"], df.loc[hot,"lat"],
                               facecolors="none", edgecolors="red", s=80,
                               transform=ccrs.PlateCarree())
                ax.set_title(ttl)
                gl = ax.gridlines(draw_labels=True, linestyle="--", color="gray")
                gl.right_labels = gl.top_labels = False
            one(axes[0], oc, f"{idx} OBS ({SEAS})")
            one(axes[1], ec, f"{idx} EMD ({SEAS})")
            one(axes[2], rc, f"{idx} Ratio ({SEAS})")
            out_png = os.path.join(output_plots,
                                   f"{SEAS}_{idx}_MAP_3panel.png")
            plt.tight_layout(); plt.savefig(out_png, dpi=300); plt.close()
            print("Map →", out_png)

        def dist_triptych(df, oc, ec, idx):
            fig, axes = plt.subplots(1, 3, figsize=(18, 6))

            # boxplot
            sns.boxplot(data=pd.melt(df[[oc, ec]].rename(
                    columns={oc: "Obs", ec: "EMD"})),
                x="variable", y="value", ax=axes[0])
            axes[0].set_title(f"Box: {idx} ({SEAS})")

            # CDF
            def ecdf(x):
                sx = np.sort(x)
                return sx, np.arange(1, len(sx)+1) / len(sx)
            o, e = df[oc].dropna(), df[ec].dropna()
            if len(o) > 1 and len(e) > 1:
                axes[1].plot(*ecdf(o), label="Obs")
                axes[1].plot(*ecdf(e), label="EMD")
                axes[1].legend()
            axes[1].set_title(f"CDF: {idx} ({SEAS})")

            # scatter
            v = df[[oc, ec]].dropna()
            if len(v) > 1:
                x, y = v[oc], v[ec]
                cc   = pearsonr(x, y)[0]
                axes[2].scatter(x, y, edgecolors="k")
                lim = [min(x.min(), y.min()), max(x.max(), y.max())]
                axes[2].plot(lim, lim, "r--")
                axes[2].set_title(f"Scatter r={cc:.2f}")
            else:
                axes[2].set_title("Scatter: n/a")

            out_png = os.path.join(output_plots,
                                   f"{SEAS}_{idx}_Distribution_3panel.png")
            plt.tight_layout(); plt.savefig(out_png, dpi=300); plt.close()
            print("Distribution →", out_png)

        # === 2-C  Generate outputs for every index =====================
        for idx in index_list:
            if idx in ["rx1day","rx5day","cdd","cwd","drydays"]:
                oc, ec, rc = f"{idx}_obs", f"{idx}_emd", f"{idx}_ratio"
            elif idx == "wetdays":
                oc, ec, rc = "wetdays_obs", "wetdays_emd", "wetdays_ratio"
            elif idx == "r95p":
                oc, ec, rc = "r95amt_obs", "r95amt_emd", "r95amt_ratio"
            elif idx == "r99p":
                oc, ec, rc = "r99amt_obs", "r99amt_emd", "r99amt_ratio"
            else:
                continue
            if not all(c in mdf.columns for c in [oc, ec, rc]): continue
            ms = mdf.dropna(subset=["lat", "lon"])
            map_triptych(ms, oc, ec, rc, idx)
            dist_triptych(ms[[oc, ec]].dropna(), oc, ec, idx)

        print(f"✔ finished {SEAS}   (ensemble {ENS:03d})")
        # ── end BLOCK-2 ────────────────────────────────────────────────


        
        ###############################################################################
        # 3. DONE
        ###############################################################################
        print(f"\nAll {SEAS} steps completed!  See outputs in: {output_plots}")

In [None]:
                 #########################                                           ##############################
                 #########################                  RDRS                     ##############################
                 #########################                                           ##############################

In [None]:
# Calculating the climatic indices for RDRS v2.1 (prcp)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\daily_loop\rdrs_vs_stations_25km_LWR_1991_2012.csv"
physical_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\ClimaticIndices-25km"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, rdrs_val) ...")
df_data = pd.read_csv(csv_file)
# parse 'time' as datetime if needed
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# unify station_name: remove leading/trailing spaces, uppercase
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. (OPTIONAL) MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
# unify station_name
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

# We'll merge lat/lon AFTER computing the indices, so each final row has lat/lon.

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    roll_5 = series.rolling(5, min_periods=1).sum()
    return roll_5.max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    maxr, curr = 0, 0
    for val in is_dry:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    maxr, curr = 0, 0
    for val in is_wet:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_r95p_r99p(series, percentile=(95,99)):
    """R95p, R99p TOT in mm, plus percentage of total."""
    # only wet days >=1 mm for percentile
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentile[0])
    p99 = np.percentile(wet, percentile[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total   = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """Count #wet days >=5 mm, #dry days <1 mm."""
    w = (series >= wet_thr).sum()
    d = (series <  dry_thr).sum()
    return w, d

###############################################################################
# 5. COMPUTE INDICES FOR EACH STATION
###############################################################################
# We'll store results in lists of dicts, then convert to DataFrame => Excel.
rx1_list, rx5_list = [], []
cdd_list, cwd_list = [], []
r95_list, r99_list = [], []
wet_list, dry_list = [], []

print("Computing indices for each station...")
grouped = df_data.groupby("station_name", as_index=False)

for st_name, grp in grouped:
    # Sort by time (just in case)
    grp = grp.sort_values("time")
    # daily obs/era5
    obs_series = grp["obs"].dropna().reset_index(drop=True)
    rdrs_series = grp["rdrs_val"].dropna().reset_index(drop=True)

    # A) Rx1day
    obs_rx1 = calc_rx1day(obs_series)
    rdrs_rx1 = calc_rx1day(rdrs_series)
    rx1_list.append({"station_name": st_name,
                     "obs_rx1day": obs_rx1,
                     "rdrs_rx1day": rdrs_rx1})

    # B) Rx5day
    obs_rx5 = calc_rx5day(obs_series)
    rdrs_rx5 = calc_rx5day(rdrs_series)
    rx5_list.append({"station_name": st_name,
                     "obs_rx5day": obs_rx5,
                     "rdrs_rx5day": rdrs_rx5})

    # C) CDD
    obs_cdd_val = calc_cdd(obs_series)
    rdrs_cdd_val = calc_cdd(rdrs_series)
    cdd_list.append({"station_name": st_name,
                     "obs_cdd": obs_cdd_val,
                     "rdrs_cdd": rdrs_cdd_val})

    # D) CWD
    obs_cwd_val = calc_cwd(obs_series)
    rdrs_cwd_val = calc_cwd(rdrs_series)
    cwd_list.append({"station_name": st_name,
                     "obs_cwd": obs_cwd_val,
                     "rdrs_cwd": rdrs_cwd_val})

    # E) R95 / R99
    or95a, or95p, or99a, or99p = calc_r95p_r99p(obs_series)
    er95a, er95p, er99a, er99p = calc_r95p_r99p(rdrs_series)
    r95_list.append({
        "station_name": st_name,
        "obs_r95amt": or95a, "obs_r95pct": or95p,
        "rdrs_r95amt": er95a, "rdrs_r95pct": er95p
    })
    r99_list.append({
        "station_name": st_name,
        "obs_r99amt": or99a, "obs_r99pct": or99p,
        "rdrs_r99amt": er99a, "rdrs_r99pct": er99p
    })

    # F) wet/dry days
    obs_wet5, obs_dry = calc_wetdays_drydays(obs_series)
    rdrs_wet5, rdrs_dry = calc_wetdays_drydays(rdrs_series)
    wet_list.append({
        "station_name": st_name,
        "obs_wetdays5mm": obs_wet5, 
        "rdrs_wetdays5mm": rdrs_wet5
    })
    dry_list.append({
        "station_name": st_name,
        "obs_drydays": obs_dry,
        "rdrs_drydays": rdrs_dry
    })

print("Finished computing. Now merging lat/lon from physical file ...")

# Convert each list to DataFrame => merge lat,lon => save
def attach_coords(df_in):
    """Attach lat, lon, elev from df_phys on station_name."""
    df_out = pd.merge(
        df_in, 
        df_phys[["station_name","lat","lon","elev"]],
        on="station_name", 
        how="left"
    )
    return df_out

df_rx1 = attach_coords(pd.DataFrame(rx1_list))
df_rx5 = attach_coords(pd.DataFrame(rx5_list))
df_cdd = attach_coords(pd.DataFrame(cdd_list))
df_cwd = attach_coords(pd.DataFrame(cwd_list))
df_r95 = attach_coords(pd.DataFrame(r95_list))
df_r99 = attach_coords(pd.DataFrame(r99_list))
df_wet = attach_coords(pd.DataFrame(wet_list))
df_dry = attach_coords(pd.DataFrame(dry_list))

###############################################################################
# 6. SAVE OUTPUT (SAME FILE NAMES AS BEFORE)
###############################################################################
print("Saving index tables to Excel in:", output_dir)
df_rx1.to_excel(os.path.join(output_dir, "rx1day.xlsx"),  index=False)
df_rx5.to_excel(os.path.join(output_dir, "rx5day.xlsx"),  index=False)
df_cdd.to_excel(os.path.join(output_dir, "cdd.xlsx"),     index=False)
df_cwd.to_excel(os.path.join(output_dir, "cwd.xlsx"),     index=False)
df_r95.to_excel(os.path.join(output_dir, "r95p.xlsx"),    index=False)
df_r99.to_excel(os.path.join(output_dir, "r99p.xlsx"),    index=False)
df_wet.to_excel(os.path.join(output_dir, "wetdays.xlsx"), index=False)
df_dry.to_excel(os.path.join(output_dir, "drydays.xlsx"), index=False)

print("\nAll precipitation-based indices have been saved to Excel with station_name (and lat/lon).")

###############################################################################
# (OPTIONAL) QUICK MAP EXAMPLE (like obs_rx5day)
###############################################################################
try:
    print("\nQuick map example for obs_rx5day ...")
    # Load shapefiles
    gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
    gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

    gdf_stations = gpd.GeoDataFrame(
        df_rx5,
        geometry=gpd.points_from_xy(df_rx5["lon"], df_rx5["lat"]),
        crs="EPSG:4326"
    )

    fig, ax = plt.subplots(figsize=(10,8), subplot_kw={"projection": ccrs.PlateCarree()})
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='cyan', linewidth=1)

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')

    sc = ax.scatter(gdf_stations.geometry.x, gdf_stations.geometry.y,
                    c=gdf_stations["obs_rx5day"], cmap="Reds", s=60,
                    transform=ccrs.PlateCarree(), edgecolor="k")
    plt.colorbar(sc, ax=ax, label="Obs Rx5day (mm)")

    ax.set_extent([-95.5, -72, 38.5, 52.5])  # approximate bounding box
    gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
    gl.right_labels = False
    gl.top_labels   = False

    plt.title("Obs Rx5day (from CSV daily data)", fontsize=14)
    plt.show()

except Exception as e:
    print("Mapping step failed:", e)

print("\n✅ Done computing precipitation-based indices from CSV, with station names included!")


In [None]:
# Calculating the climatic indices for RDRS v2.1 (temperature - Tmin/Tmax)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Temperature\daily_loop\rdrs_vs_stations_25km_LWR_1991_2012_tmin_tmax.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation_Temperature.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Temperature\ClimaticIndices-8Nearest"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, rdrs_lwr25_val) ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS – FULL ETCCDI SET FOR Tmin / Tmax
###############################################################################
def absolute_extremes(series, kind):
    """Return the single-day absolute extreme."""
    if kind == "max":
        return series.max(skipna=True)
    else:                       # "min"
        return series.min(skipna=True)

# -------------------------------------------------------------------------
# percentile thresholds (5-day moving window, baseline = 1991-2012)
# -------------------------------------------------------------------------
BASE_START, BASE_END = "1991-01-01", "2012-12-31"

def _climatology_percentiles(s, p):
    """Return a Series (index = 1…366) of the p-th percentile."""
    # drop 29 Feb so every year has 365 days
    s = s[~((s.index.month == 2) & (s.index.day == 29))]
    df = pd.DataFrame({"val": s, "doy": s.index.dayofyear})
    climo = []
    for d in range(1, 366):
        win = list(range(d-2, d+3))                       # ±2-day window
        win = [(x-1) % 365 + 1 for x in win]              # wrap around ends
        vals = df.loc[df["doy"].isin(win), "val"]
        climo.append(np.nanpercentile(vals, p) if len(vals) else np.nan)
    return pd.Series(climo, index=range(1, 366), name=f"p{p}")

def percentile_flags(s, perc_series, side):
    """Return Boolean Series: True where value is < or > percentile."""
    doy = s.index.dayofyear
    thr = perc_series.reindex(doy).values
    if side == "low":
        return s < thr
    else:
        return s > thr

def spell_length(bool_series, min_run=6):
    """Total # days in spells of ≥ min_run consecutive Trues."""
    is_true = bool_series.fillna(False).values
    # identify run lengths
    run_ends = np.where(np.diff(np.concatenate(([0], is_true, [0]))))[0]
    lengths  = run_ends[1::2] - run_ends[::2]
    return lengths[lengths >= min_run].sum()

# -------------------------------------------------------------------------
# absolute-threshold counters
# -------------------------------------------------------------------------
def count_threshold(series, op, thr):
    if op == "<":
        return (series < thr).sum()
    else:
        return (series > thr).sum()

###############################################################################
# 5.  ANNUAL ETCCDI INDICES  –  FORMAT COMPATIBLE WITH THE SEASONAL SCRIPT
#     • keeps:   TXx  TNn  TX90p  TN10p  FD  WSDI  CSDI
#     • drops:   TR, ID, SU, TXn, TNx
###############################################################################
rows = []

print("→ computing *annual* indices …")
for st_name, st_grp in df_data.groupby("station_name"):

    st_grp = st_grp.set_index("time").sort_index()

    # build daily Series ................................................................
    obs_max = st_grp.loc[st_grp["var"] == "tmax", "obs"].asfreq("D")
    obs_min = st_grp.loc[st_grp["var"] == "tmin", "obs"].asfreq("D")
    rdrs_max = st_grp.loc[st_grp["var"] == "tmax", "rdrs_lwr25_val"].asfreq("D")
    rdrs_min = st_grp.loc[st_grp["var"] == "tmin", "rdrs_lwr25_val"].asfreq("D")
    if obs_max.empty:          # station has no data at all
        continue

    # ── fixed-year climatology (1991-2012, 5-day moving window) ─────────────-
    p90_TX_obs = _climatology_percentiles(obs_max[BASE_START:BASE_END], 90)
    p10_TN_obs = _climatology_percentiles(obs_min[BASE_START:BASE_END], 10)
    p90_TX_rdrs = _climatology_percentiles(rdrs_max[BASE_START:BASE_END], 90)
    p10_TN_rdrs = _climatology_percentiles(rdrs_min[BASE_START:BASE_END], 10)

    # flags for the *full* record (faster than recomputing year-by-year)
    flags = {
        "obs_TX90": obs_max >
                     p90_TX_obs.reindex(obs_max.index.dayofyear).values,
        "rdrs_TX90": rdrs_max >
                     p90_TX_rdrs.reindex(rdrs_max.index.dayofyear).values,
        "obs_TN10": obs_min <
                     p10_TN_obs.reindex(obs_min.index.dayofyear).values,
        "rdrs_TN10": rdrs_min <
                     p10_TN_rdrs.reindex(rdrs_min.index.dayofyear).values,
    }

    # ── iterate over years (December belongs to the following DJF year) ──────
    years = np.unique(obs_max.index.year)
    for yr in years:
        mask = obs_max.index.year == yr
        if mask.sum() < 200:         # at least ~55 % of a year
            continue

        def _sel(s):          # helper to slice one year
            return s[s.index.year == yr]

        # intensity ....................................................................
        TXx_obs = _sel(obs_max).max()
        TXx_rdrs = _sel(rdrs_max).max()
        TNn_obs = _sel(obs_min).min()
        TNn_rdrs = _sel(rdrs_min).min()

        # percentile frequencies (percentage of days)
        TX90p_obs = flags["obs_TX90"][mask].mean() * 100.0
        TX90p_rdrs = flags["rdrs_TX90"][mask].mean() * 100.0
        TN10p_obs = flags["obs_TN10"][mask].mean() * 100.0
        TN10p_rdrs = flags["rdrs_TN10"][mask].mean() * 100.0

        # spell duration (≥6 consecutive days)
        WSDI_obs = spell_length(flags["obs_TX90"][mask], min_run=6)
        WSDI_rdrs = spell_length(flags["rdrs_TX90"][mask], min_run=6)
        CSDI_obs = spell_length(flags["obs_TN10"][mask], min_run=6)
        CSDI_rdrs = spell_length(flags["rdrs_TN10"][mask], min_run=6)

        # absolute‐threshold count
        FD_obs = (_sel(obs_min) < 0).sum()
        FD_rdrs = (_sel(rdrs_min) < 0).sum()

        # helper for ratios (avoid /0)
        ratio = lambda o, e: np.nan if (o == 0 or np.isnan(o)) else e / o

        rows.append(dict(
            station_name=st_name, year=yr,
            TXx_obs=TXx_obs,   TXx_rdrs=TXx_rdrs,   TXx_ratio=ratio(TXx_obs, TXx_rdrs),
            TNn_obs=TNn_obs,   TNn_rdrs=TNn_rdrs,   TNn_ratio=ratio(TNn_obs, TNn_rdrs),
            TX90p_obs=TX90p_obs, TX90p_rdrs=TX90p_rdrs,
            TX90p_ratio=ratio(TX90p_obs, TX90p_rdrs),
            TN10p_obs=TN10p_obs, TN10p_rdrs=TN10p_rdrs,
            TN10p_ratio=ratio(TN10p_obs, TN10p_rdrs),
            FD_obs=FD_obs,     FD_rdrs=FD_rdrs,     FD_ratio=ratio(FD_obs, FD_rdrs),
            WSDI_obs=WSDI_obs, WSDI_rdrs=WSDI_rdrs,
            WSDI_ratio=ratio(WSDI_obs, WSDI_rdrs),
            CSDI_obs=CSDI_obs, CSDI_rdrs=CSDI_rdrs,
            CSDI_ratio=ratio(CSDI_obs, CSDI_rdrs)
        ))

df_yr = (pd.DataFrame(rows)
         .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                on="station_name", how="left")
         .sort_values(["station_name", "year"]))

###############################################################################
# 6.  SAVE OUTPUT  – ONE PARQUET ( + XLSX )  +  OPTIONAL per-index sheets
###############################################################################
annual_pq  = os.path.join(output_dir, "Indices_Annual.parquet")
annual_xls = annual_pq.replace(".parquet", ".xlsx")

df_yr.to_parquet(annual_pq, index=False)
df_yr.to_excel  (annual_xls, index=False)
print(f"✓ Annual indices saved → {annual_pq}")
print(f"✓ …and also saved as   → {annual_xls}")

# OPTIONAL: write one Excel file per index in the familiar “wide” format
# ---------------------------------------------------------------------------
index_roots = ["TXx", "TNn", "TX90p", "TN10p", "FD", "WSDI", "CSDI"]
def _wide(idx_root: str) -> pd.DataFrame:
    return (df_yr
            .pivot_table(index="station_name",
                         values=[f"{idx_root}_obs", f"{idx_root}_rdrs",
                                 f"{idx_root}_ratio"])
            .reset_index()
            .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                   on="station_name", how="left"))

print("\n(optional) individual workbooks …")
for idx in index_roots:
    w = _wide(idx)
    fp = os.path.join(output_dir, f"{idx}.xlsx")
    w.to_excel(fp, index=False)
    print("  •", os.path.basename(fp))

print("\n✅  Annual-index workflow finished.")

In [None]:
# Temporal stratification of climatic indices RDRS v2.1 for having the seasonal indices

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIGURATION
###############################################################################
csv_file      = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\daily_loop\rdrs_vs_stations_25km_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\ClimaticIndices-Seasonal"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY DATA & ADD TEMPORAL FIELDS
###############################################################################
print("Loading daily CSV data ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# Standardize station_name
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

# Add month (1..12) and season (DJF, MAM, JJA, SON)
df_data["month"] = df_data["time"].dt.month

def get_season(month):
    if month in [12, 1, 2]:
        return "DJF"
    elif month in [3, 4, 5]:
        return "MAM"
    elif month in [6, 7, 8]:
        return "JJA"
    else:
        return "SON"

df_data["season"] = df_data["month"].apply(get_season)

print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. LOAD PHYSICAL FILE & MERGE COORDINATES
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    return series.rolling(5, min_periods=1).sum().max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    max_run, current_run = 0, 0
    for val in is_dry:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    max_run, current_run = 0, 0
    for val in is_wet:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_r95p_r99p(series, percentiles=(95,99)):
    """
    r95amt, r95pct, r99amt, r99pct:
    - r95amt = sum of daily prcp above 95th percentile
    - r95pct = (r95amt / total) * 100
    - similarly for 99th percentile
    """
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentiles[0])
    p99 = np.percentile(wet, percentiles[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """
    wetdays = #days >= wet_thr
    drydays = #days < dry_thr
    """
    return (series >= wet_thr).sum(), (series < dry_thr).sum()

###############################################################################
# 5. FUNCTION TO COMPUTE INDICES FOR A GROUP (MONTHLY or SEASONAL)
###############################################################################
def compute_indices(df_group):
    """
    For a subset of daily data (e.g. station+month, or station+season),
    compute the climate indices for Obs vs RDRS, plus ratio columns.
    """
    obs_series = df_group["obs"].dropna().reset_index(drop=True)
    rdrs_series = df_group["rdrs_val"].dropna().reset_index(drop=True)
    if len(obs_series) == 0 or len(rdrs_series) == 0:
        return None
    
    res = {}
    # Rx1day / Rx5day
    res["rx1day_obs"] = calc_rx1day(obs_series)
    res["rx1day_rdrs"] = calc_rx1day(rdrs_series)
    res["rx5day_obs"] = calc_rx5day(obs_series)
    res["rx5day_rdrs"] = calc_rx5day(rdrs_series)
    
    # CDD / CWD
    res["cdd_obs"] = calc_cdd(obs_series)
    res["cdd_rdrs"] = calc_cdd(rdrs_series)
    res["cwd_obs"] = calc_cwd(obs_series)
    res["cwd_rdrs"] = calc_cwd(rdrs_series)
    
    # R95 / R99
    r95_obs = calc_r95p_r99p(obs_series)
    r95_rdrs = calc_r95p_r99p(rdrs_series)
    res["r95amt_obs"] = r95_obs[0]
    res["r95pct_obs"] = r95_obs[1]
    res["r95amt_rdrs"] = r95_rdrs[0]
    res["r95pct_rdrs"] = r95_rdrs[1]
    res["r99amt_obs"] = r95_obs[2]
    res["r99pct_obs"] = r95_obs[3]
    res["r99amt_rdrs"] = r95_rdrs[2]
    res["r99pct_rdrs"] = r95_rdrs[3]
    
    # Wet / Dry days
    wet_obs, dry_obs = calc_wetdays_drydays(obs_series)
    wet_rdrs, dry_rdrs = calc_wetdays_drydays(rdrs_series)
    res["wetdays_obs"] = wet_obs
    res["wetdays_rdrs"] = wet_rdrs
    res["drydays_obs"] = dry_obs
    res["drydays_rdrs"] = dry_rdrs
    
    # Ratio columns: RDRS/OBS if obs != 0
    if res["rx1day_obs"]:
        res["rx1day_ratio"] = res["rx1day_rdrs"] / res["rx1day_obs"]
    if res["rx5day_obs"]:
        res["rx5day_ratio"] = res["rx5day_rdrs"] / res["rx5day_obs"]
    if res["cdd_obs"]:
        res["cdd_ratio"] = res["cdd_rdrs"] / res["cdd_obs"]
    if res["cwd_obs"]:
        res["cwd_ratio"] = res["cwd_rdrs"] / res["cwd_obs"]
    if res["r95amt_obs"]:
        res["r95amt_ratio"] = res["r95amt_rdrs"] / res["r95amt_obs"]
    if res["r95pct_obs"]:
        res["r95pct_ratio"] = res["r95pct_rdrs"] / res["r95pct_obs"]
    if res["r99amt_obs"]:
        res["r99amt_ratio"] = res["r99amt_rdrs"] / res["r99amt_obs"]
    if res["r99pct_obs"]:
        res["r99pct_ratio"] = res["r99pct_rdrs"] / res["r99pct_obs"]
    if res["wetdays_obs"]:
        res["wetdays_ratio"] = res["wetdays_rdrs"] / res["wetdays_obs"]
    if res["drydays_obs"]:
        res["drydays_ratio"] = res["drydays_rdrs"] / res["drydays_obs"]
    
    return res

###############################################################################
# 6. MONTHLY INDICES
###############################################################################
monthly_results = []
group_month = df_data.groupby(["station_name", "month"])
for (st_name, mon), group in group_month:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["month"] = mon
    monthly_results.append(indices)

df_monthly = pd.DataFrame(monthly_results)
df_monthly = pd.merge(
    df_monthly,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_monthly = df_monthly.sort_values(["station_name", "month"])
monthly_out = os.path.join(output_dir, "Indices_Monthly.xlsx")
df_monthly.to_excel(monthly_out, index=False)
print("Monthly indices saved =>", monthly_out)

###############################################################################
# 7. SEASONAL INDICES
###############################################################################
seasonal_results = []
group_season = df_data.groupby(["station_name", "season"])
for (st_name, seas), group in group_season:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["season"] = seas
    seasonal_results.append(indices)

df_seasonal = pd.DataFrame(seasonal_results)
df_seasonal = pd.merge(
    df_seasonal,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_seasonal = df_seasonal.sort_values(["station_name", "season"])
seasonal_out = os.path.join(output_dir, "Indices_Seasonal.xlsx")
df_seasonal.to_excel(seasonal_out, index=False)
print("Seasonal indices saved =>", seasonal_out)

###############################################################################
# 8. DONE
###############################################################################
print("\nAll monthly and seasonal indices have been saved. (No extreme-event stratification.)")


In [None]:
# DJF for RDRS v2.1

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import Point
from scipy.stats import pearsonr
import seaborn as sns

###############################################################################
# 1. CONFIG & PATHS
###############################################################################
indices_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\ClimaticIndices-Seasonal"
seasonal_file = os.path.join(indices_dir, "Indices_Seasonal.xlsx")  # single file
output_plots  = os.path.join(indices_dir, "AnalysisPlots_DJF")
os.makedirs(output_plots, exist_ok=True)

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

# Indices in your seasonal file
index_list = ["rx1day","rx5day","cdd","cwd","r95p","r99p","wetdays","drydays"]

# For summary stats, define how to find obs vs. rdrs columns
index_columns = {
    "rx1day":  ("rx1day_obs",  "rx1day_rdrs"),
    "rx5day":  ("rx5day_obs",  "rx5day_rdrs"),
    "cdd":     ("cdd_obs",     "cdd_rdrs"),
    "cwd":     ("cwd_obs",     "cwd_rdrs"),
    "r95p":    (("r95amt_obs","r95pct_obs"), ("r95amt_rdrs","r95pct_rdrs")),
    "r99p":    (("r99amt_obs","r99pct_obs"), ("r99amt_rdrs","r99pct_rdrs")),
    "wetdays": ("wetdays_obs","wetdays_rdrs"),
    "drydays": ("drydays_obs","drydays_rdrs"),
}

###############################################################################
# 2. LOAD SEASONAL FILE & FILTER TO DJF
###############################################################################
df_season = pd.read_excel(seasonal_file)
print("Loaded =>", seasonal_file, "| shape =", df_season.shape)

# Filter to DJF
df_season = df_season[df_season["season"]=="DJF"].copy()
df_season = df_season.dropna(subset=["lat","lon"])  # ensure lat/lon exist
print("After filtering to DJF => shape =", df_season.shape)

mdf = df_season.reset_index(drop=True)
master_xlsx = os.path.join(output_plots, "MasterTable_Seasonal_DJF.xlsx")
mdf.to_excel(master_xlsx, index=False)
print(f"\n(A) Master table (DJF) saved => {master_xlsx}")
print("Columns:", mdf.columns.tolist())

###############################################################################
# 3. SUMMARY TABLE (MBE, RMSE, STD, CC, d) for DJF
###############################################################################
def index_of_agreement(obs, model):
    obs_mean = np.mean(obs)
    num = np.sum((model - obs)**2)
    den = np.sum((abs(model - obs_mean) + abs(obs - obs_mean))**2)
    if den == 0:
        return np.nan
    return 1 - num/den

def rmse(a, b):
    return np.sqrt(np.mean((a-b)**2))

def std_of_residuals(a, b):
    return np.std(a-b, ddof=1)

def mean_bias_error(a, b):
    return np.mean(b-a)

summary_rows = []
for idx_name in index_list:
    obs_cols = index_columns[idx_name][0]
    rdrs_cols = index_columns[idx_name][1]

    if isinstance(obs_cols, tuple):
        # multiple columns
        for oc, ec in zip(obs_cols, rdrs_cols):
            valid = mdf[[oc, ec]].dropna()
            if len(valid) < 2:
                continue
            obs_vals = valid[oc].values
            rdrs_vals = valid[ec].values
            MB  = mean_bias_error(obs_vals, rdrs_vals)
            RM  = rmse(obs_vals, rdrs_vals)
            SR  = std_of_residuals(obs_vals, rdrs_vals)
            CC  = pearsonr(obs_vals, rdrs_vals)[0] if len(obs_vals) > 1 else np.nan
            dd  = index_of_agreement(obs_vals, rdrs_vals)
            idx_label = f"{idx_name}_{oc.replace('_obs','')}"
            summary_rows.append({
                "Index": idx_label,
                "Count": len(valid),
                "MBE": MB,
                "RMSE": RM,
                "STDres": SR,
                "CC": CC,
                "d": dd,
            })
    else:
        oc = obs_cols
        ec = rdrs_cols
        valid = mdf[[oc, ec]].dropna()
        if len(valid) < 2:
            continue
        obs_vals = valid[oc].values
        rdrs_vals = valid[ec].values
        MB = mean_bias_error(obs_vals, rdrs_vals)
        RM = rmse(obs_vals, rdrs_vals)
        SR = std_of_residuals(obs_vals, rdrs_vals)
        CC = pearsonr(obs_vals, rdrs_vals)[0] if len(obs_vals) > 1 else np.nan
        dd = index_of_agreement(obs_vals, rdrs_vals)
        summary_rows.append({
            "Index": idx_name,
            "Count": len(valid),
            "MBE": MB,
            "RMSE": RM,
            "STDres": SR,
            "CC": CC,
            "d": dd,
        })

summary_df = pd.DataFrame(summary_rows)
summary_cols = ["Index","Count","MBE","RMSE","STDres","CC","d"]
summary_df = summary_df[summary_cols]
summary_xlsx = os.path.join(output_plots, "SummaryTable_Extremes_DJF.xlsx")
summary_df.to_excel(summary_xlsx, index=False)
print(f"(B) Summary Table (DJF) => {summary_xlsx}\n{summary_df}")

###############################################################################
# 4. MAPPING: Combine Observed, RDRS v2.1, Ratio in One Figure
###############################################################################
gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

def add_basin_lakes(ax):
    #ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='black', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='cyan', linewidth=1)

def plot_map_triptych(df, obs_col, rdrs_col, ratio_col, idx_name, out_png):
    """
    Creates a single figure with 3 subplots (side-by-side):
      1) Observed
      2) RDRS v2.1
      3) Ratio (RDRS/OBS)
    Each subplot has a colorbar, a 90th-percentile hotspot circle, etc.
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6),
                             subplot_kw={"projection": ccrs.PlateCarree()})

    # We'll define a small helper to do each subplot
    def scatter_map(ax, value_col, title):
        ax.set_extent([-95.5, -72, 38.5, 52.5])
        add_basin_lakes(ax)
        sc = ax.scatter(df["lon"], df["lat"], c=df[value_col], cmap="viridis",
                        s=60, transform=ccrs.PlateCarree(), edgecolor="k", zorder=10)
        cb = plt.colorbar(sc, ax=ax, shrink=0.8)
        cb.set_label(value_col)

        # Hotspots => top 10%
        vals = df[value_col].dropna().values
        if len(vals) > 0:
            thr = np.percentile(vals, 90)
            is_hot = df[value_col]>=thr
            ax.scatter(df.loc[is_hot,"lon"], df.loc[is_hot,"lat"],
                       marker='o', facecolors='none', edgecolors='red', s=80,
                       transform=ccrs.PlateCarree(), zorder=11,
                       label=f"Hotspot >= {thr:.2f}")
        ax.set_title(title, fontsize=12)
        gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
        gl.right_labels = False
        gl.top_labels   = False
        ax.legend(loc='upper right')

    scatter_map(axes[0], obs_col,  f"{idx_name} Observed (DJF)")
    scatter_map(axes[1], rdrs_col,  f"{idx_name} RDRS (DJF)")
    scatter_map(axes[2], ratio_col,f"{idx_name} (RDRS/OBS) (DJF)")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

def get_map_cols(idx_name):
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs  = f"{idx_name}_obs"
        rdrs  = f"{idx_name}_rdrs"
        ratio= f"{idx_name}_ratio"
        return obs, rdrs, ratio
    elif idx_name == "wetdays":
        obs  = "wetdays_obs"
        rdrs  = "wetdays_rdrs"
        ratio= "wetdays_ratio"
        return obs, rdrs, ratio
    elif idx_name == "r95p":
        obs  = "r95amt_obs"
        rdrs  = "r95amt_rdrs"
        ratio= "r95amt_ratio"
        return obs, rdrs, ratio
    elif idx_name == "r99p":
        obs  = "r99amt_obs"
        rdrs  = "r99amt_rdrs"
        ratio= "r99amt_ratio"
        return obs, rdrs, ratio
    else:
        return None, None, None

for idx_name in index_list:
    obs_col, rdrs_col, ratio_col = get_map_cols(idx_name)
    if obs_col is None:
        continue

    needed_cols = [obs_col, rdrs_col, ratio_col, "lat", "lon"]
    if not all(c in mdf.columns for c in needed_cols):
        print(f"Skipping map for {idx_name} - missing columns.")
        continue

    subdf = mdf.dropna(subset=["lat","lon"]).copy()
    out_png = os.path.join(output_plots, f"DJF_{idx_name}_MAP_3panel.png")
    plot_map_triptych(subdf, obs_col, rdrs_col, ratio_col, idx_name, out_png)

###############################################################################
# 5. DISTRIBUTION & BOX/CDF/Scatter in One Figure
###############################################################################
def plot_distribution_triptych(df, obs_col, rdrs_col, label, out_png):
    """
    Creates a single figure with 3 subplots side-by-side:
      1) Boxplot
      2) CDF
      3) Scatter
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18,6))

    # A) Boxplot
    ax_box = axes[0]
    data = pd.DataFrame({"Obs": df[obs_col], "RDRS": df[rdrs_col]}).melt(
        var_name="Dataset", value_name=label
    )
    sns.boxplot(data=data, x="Dataset", y=label, ax=ax_box)
    ax_box.set_title(f"Boxplot: {label} (DJF)")

    # B) CDF
    ax_cdf = axes[1]
    obs_vals = df[obs_col].dropna()
    rdrs_vals = df[rdrs_col].dropna()

    def ecdf(x):
        xs = np.sort(x)
        ys = np.arange(1, len(xs)+1)/len(xs)
        return xs, ys

    if len(obs_vals)>=2 and len(rdrs_vals)>=2:
        xs_o, ys_o = ecdf(obs_vals)
        xs_e, ys_e = ecdf(rdrs_vals)
        ax_cdf.plot(xs_o, ys_o, label="Obs")
        ax_cdf.plot(xs_e, ys_e, label="RDRS")
        ax_cdf.set_title(f"CDF of {label} (DJF)")
        ax_cdf.set_xlabel(label)
        ax_cdf.set_ylabel("Probability")
        ax_cdf.legend()
    else:
        ax_cdf.set_title(f"CDF: not enough data ({label})")

    # C) Scatter
    ax_scat = axes[2]
    valid = df[[obs_col, rdrs_col]].dropna()
    if len(valid)>=2:
        x = valid[obs_col]
        y = valid[rdrs_col]
        cc, _ = pearsonr(x, y)
        ax_scat.scatter(x, y, edgecolors='k', alpha=0.7)
        mn, mx = np.nanmin([x.min(), y.min()]), np.nanmax([x.max(), y.max()])
        ax_scat.plot([mn, mx],[mn, mx],'r--')
        ax_scat.set_xlabel(f"Obs {label} (DJF)")
        ax_scat.set_ylabel(f"RDRS {label} (DJF)")
        ax_scat.set_title(f"{label} (Corr={cc:.2f}, DJF)")
    else:
        ax_scat.set_title(f"Scatter: not enough data ({label})")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

for idx_name in index_list:
    # figure out obs, emd columns
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs_col  = f"{idx_name}_obs"
        rdrs_col  = f"{idx_name}_rdrs"
    elif idx_name == "wetdays":
        obs_col  = "wetdays_obs"
        rdrs_col  = "wetdays_rdrs"
    elif idx_name == "r95p":
        obs_col  = "r95amt_obs"
        rdrs_col  = "r95amt_rdrs"
    elif idx_name == "r99p":
        obs_col  = "r99amt_obs"
        rdrs_col  = "r99amt_rdrs"
    else:
        continue

    if obs_col not in mdf.columns or rdrs_col not in mdf.columns:
        print(f"Skipping distribution for {idx_name} - missing columns.")
        continue

    subdf = mdf[[obs_col, rdrs_col]].dropna()
    if len(subdf)<2:
        print(f"Skipping distribution for {idx_name} - not enough data.")
        continue

    out_3panel = os.path.join(output_plots, f"DJF_{idx_name}_Distribution_3panel.png")
    plot_distribution_triptych(subdf, obs_col, rdrs_col, idx_name, out_3panel)

###############################################################################
# 6. DONE
###############################################################################
print("\nAll DJF steps completed! See outputs in:", output_plots)

## For the other seasons, just change any DJF to JJA, MAM, or SON

In [None]:
                 #########################                                           ##############################
                 #########################                  ERA5                     ##############################
                 #########################                                           ##############################

In [None]:
# Calculating the climatic indices for ERA5 (prcp)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\daily_loop\era5_vs_stations_8Nearest_LWR_1991_2012.csv"
physical_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\ClimaticIndices-8Nearest"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, era5_lwr8_val) ...")
df_data = pd.read_csv(csv_file)
# parse 'time' as datetime if needed
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# unify station_name: remove leading/trailing spaces, uppercase
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. (OPTIONAL) MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
# unify station_name
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    roll_5 = series.rolling(5, min_periods=1).sum()
    return roll_5.max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    maxr, curr = 0, 0
    for val in is_dry:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    maxr, curr = 0, 0
    for val in is_wet:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_r95p_r99p(series, percentile=(95,99)):
    """R95p, R99p TOT in mm, plus percentage of total."""
    # only wet days >=1 mm for percentile
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentile[0])
    p99 = np.percentile(wet, percentile[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total   = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """Count #wet days >=5 mm, #dry days <1 mm."""
    w = (series >= wet_thr).sum()
    d = (series <  dry_thr).sum()
    return w, d

###############################################################################
# 5. COMPUTE INDICES FOR EACH STATION
###############################################################################
rx1_list, rx5_list = [], []
cdd_list, cwd_list = [], []
r95_list, r99_list = [], []
wet_list, dry_list = [], []

print("Computing indices for each station...")

grouped = df_data.groupby("station_name", as_index=False)
for st_name, grp in grouped:
    # Sort by time (just in case)
    grp = grp.sort_values("time")

    # daily obs/era5
    obs_series = grp["obs"].dropna().reset_index(drop=True)
    era5_series = grp["era5_lwr8_val"].dropna().reset_index(drop=True)  # <--- REFERENCE CHANGED HERE

    # A) Rx1day
    obs_rx1 = calc_rx1day(obs_series)
    era5_rx1 = calc_rx1day(era5_series)
    rx1_list.append({"station_name": st_name,
                     "obs_rx1day": obs_rx1,
                     "era5_rx1day": era5_rx1})

    # B) Rx5day
    obs_rx5 = calc_rx5day(obs_series)
    era5_rx5 = calc_rx5day(era5_series)
    rx5_list.append({"station_name": st_name,
                     "obs_rx5day": obs_rx5,
                     "era5_rx5day": era5_rx5})

    # C) CDD
    obs_cdd_val = calc_cdd(obs_series)
    era5_cdd_val = calc_cdd(era5_series)
    cdd_list.append({"station_name": st_name,
                     "obs_cdd": obs_cdd_val,
                     "era5_cdd": era5_cdd_val})

    # D) CWD
    obs_cwd_val = calc_cwd(obs_series)
    era5_cwd_val = calc_cwd(era5_series)
    cwd_list.append({"station_name": st_name,
                     "obs_cwd": obs_cwd_val,
                     "era5_cwd": era5_cwd_val})

    # E) R95 / R99
    or95a, or95p, or99a, or99p = calc_r95p_r99p(obs_series)
    er95a, er95p, er99a, er99p = calc_r95p_r99p(era5_series)
    r95_list.append({
        "station_name": st_name,
        "obs_r95amt": or95a, "obs_r95pct": or95p,
        "era5_r95amt": er95a, "era5_r95pct": er95p
    })
    r99_list.append({
        "station_name": st_name,
        "obs_r99amt": or99a, "obs_r99pct": or99p,
        "era5_r99amt": er99a, "era5_r99pct": er99p
    })

    # F) wet/dry days
    obs_wet5, obs_dry = calc_wetdays_drydays(obs_series)
    era5_wet5, era5_dry = calc_wetdays_drydays(era5_series)
    wet_list.append({
        "station_name": st_name,
        "obs_wetdays5mm": obs_wet5, 
        "era5_wetdays5mm": era5_wet5
    })
    dry_list.append({
        "station_name": st_name,
        "obs_drydays": obs_dry,
        "era5_drydays": era5_dry
    })

print("Finished computing. Now merging lat/lon from physical file ...")

def attach_coords(df_in):
    """Attach lat, lon, elev from df_phys on station_name."""
    df_out = pd.merge(
        df_in, 
        df_phys[["station_name","lat","lon","elev"]],
        on="station_name", 
        how="left"
    )
    return df_out

df_rx1 = attach_coords(pd.DataFrame(rx1_list))
df_rx5 = attach_coords(pd.DataFrame(rx5_list))
df_cdd = attach_coords(pd.DataFrame(cdd_list))
df_cwd = attach_coords(pd.DataFrame(cwd_list))
df_r95 = attach_coords(pd.DataFrame(r95_list))
df_r99 = attach_coords(pd.DataFrame(r99_list))
df_wet = attach_coords(pd.DataFrame(wet_list))
df_dry = attach_coords(pd.DataFrame(dry_list))

###############################################################################
# 6. SAVE OUTPUT
###############################################################################
print("Saving index tables to Excel in:", output_dir)
df_rx1.to_excel(os.path.join(output_dir, "rx1day.xlsx"),  index=False)
df_rx5.to_excel(os.path.join(output_dir, "rx5day.xlsx"),  index=False)
df_cdd.to_excel(os.path.join(output_dir, "cdd.xlsx"),     index=False)
df_cwd.to_excel(os.path.join(output_dir, "cwd.xlsx"),     index=False)
df_r95.to_excel(os.path.join(output_dir, "r95p.xlsx"),    index=False)
df_r99.to_excel(os.path.join(output_dir, "r99p.xlsx"),    index=False)
df_wet.to_excel(os.path.join(output_dir, "wetdays.xlsx"), index=False)
df_dry.to_excel(os.path.join(output_dir, "drydays.xlsx"), index=False)

print("\nAll precipitation-based indices have been saved to Excel.")

###############################################################################
# (OPTIONAL) QUICK MAP EXAMPLE
###############################################################################
try:
    print("\nQuick map example for obs_rx5day ...")
    # Load shapefiles
    gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
    gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

    gdf_stations = gpd.GeoDataFrame(
        df_rx5,
        geometry=gpd.points_from_xy(df_rx5["lon"], df_rx5["lat"]),
        crs="EPSG:4326"
    )

    fig, ax = plt.subplots(figsize=(10,8), subplot_kw={"projection": ccrs.PlateCarree()})
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='cyan', linewidth=1)

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')

    sc = ax.scatter(gdf_stations.geometry.x, gdf_stations.geometry.y,
                    c=gdf_stations["obs_rx5day"], cmap="Reds", s=60,
                    transform=ccrs.PlateCarree(), edgecolor="k")
    plt.colorbar(sc, ax=ax, label="Obs Rx5day (mm)")
    ax.set_extent([-95.5, -72, 38.5, 52.5])  # approximate bounding box
    gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
    gl.right_labels = False
    gl.top_labels   = False

    plt.title("Obs Rx5day (from CSV daily data)", fontsize=14)
    plt.show()

except Exception as e:
    print("Mapping step failed:", e)

print("\n✅ Done computing precipitation-based indices from 'era5_lwr8_val' column!")


In [None]:
# Calculating the climatic indices for ERA5 (temperature (tmin-tmax))

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Temperature\Temperature ERA5\tmin-tmax\daily_loop\era5_vs_stations_8Nearest_LWR_1991_2012_tmin_tmax.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation_Temperature.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Temperature\Temperature ERA5\tmin-tmax\ClimaticIndices-8Nearest"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, era5_lwr8_val) ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS – FULL ETCCDI SET FOR Tmin / Tmax
###############################################################################
def absolute_extremes(series, kind):
    """Return the single-day absolute extreme."""
    if kind == "max":
        return series.max(skipna=True)
    else:                       # "min"
        return series.min(skipna=True)

# -------------------------------------------------------------------------
# percentile thresholds (5-day moving window, baseline = 1991-2012)
# -------------------------------------------------------------------------
BASE_START, BASE_END = "1991-01-01", "2012-12-31"

def _climatology_percentiles(s, p):
    """Return a Series (index = 1…366) of the p-th percentile."""
    # drop 29 Feb so every year has 365 days
    s = s[~((s.index.month == 2) & (s.index.day == 29))]
    df = pd.DataFrame({"val": s, "doy": s.index.dayofyear})
    climo = []
    for d in range(1, 366):
        win = list(range(d-2, d+3))                       # ±2-day window
        win = [(x-1) % 365 + 1 for x in win]              # wrap around ends
        vals = df.loc[df["doy"].isin(win), "val"]
        climo.append(np.nanpercentile(vals, p) if len(vals) else np.nan)
    return pd.Series(climo, index=range(1, 366), name=f"p{p}")

def percentile_flags(s, perc_series, side):
    """Return Boolean Series: True where value is < or > percentile."""
    doy = s.index.dayofyear
    thr = perc_series.reindex(doy).values
    if side == "low":
        return s < thr
    else:
        return s > thr

def spell_length(bool_series, min_run=6):
    """Total # days in spells of ≥ min_run consecutive Trues."""
    is_true = bool_series.fillna(False).values
    # identify run lengths
    run_ends = np.where(np.diff(np.concatenate(([0], is_true, [0]))))[0]
    lengths  = run_ends[1::2] - run_ends[::2]
    return lengths[lengths >= min_run].sum()

# -------------------------------------------------------------------------
# absolute-threshold counters
# -------------------------------------------------------------------------
def count_threshold(series, op, thr):
    if op == "<":
        return (series < thr).sum()
    else:
        return (series > thr).sum()

###############################################################################
# 5. COMPUTE INDICES FOR EACH STATION (OBS & ERA5 – Tmin / Tmax)
###############################################################################
indices_rows = []

print("Computing climatic indices …")
for st_name, grp in df_data.groupby("station_name"):

    # split the long table into two wide daily Series --------------------------------
    grp      = grp.set_index("time")
    obs_max  = grp.loc[grp["var"]=="tmax", "obs"].asfreq("D")
    obs_min  = grp.loc[grp["var"]=="tmin", "obs"].asfreq("D")
    era_max  = grp.loc[grp["var"]=="tmax", "era5_lwr8_val"].asfreq("D")
    era_min  = grp.loc[grp["var"]=="tmin", "era5_lwr8_val"].asfreq("D")

    for label, tmax, tmin in [("obs",  obs_max,  obs_min),
                              ("era5", era_max, era_min)]:

        if tmax.empty or tmin.empty:
            continue   # skip if station has no data

        # --- absolute extremes ----------------------------------------------------
        TXx = absolute_extremes(tmax, "max");  TNn = absolute_extremes(tmin, "min")
        TXn = absolute_extremes(tmax, "min");  TNx = absolute_extremes(tmin, "max")

        # --- percentile thresholds (baseline climatology) -------------------------
        base_max = tmax[BASE_START:BASE_END];  base_min = tmin[BASE_START:BASE_END]
        p90_TX   = _climatology_percentiles(base_max, 90)
        p10_TX   = _climatology_percentiles(base_max, 10)
        p90_TN   = _climatology_percentiles(base_min, 90)
        p10_TN   = _climatology_percentiles(base_min, 10)

        # flags for the whole record
        TX90p_flag = percentile_flags(tmax, p90_TX, "high")
        TX10p_flag = percentile_flags(tmax, p10_TX, "low")
        TN90p_flag = percentile_flags(tmin, p90_TN, "high")
        TN10p_flag = percentile_flags(tmin, p10_TN, "low")

        TX90p = TX90p_flag.mean() * 100.0          # % of days
        TX10p = TX10p_flag.mean() * 100.0
        TN90p = TN90p_flag.mean() * 100.0
        TN10p = TN10p_flag.mean() * 100.0

        # --- warm / cold spell duration indices -----------------------------------
        WSDI = spell_length(TX90p_flag, min_run=6)
        CSDI = spell_length(TN10p_flag, min_run=6)

        # --- absolute-threshold counters ------------------------------------------
        FD = count_threshold(tmin, "<", 0.0)       # Frost days
        ID = count_threshold(tmax, "<", 0.0)       # Ice days
        SU = count_threshold(tmax, ">", 25.0)      # Summer days
        TR = count_threshold(tmin, ">", 20.0)      # Tropical nights

        # --- aggregate into one row ----------------------------------------------
        indices_rows.append({
            "station_name": st_name,
            "dataset":      label,         # obs / era5
            # intensity
            "TXx": TXx, "TNn": TNn, "TXn": TXn, "TNx": TNx,
            # percentile frequency
            "TX90p_%": TX90p, "TX10p_%": TX10p,
            "TN90p_%": TN90p, "TN10p_%": TN10p,
            # spell duration
            "WSDI": WSDI, "CSDI": CSDI,
            # absolute threshold counts
            "FD": FD, "ID": ID, "SU": SU, "TR": TR
        })

df_idx = pd.DataFrame(indices_rows)
print(f"Computed indices table shape: {df_idx.shape}")

###############################################################################
# 6.  SAVE OUTPUT  (one sheet per dataset)
###############################################################################
###############################################################################
# 6.  SAVE OUTPUT  – ONE FILE PER INDEX  (wide table: obs_*, era5_*)
###############################################################################
# attach coordinates once (they will be the same for all indices)
df_idx = pd.merge(
    df_idx,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)

# ------------------------------------------------------------------
# helper: build wide-format DataFrame for ONE index
# ------------------------------------------------------------------
def wide_table(index_name: str) -> pd.DataFrame:
    """
    Return a wide table with columns:
        station_name, obs_<idx>, era5_<idx>, lat, lon, elev
    """
    wide = (
        df_idx.pivot_table(index="station_name",
                           columns="dataset",
                           values=index_name)
             .rename(columns={"obs": f"obs_{index_name}",
                              "era5": f"era5_{index_name}"})
             .reset_index()
    )
    # attach coords (unique per station)
    wide = pd.merge(wide,
                    df_phys[["station_name", "lat", "lon", "elev"]],
                    on="station_name", how="left")
    return wide

# list of all index columns we computed
index_cols = [
    "TXx","TNn","TXn","TNx",
    "TX90p_%","TX10p_%","TN90p_%","TN10p_%",
    "WSDI","CSDI","FD","ID","SU","TR"
]

print(f"Writing one Excel file per index into\n{output_dir}")
for idx in index_cols:
    df_w = wide_table(idx)
    out_xlsx = os.path.join(output_dir, f"{idx}.xlsx")
    df_w.to_excel(out_xlsx, index=False)
    print("  •", os.path.basename(out_xlsx))

# ------------------------------------------------------------------
# OPTIONAL – also write a single Excel file containing *all* indices
#            (each index in wide format, side-by-side)
# ------------------------------------------------------------------
wide_all = df_phys[["station_name", "lat", "lon", "elev"]].copy()
for idx in index_cols:
    w = wide_table(idx).drop(columns=["lat","lon","elev"])
    wide_all = pd.merge(wide_all, w, on="station_name", how="left")

all_path = os.path.join(output_dir, "Temperature_Indices_ALL.xlsx")
wide_all.to_excel(all_path, index=False)
print("  •", os.path.basename(all_path), "(contains every index)")

print("✅  All index files written.")

###############################################################################
# (OPTIONAL) QUICK MAP EXAMPLE  –  OBS-TXx (hottest daytime temperature)
###############################################################################
try:
    print("\nQuick map example: OBS – TXx (°C)")
    
    # base layers
    gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
    gdf_lakes = gpd.read_file(lakes_shp     ).to_crs(epsg=4326)

    # pick the OBS rows and build a GeoDataFrame
    obs_tx = df_idx[df_idx["dataset"] == "obs"].copy()
    gdf_stn = gpd.GeoDataFrame(
        obs_tx,
        geometry=gpd.points_from_xy(obs_tx["lon"], obs_tx["lat"]),
        crs="EPSG:4326"
    )

    fig, ax = plt.subplots(figsize=(10, 8),
                            subplot_kw=dict(projection=ccrs.PlateCarree()))
    
    # add basin + lakes outlines
    ax.add_geometries(gdf_basin.geometry, ccrs.PlateCarree(),
                      facecolor='none', edgecolor='black', linewidth=0.8)
    ax.add_geometries(gdf_lakes.geometry, ccrs.PlateCarree(),
                      facecolor='none', edgecolor='blue',  linewidth=0.8)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')

    # scatter the station values
    sc = ax.scatter(gdf_stn.geometry.x, gdf_stn.geometry.y,
                    c=gdf_stn["TXx"], cmap="hot_r", s=60,
                    edgecolor="k", transform=ccrs.PlateCarree())
    plt.colorbar(sc, ax=ax, label="TXx (°C)")

    ax.set_extent([-95.5, -72, 38.5, 52.5])   # Great Lakes frame
    ax.set_title("Station hottest-day temperature (TXx)\n1991-2012 – Observations",
                 fontsize=13)
    plt.show()

except Exception as e:
    print("Mapping step failed:", e)

print("\n✅  Finished: climatic indices computed & quick TXx map rendered.")



In [None]:
# Temporal stratification of climatic indices ERA5

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIGURATION
###############################################################################
csv_file      = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\daily_loop\era5_vs_stations_100km_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\ClimaticIndices2"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY DATA & ADD TEMPORAL FIELDS
###############################################################################
print("Loading daily CSV data ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# Standardize station_name
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

# Add month (1..12) and season (DJF, MAM, JJA, SON)
df_data["month"] = df_data["time"].dt.month

def get_season(month):
    if month in [12, 1, 2]:
        return "DJF"
    elif month in [3, 4, 5]:
        return "MAM"
    elif month in [6, 7, 8]:
        return "JJA"
    else:
        return "SON"

df_data["season"] = df_data["month"].apply(get_season)

print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. LOAD PHYSICAL FILE & MERGE COORDINATES
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    return series.rolling(5, min_periods=1).sum().max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    max_run, current_run = 0, 0
    for val in is_dry:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    max_run, current_run = 0, 0
    for val in is_wet:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_r95p_r99p(series, percentiles=(95,99)):
    """
    r95amt, r95pct, r99amt, r99pct:
    - r95amt = sum of daily prcp above 95th percentile
    - r95pct = (r95amt / total) * 100
    - similarly for 99th percentile
    """
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentiles[0])
    p99 = np.percentile(wet, percentiles[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """
    wetdays = #days >= wet_thr
    drydays = #days < dry_thr
    """
    return (series >= wet_thr).sum(), (series < dry_thr).sum()

###############################################################################
# 5. FUNCTION TO COMPUTE INDICES FOR A GROUP (MONTHLY or SEASONAL)
###############################################################################
def compute_indices(df_group):
    """
    For a subset of daily data (e.g. station+month, or station+season),
    compute the climate indices for Obs vs ERA5, plus ratio columns.
    """
    obs_series = df_group["obs"].dropna().reset_index(drop=True)
    era5_series = df_group["era5_val"].dropna().reset_index(drop=True)
    if len(obs_series) == 0 or len(era5_series) == 0:
        return None
    
    res = {}
    # Rx1day / Rx5day
    res["rx1day_obs"] = calc_rx1day(obs_series)
    res["rx1day_era5"] = calc_rx1day(era5_series)
    res["rx5day_obs"] = calc_rx5day(obs_series)
    res["rx5day_era5"] = calc_rx5day(era5_series)
    
    # CDD / CWD
    res["cdd_obs"] = calc_cdd(obs_series)
    res["cdd_era5"] = calc_cdd(era5_series)
    res["cwd_obs"] = calc_cwd(obs_series)
    res["cwd_era5"] = calc_cwd(era5_series)
    
    # R95 / R99
    r95_obs = calc_r95p_r99p(obs_series)
    r95_era5 = calc_r95p_r99p(era5_series)
    res["r95amt_obs"] = r95_obs[0]
    res["r95pct_obs"] = r95_obs[1]
    res["r95amt_era5"] = r95_era5[0]
    res["r95pct_era5"] = r95_era5[1]
    res["r99amt_obs"] = r95_obs[2]
    res["r99pct_obs"] = r95_obs[3]
    res["r99amt_era5"] = r95_era5[2]
    res["r99pct_era5"] = r95_era5[3]
    
    # Wet / Dry days
    wet_obs, dry_obs = calc_wetdays_drydays(obs_series)
    wet_era5, dry_era5 = calc_wetdays_drydays(era5_series)
    res["wetdays_obs"] = wet_obs
    res["wetdays_era5"] = wet_era5
    res["drydays_obs"] = dry_obs
    res["drydays_era5"] = dry_era5
    
    # Ratio columns: EMD/OBS if obs != 0
    if res["rx1day_obs"]:
        res["rx1day_ratio"] = res["rx1day_era5"] / res["rx1day_obs"]
    if res["rx5day_obs"]:
        res["rx5day_ratio"] = res["rx5day_era5"] / res["rx5day_obs"]
    if res["cdd_obs"]:
        res["cdd_ratio"] = res["cdd_era5"] / res["cdd_obs"]
    if res["cwd_obs"]:
        res["cwd_ratio"] = res["cwd_era5"] / res["cwd_obs"]
    if res["r95amt_obs"]:
        res["r95amt_ratio"] = res["r95amt_era5"] / res["r95amt_obs"]
    if res["r95pct_obs"]:
        res["r95pct_ratio"] = res["r95pct_era5"] / res["r95pct_obs"]
    if res["r99amt_obs"]:
        res["r99amt_ratio"] = res["r99amt_era5"] / res["r99amt_obs"]
    if res["r99pct_obs"]:
        res["r99pct_ratio"] = res["r99pct_era5"] / res["r99pct_obs"]
    if res["wetdays_obs"]:
        res["wetdays_ratio"] = res["wetdays_era5"] / res["wetdays_obs"]
    if res["drydays_obs"]:
        res["drydays_ratio"] = res["drydays_era5"] / res["drydays_obs"]
    
    return res

###############################################################################
# 6. MONTHLY INDICES
###############################################################################
monthly_results = []
group_month = df_data.groupby(["station_name", "month"])
for (st_name, mon), group in group_month:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["month"] = mon
    monthly_results.append(indices)

df_monthly = pd.DataFrame(monthly_results)
df_monthly = pd.merge(
    df_monthly,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_monthly = df_monthly.sort_values(["station_name", "month"])
monthly_out = os.path.join(output_dir, "Indices_Monthly.xlsx")
df_monthly.to_excel(monthly_out, index=False)
print("Monthly indices saved =>", monthly_out)

###############################################################################
# 7. SEASONAL INDICES
###############################################################################
seasonal_results = []
group_season = df_data.groupby(["station_name", "season"])
for (st_name, seas), group in group_season:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["season"] = seas
    seasonal_results.append(indices)

df_seasonal = pd.DataFrame(seasonal_results)
df_seasonal = pd.merge(
    df_seasonal,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_seasonal = df_seasonal.sort_values(["station_name", "season"])
seasonal_out = os.path.join(output_dir, "Indices_Seasonal.xlsx")
df_seasonal.to_excel(seasonal_out, index=False)
print("Seasonal indices saved =>", seasonal_out)

###############################################################################
# 8. DONE
###############################################################################
print("\nAll monthly and seasonal indices have been saved. (No extreme-event stratification.)")


In [None]:
# DJF for ERA5

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import Point
from scipy.stats import pearsonr
import seaborn as sns

###############################################################################
# 1. CONFIG & PATHS
###############################################################################
indices_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\ClimaticIndices2"
seasonal_file = os.path.join(indices_dir, "Indices_Seasonal.xlsx")  # single file
output_plots  = os.path.join(indices_dir, "AnalysisPlots_DJF")
os.makedirs(output_plots, exist_ok=True)

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

# Indices in your seasonal file
index_list = ["rx1day","rx5day","cdd","cwd","r95p","r99p","wetdays","drydays"]

# For summary stats, define how to find obs vs. emd columns
index_columns = {
    "rx1day":  ("rx1day_obs",  "rx1day_era5"),
    "rx5day":  ("rx5day_obs",  "rx5day_era5"),
    "cdd":     ("cdd_obs",     "cdd_era5"),
    "cwd":     ("cwd_obs",     "cwd_era5"),
    "r95p":    (("r95amt_obs","r95pct_obs"), ("r95amt_era5","r95pct_era5")),
    "r99p":    (("r99amt_obs","r99pct_obs"), ("r99amt_era5","r99pct_era5")),
    "wetdays": ("wetdays_obs","wetdays_era5"),
    "drydays": ("drydays_obs","drydays_era5"),
}

###############################################################################
# 2. LOAD SEASONAL FILE & FILTER TO DJF
###############################################################################
df_season = pd.read_excel(seasonal_file)
print("Loaded =>", seasonal_file, "| shape =", df_season.shape)

# Filter to DJF
df_season = df_season[df_season["season"]=="DJF"].copy()
df_season = df_season.dropna(subset=["lat","lon"])  # ensure lat/lon exist
print("After filtering to DJF => shape =", df_season.shape)

mdf = df_season.reset_index(drop=True)
master_xlsx = os.path.join(output_plots, "MasterTable_Seasonal_DJF.xlsx")
mdf.to_excel(master_xlsx, index=False)
print(f"\n(A) Master table (DJF) saved => {master_xlsx}")
print("Columns:", mdf.columns.tolist())

###############################################################################
# 3. SUMMARY TABLE (MBE, RMSE, STD, CC, d) for DJF
###############################################################################
def index_of_agreement(obs, model):
    obs_mean = np.mean(obs)
    num = np.sum((model - obs)**2)
    den = np.sum((abs(model - obs_mean) + abs(obs - obs_mean))**2)
    if den == 0:
        return np.nan
    return 1 - num/den

def rmse(a, b):
    return np.sqrt(np.mean((a-b)**2))

def std_of_residuals(a, b):
    return np.std(a-b, ddof=1)

def mean_bias_error(a, b):
    return np.mean(b-a)

summary_rows = []
for idx_name in index_list:
    obs_cols = index_columns[idx_name][0]
    era5_cols = index_columns[idx_name][1]

    if isinstance(obs_cols, tuple):
        # multiple columns
        for oc, ec in zip(obs_cols, era5_cols):
            valid = mdf[[oc, ec]].dropna()
            if len(valid) < 2:
                continue
            obs_vals = valid[oc].values
            era5_vals = valid[ec].values
            MB  = mean_bias_error(obs_vals, era5_vals)
            RM  = rmse(obs_vals, era5_vals)
            SR  = std_of_residuals(obs_vals, era5_vals)
            CC  = pearsonr(obs_vals, era5_vals)[0] if len(obs_vals) > 1 else np.nan
            dd  = index_of_agreement(obs_vals, era5_vals)
            idx_label = f"{idx_name}_{oc.replace('_obs','')}"
            summary_rows.append({
                "Index": idx_label,
                "Count": len(valid),
                "MBE": MB,
                "RMSE": RM,
                "STDres": SR,
                "CC": CC,
                "d": dd,
            })
    else:
        oc = obs_cols
        ec = era5_cols
        valid = mdf[[oc, ec]].dropna()
        if len(valid) < 2:
            continue
        obs_vals = valid[oc].values
        era5_vals = valid[ec].values
        MB = mean_bias_error(obs_vals, era5_vals)
        RM = rmse(obs_vals, era5_vals)
        SR = std_of_residuals(obs_vals, era5_vals)
        CC = pearsonr(obs_vals, era5_vals)[0] if len(obs_vals) > 1 else np.nan
        dd = index_of_agreement(obs_vals, era5_vals)
        summary_rows.append({
            "Index": idx_name,
            "Count": len(valid),
            "MBE": MB,
            "RMSE": RM,
            "STDres": SR,
            "CC": CC,
            "d": dd,
        })

summary_df = pd.DataFrame(summary_rows)
summary_cols = ["Index","Count","MBE","RMSE","STDres","CC","d"]
summary_df = summary_df[summary_cols]
summary_xlsx = os.path.join(output_plots, "SummaryTable_Extremes_DJF.xlsx")
summary_df.to_excel(summary_xlsx, index=False)
print(f"(B) Summary Table (DJF) => {summary_xlsx}\n{summary_df}")

###############################################################################
# 4. MAPPING: Combine Observed, EMD, Ratio in One Figure
###############################################################################
gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

def add_basin_lakes(ax):
    #ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='black', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='cyan', linewidth=1)

def plot_map_triptych(df, obs_col, emd_col, ratio_col, idx_name, out_png):
    """
    Creates a single figure with 3 subplots (side-by-side):
      1) Observed
      2) ERA5
      3) Ratio (ERA5/OBS)
    Each subplot has a colorbar, a 90th-percentile hotspot circle, etc.
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6),
                             subplot_kw={"projection": ccrs.PlateCarree()})

    # We'll define a small helper to do each subplot
    def scatter_map(ax, value_col, title):
        ax.set_extent([-95.5, -72, 38.5, 52.5])
        add_basin_lakes(ax)
        sc = ax.scatter(df["lon"], df["lat"], c=df[value_col], cmap="viridis",
                        s=60, transform=ccrs.PlateCarree(), edgecolor="k", zorder=10)
        cb = plt.colorbar(sc, ax=ax, shrink=0.8)
        cb.set_label(value_col)

        # Hotspots => top 10%
        vals = df[value_col].dropna().values
        if len(vals) > 0:
            thr = np.percentile(vals, 90)
            is_hot = df[value_col]>=thr
            ax.scatter(df.loc[is_hot,"lon"], df.loc[is_hot,"lat"],
                       marker='o', facecolors='none', edgecolors='red', s=80,
                       transform=ccrs.PlateCarree(), zorder=11,
                       label=f"Hotspot >= {thr:.2f}")
        ax.set_title(title, fontsize=12)
        gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
        gl.right_labels = False
        gl.top_labels   = False
        ax.legend(loc='upper right')

    scatter_map(axes[0], obs_col,  f"{idx_name} Observed (DJF)")
    scatter_map(axes[1], era5_col,  f"{idx_name} ERA5 (DJF)")
    scatter_map(axes[2], ratio_col,f"{idx_name} (ERA5/OBS) (DJF)")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

def get_map_cols(idx_name):
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs  = f"{idx_name}_obs"
        era5  = f"{idx_name}_era5"
        ratio= f"{idx_name}_ratio"
        return obs, era5, ratio
    elif idx_name == "wetdays":
        obs  = "wetdays_obs"
        era5  = "wetdays_era5"
        ratio= "wetdays_ratio"
        return obs, era5, ratio
    elif idx_name == "r95p":
        obs  = "r95amt_obs"
        era5  = "r95amt_era5"
        ratio= "r95amt_ratio"
        return obs, era5, ratio
    elif idx_name == "r99p":
        obs  = "r99amt_obs"
        era5  = "r99amt_era5"
        ratio= "r99amt_ratio"
        return obs, era5, ratio
    else:
        return None, None, None

for idx_name in index_list:
    obs_col, era5_col, ratio_col = get_map_cols(idx_name)
    if obs_col is None:
        continue

    needed_cols = [obs_col, era5_col, ratio_col, "lat", "lon"]
    if not all(c in mdf.columns for c in needed_cols):
        print(f"Skipping map for {idx_name} - missing columns.")
        continue

    subdf = mdf.dropna(subset=["lat","lon"]).copy()
    out_png = os.path.join(output_plots, f"DJF_{idx_name}_MAP_3panel.png")
    plot_map_triptych(subdf, obs_col, era5_col, ratio_col, idx_name, out_png)

###############################################################################
# 5. DISTRIBUTION & BOX/CDF/Scatter in One Figure
###############################################################################
def plot_distribution_triptych(df, obs_col, era5_col, label, out_png):
    """
    Creates a single figure with 3 subplots side-by-side:
      1) Boxplot
      2) CDF
      3) Scatter
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18,6))

    # A) Boxplot
    ax_box = axes[0]
    data = pd.DataFrame({"Obs": df[obs_col], "ERA5": df[era5_col]}).melt(
        var_name="Dataset", value_name=label
    )
    sns.boxplot(data=data, x="Dataset", y=label, ax=ax_box)
    ax_box.set_title(f"Boxplot: {label} (DJF)")

    # B) CDF
    ax_cdf = axes[1]
    obs_vals = df[obs_col].dropna()
    era5_vals = df[era5_col].dropna()

    def ecdf(x):
        xs = np.sort(x)
        ys = np.arange(1, len(xs)+1)/len(xs)
        return xs, ys

    if len(obs_vals)>=2 and len(era5_vals)>=2:
        xs_o, ys_o = ecdf(obs_vals)
        xs_e, ys_e = ecdf(era5_vals)
        ax_cdf.plot(xs_o, ys_o, label="Obs")
        ax_cdf.plot(xs_e, ys_e, label="ERA5")
        ax_cdf.set_title(f"CDF of {label} (DJF)")
        ax_cdf.set_xlabel(label)
        ax_cdf.set_ylabel("Probability")
        ax_cdf.legend()
    else:
        ax_cdf.set_title(f"CDF: not enough data ({label})")

    # C) Scatter
    ax_scat = axes[2]
    valid = df[[obs_col, era5_col]].dropna()
    if len(valid)>=2:
        x = valid[obs_col]
        y = valid[era5_col]
        cc, _ = pearsonr(x, y)
        ax_scat.scatter(x, y, edgecolors='k', alpha=0.7)
        mn, mx = np.nanmin([x.min(), y.min()]), np.nanmax([x.max(), y.max()])
        ax_scat.plot([mn, mx],[mn, mx],'r--')
        ax_scat.set_xlabel(f"Obs {label} (DJF)")
        ax_scat.set_ylabel(f"ERA5 {label} (DJF)")
        ax_scat.set_title(f"{label} (Corr={cc:.2f}, DJF)")
    else:
        ax_scat.set_title(f"Scatter: not enough data ({label})")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

for idx_name in index_list:
    # figure out obs, emd columns
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs_col  = f"{idx_name}_obs"
        era5_col  = f"{idx_name}_era5"
    elif idx_name == "wetdays":
        obs_col  = "wetdays_obs"
        era5_col  = "wetdays_era5"
    elif idx_name == "r95p":
        obs_col  = "r95amt_obs"
        era5_col  = "r95amt_era5"
    elif idx_name == "r99p":
        obs_col  = "r99amt_obs"
        era5_col  = "r99amt_era5"
    else:
        continue

    if obs_col not in mdf.columns or era5_col not in mdf.columns:
        print(f"Skipping distribution for {idx_name} - missing columns.")
        continue

    subdf = mdf[[obs_col, era5_col]].dropna()
    if len(subdf)<2:
        print(f"Skipping distribution for {idx_name} - not enough data.")
        continue

    out_3panel = os.path.join(output_plots, f"DJF_{idx_name}_Distribution_3panel.png")
    plot_distribution_triptych(subdf, obs_col, era5_col, idx_name, out_3panel)

###############################################################################
# 6. DONE
###############################################################################
print("\nAll DJF steps completed! See outputs in:", output_plots)

# For the other seasons, just change any DJF to JJA, MAM, or SON

In [None]:
                 #########################                                           ##############################
                 #########################                  MERRA-2                  ##############################
                 #########################                                           ##############################

In [None]:
# Calculating the climatic indices for MERRA2 (prcp)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\daily_loop\merra2_vs_stations_12Nearest_LWR_1991_2012.csv"
physical_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\ClimaticIndices-12Nearest"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, merra2_lwr12_val) ...")
df_data = pd.read_csv(csv_file)
# parse 'time' as datetime if needed
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# unify station_name: remove leading/trailing spaces, uppercase
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. (OPTIONAL) MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
# unify station_name
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    roll_5 = series.rolling(5, min_periods=1).sum()
    return roll_5.max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    maxr, curr = 0, 0
    for val in is_dry:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    maxr, curr = 0, 0
    for val in is_wet:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_r95p_r99p(series, percentile=(95,99)):
    """R95p, R99p TOT in mm, plus percentage of total."""
    # only wet days >=1 mm for percentile
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentile[0])
    p99 = np.percentile(wet, percentile[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total   = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """Count #wet days >=5 mm, #dry days <1 mm."""
    w = (series >= wet_thr).sum()
    d = (series <  dry_thr).sum()
    return w, d

###############################################################################
# 5. COMPUTE INDICES FOR EACH STATION
###############################################################################
rx1_list, rx5_list = [], []
cdd_list, cwd_list = [], []
r95_list, r99_list = [], []
wet_list, dry_list = [], []

print("Computing indices for each station...")

grouped = df_data.groupby("station_name", as_index=False)
for st_name, grp in grouped:
    # Sort by time (just in case)
    grp = grp.sort_values("time")

    # daily obs/era5
    obs_series = grp["obs"].dropna().reset_index(drop=True)
    merra2_series = grp["merra2_lwr12_val"].dropna().reset_index(drop=True)  # <--- REFERENCE CHANGED HERE

    # A) Rx1day
    obs_rx1 = calc_rx1day(obs_series)
    merra2_rx1 = calc_rx1day(merra2_series)
    rx1_list.append({"station_name": st_name,
                     "obs_rx1day": obs_rx1,
                     "merra2_rx1day": merra2_rx1})

    # B) Rx5day
    obs_rx5 = calc_rx5day(obs_series)
    merra2_rx5 = calc_rx5day(merra2_series)
    rx5_list.append({"station_name": st_name,
                     "obs_rx5day": obs_rx5,
                     "merra2_rx5day": merra2_rx5})

    # C) CDD
    obs_cdd_val = calc_cdd(obs_series)
    merra2_cdd_val = calc_cdd(merra2_series)
    cdd_list.append({"station_name": st_name,
                     "obs_cdd": obs_cdd_val,
                     "merra2_cdd": merra2_cdd_val})

    # D) CWD
    obs_cwd_val = calc_cwd(obs_series)
    merra2_cwd_val = calc_cwd(merra2_series)
    cwd_list.append({"station_name": st_name,
                     "obs_cwd": obs_cwd_val,
                     "merra2_cwd": merra2_cwd_val})

    # E) R95 / R99
    or95a, or95p, or99a, or99p = calc_r95p_r99p(obs_series)
    er95a, er95p, er99a, er99p = calc_r95p_r99p(merra2_series)
    r95_list.append({
        "station_name": st_name,
        "obs_r95amt": or95a, "obs_r95pct": or95p,
        "merra2_r95amt": er95a, "merra2_r95pct": er95p
    })
    r99_list.append({
        "station_name": st_name,
        "obs_r99amt": or99a, "obs_r99pct": or99p,
        "merra2_r99amt": er99a, "merra2_r99pct": er99p
    })

    # F) wet/dry days
    obs_wet5, obs_dry = calc_wetdays_drydays(obs_series)
    merra2_wet5, merra2_dry = calc_wetdays_drydays(merra2_series)
    wet_list.append({
        "station_name": st_name,
        "obs_wetdays5mm": obs_wet5, 
        "merra2_wetdays5mm": merra2_wet5
    })
    dry_list.append({
        "station_name": st_name,
        "obs_drydays": obs_dry,
        "merra2_drydays": merra2_dry
    })

print("Finished computing. Now merging lat/lon from physical file ...")

def attach_coords(df_in):
    """Attach lat, lon, elev from df_phys on station_name."""
    df_out = pd.merge(
        df_in, 
        df_phys[["station_name","lat","lon","elev"]],
        on="station_name", 
        how="left"
    )
    return df_out

df_rx1 = attach_coords(pd.DataFrame(rx1_list))
df_rx5 = attach_coords(pd.DataFrame(rx5_list))
df_cdd = attach_coords(pd.DataFrame(cdd_list))
df_cwd = attach_coords(pd.DataFrame(cwd_list))
df_r95 = attach_coords(pd.DataFrame(r95_list))
df_r99 = attach_coords(pd.DataFrame(r99_list))
df_wet = attach_coords(pd.DataFrame(wet_list))
df_dry = attach_coords(pd.DataFrame(dry_list))

###############################################################################
# 6. SAVE OUTPUT
###############################################################################
print("Saving index tables to Excel in:", output_dir)
df_rx1.to_excel(os.path.join(output_dir, "rx1day.xlsx"),  index=False)
df_rx5.to_excel(os.path.join(output_dir, "rx5day.xlsx"),  index=False)
df_cdd.to_excel(os.path.join(output_dir, "cdd.xlsx"),     index=False)
df_cwd.to_excel(os.path.join(output_dir, "cwd.xlsx"),     index=False)
df_r95.to_excel(os.path.join(output_dir, "r95p.xlsx"),    index=False)
df_r99.to_excel(os.path.join(output_dir, "r99p.xlsx"),    index=False)
df_wet.to_excel(os.path.join(output_dir, "wetdays.xlsx"), index=False)
df_dry.to_excel(os.path.join(output_dir, "drydays.xlsx"), index=False)

print("\nAll precipitation-based indices have been saved to Excel.")

###############################################################################
# (OPTIONAL) QUICK MAP EXAMPLE
###############################################################################
try:
    print("\nQuick map example for obs_rx5day ...")
    # Load shapefiles
    gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
    gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

    gdf_stations = gpd.GeoDataFrame(
        df_rx5,
        geometry=gpd.points_from_xy(df_rx5["lon"], df_rx5["lat"]),
        crs="EPSG:4326"
    )

    fig, ax = plt.subplots(figsize=(10,8), subplot_kw={"projection": ccrs.PlateCarree()})
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='cyan', linewidth=1)

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')

    sc = ax.scatter(gdf_stations.geometry.x, gdf_stations.geometry.y,
                    c=gdf_stations["obs_rx5day"], cmap="Reds", s=60,
                    transform=ccrs.PlateCarree(), edgecolor="k")
    plt.colorbar(sc, ax=ax, label="Obs Rx5day (mm)")
    ax.set_extent([-95.5, -72, 38.5, 52.5])  # approximate bounding box
    gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
    gl.right_labels = False
    gl.top_labels   = False

    plt.title("Obs Rx5day (from CSV daily data)", fontsize=14)
    plt.show()

except Exception as e:
    print("Mapping step failed:", e)

print("\n✅ Done computing precipitation-based indices from 'merra2_lwr12_val' column!")


In [None]:
# Calculating the temperature (tmin-tmax) climatic indices for MERRA2

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Temperature\daily_loop\merra2_vs_stations_12Nearest_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation_Temperature.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Temperature\ClimaticIndices-12Nearest"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, merra2_lwr12_val) ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS – FULL ETCCDI SET FOR Tmin / Tmax
###############################################################################
def absolute_extremes(series, kind):
    """Return the single-day absolute extreme."""
    if kind == "max":
        return series.max(skipna=True)
    else:                       # "min"
        return series.min(skipna=True)

# -------------------------------------------------------------------------
# percentile thresholds (5-day moving window, baseline = 1991-2012)
# -------------------------------------------------------------------------
BASE_START, BASE_END = "1991-01-01", "2012-12-31"

def _climatology_percentiles(s, p):
    """Return a Series (index = 1…366) of the p-th percentile."""
    # drop 29 Feb so every year has 365 days
    s = s[~((s.index.month == 2) & (s.index.day == 29))]
    df = pd.DataFrame({"val": s, "doy": s.index.dayofyear})
    climo = []
    for d in range(1, 366):
        win = list(range(d-2, d+3))                       # ±2-day window
        win = [(x-1) % 365 + 1 for x in win]              # wrap around ends
        vals = df.loc[df["doy"].isin(win), "val"]
        climo.append(np.nanpercentile(vals, p) if len(vals) else np.nan)
    return pd.Series(climo, index=range(1, 366), name=f"p{p}")

def percentile_flags(s, perc_series, side):
    """Return Boolean Series: True where value is < or > percentile."""
    doy = s.index.dayofyear
    thr = perc_series.reindex(doy).values
    if side == "low":
        return s < thr
    else:
        return s > thr

def spell_length(bool_series, min_run=6):
    """Total # days in spells of ≥ min_run consecutive Trues."""
    is_true = bool_series.fillna(False).values
    # identify run lengths
    run_ends = np.where(np.diff(np.concatenate(([0], is_true, [0]))))[0]
    lengths  = run_ends[1::2] - run_ends[::2]
    return lengths[lengths >= min_run].sum()

# -------------------------------------------------------------------------
# absolute-threshold counters
# -------------------------------------------------------------------------
def count_threshold(series, op, thr):
    if op == "<":
        return (series < thr).sum()
    else:
        return (series > thr).sum()

###############################################################################
# 5.  ANNUAL ETCCDI INDICES  –  FORMAT COMPATIBLE WITH THE SEASONAL SCRIPT
#     • keeps:   TXx  TNn  TX90p  TN10p  FD  WSDI  CSDI
#     • drops:   TR, ID, SU, TXn, TNx
###############################################################################
rows = []

print("→ computing *annual* indices …")
for st_name, st_grp in df_data.groupby("station_name"):

    st_grp = st_grp.set_index("time").sort_index()

    # build daily Series ................................................................
    obs_max = st_grp.loc[st_grp["var"] == "tmax", "obs"].asfreq("D")
    obs_min = st_grp.loc[st_grp["var"] == "tmin", "obs"].asfreq("D")
    merra2_max = st_grp.loc[st_grp["var"] == "tmax", "merra2_lwr12_val"].asfreq("D")
    merra2_min = st_grp.loc[st_grp["var"] == "tmin", "merra2_lwr12_val"].asfreq("D")
    if obs_max.empty:          # station has no data at all
        continue

    # ── fixed-year climatology (1991-2012, 5-day moving window) ─────────────-
    p90_TX_obs = _climatology_percentiles(obs_max[BASE_START:BASE_END], 90)
    p10_TN_obs = _climatology_percentiles(obs_min[BASE_START:BASE_END], 10)
    p90_TX_merra2 = _climatology_percentiles(merra2_max[BASE_START:BASE_END], 90)
    p10_TN_merra2 = _climatology_percentiles(merra2_min[BASE_START:BASE_END], 10)

    # flags for the *full* record (faster than recomputing year-by-year)
    flags = {
        "obs_TX90": obs_max >
                     p90_TX_obs.reindex(obs_max.index.dayofyear).values,
        "merra2_TX90": merra2_max >
                     p90_TX_merra2.reindex(merra2_max.index.dayofyear).values,
        "obs_TN10": obs_min <
                     p10_TN_obs.reindex(obs_min.index.dayofyear).values,
        "merra2_TN10": merra2_min <
                     p10_TN_merra2.reindex(merra2_min.index.dayofyear).values,
    }

    # ── iterate over years (December belongs to the following DJF year) ──────
    years = np.unique(obs_max.index.year)
    for yr in years:
        mask = obs_max.index.year == yr
        if mask.sum() < 200:         # at least ~55 % of a year
            continue

        def _sel(s):          # helper to slice one year
            return s[s.index.year == yr]

        # intensity ....................................................................
        TXx_obs = _sel(obs_max).max()
        TXx_merra2 = _sel(merra2_max).max()
        TNn_obs = _sel(obs_min).min()
        TNn_merra2 = _sel(merra2_min).min()

        # percentile frequencies (percentage of days)
        TX90p_obs = flags["obs_TX90"][mask].mean() * 100.0
        TX90p_merra2 = flags["merra2_TX90"][mask].mean() * 100.0
        TN10p_obs = flags["obs_TN10"][mask].mean() * 100.0
        TN10p_merra2 = flags["merra2_TN10"][mask].mean() * 100.0

        # spell duration (≥6 consecutive days)
        WSDI_obs = spell_length(flags["obs_TX90"][mask], min_run=6)
        WSDI_merra2 = spell_length(flags["merra2_TX90"][mask], min_run=6)
        CSDI_obs = spell_length(flags["obs_TN10"][mask], min_run=6)
        CSDI_merra2 = spell_length(flags["merra2_TN10"][mask], min_run=6)

        # absolute‐threshold count
        FD_obs = (_sel(obs_min) < 0).sum()
        FD_merra2 = (_sel(merra2_min) < 0).sum()

        # helper for ratios (avoid /0)
        ratio = lambda o, e: np.nan if (o == 0 or np.isnan(o)) else e / o

        rows.append(dict(
            station_name=st_name, year=yr,
            TXx_obs=TXx_obs,   TXx_merra2=TXx_merra2,   TXx_ratio=ratio(TXx_obs, TXx_merra2),
            TNn_obs=TNn_obs,   TNn_merra2=TNn_merra2,   TNn_ratio=ratio(TNn_obs, TNn_merra2),
            TX90p_obs=TX90p_obs, TX90p_merra2=TX90p_merra2,
            TX90p_ratio=ratio(TX90p_obs, TX90p_merra2),
            TN10p_obs=TN10p_obs, TN10p_merra2=TN10p_merra2,
            TN10p_ratio=ratio(TN10p_obs, TN10p_merra2),
            FD_obs=FD_obs,     FD_merra2=FD_merra2,     FD_ratio=ratio(FD_obs, FD_merra2),
            WSDI_obs=WSDI_obs, WSDI_merra2=WSDI_merra2,
            WSDI_ratio=ratio(WSDI_obs, WSDI_merra2),
            CSDI_obs=CSDI_obs, CSDI_merra2=CSDI_merra2,
            CSDI_ratio=ratio(CSDI_obs, CSDI_merra2)
        ))

df_yr = (pd.DataFrame(rows)
         .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                on="station_name", how="left")
         .sort_values(["station_name", "year"]))

###############################################################################
# 6.  SAVE OUTPUT  – ONE PARQUET ( + XLSX )  +  OPTIONAL per-index sheets
###############################################################################
annual_pq  = os.path.join(output_dir, "Indices_Annual.parquet")
annual_xls = annual_pq.replace(".parquet", ".xlsx")

df_yr.to_parquet(annual_pq, index=False)
df_yr.to_excel  (annual_xls, index=False)
print(f"✓ Annual indices saved → {annual_pq}")
print(f"✓ …and also saved as   → {annual_xls}")

# OPTIONAL: write one Excel file per index in the familiar “wide” format
# ---------------------------------------------------------------------------
index_roots = ["TXx", "TNn", "TX90p", "TN10p", "FD", "WSDI", "CSDI"]
def _wide(idx_root: str) -> pd.DataFrame:
    return (df_yr
            .pivot_table(index="station_name",
                         values=[f"{idx_root}_obs", f"{idx_root}_merra2",
                                 f"{idx_root}_ratio"])
            .reset_index()
            .merge(df_phys[["station_name", "lat", "lon", "elev"]],
                   on="station_name", how="left"))

print("\n(optional) individual workbooks …")
for idx in index_roots:
    w = _wide(idx)
    fp = os.path.join(output_dir, f"{idx}.xlsx")
    w.to_excel(fp, index=False)
    print("  •", os.path.basename(fp))

print("\n✅  Annual-index workflow finished.")

In [None]:
# Temporal stratification of climatic indices MERRA2 for having the seasonal indices

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIGURATION
###############################################################################
csv_file      = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\daily_loop\merra2_vs_stations_12Nearest_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\ClimaticIndices-Seasonal"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY DATA & ADD TEMPORAL FIELDS
###############################################################################
print("Loading daily CSV data ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# Standardize station_name
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

# Add month (1..12) and season (DJF, MAM, JJA, SON)
df_data["month"] = df_data["time"].dt.month

def get_season(month):
    if month in [12, 1, 2]:
        return "DJF"
    elif month in [3, 4, 5]:
        return "MAM"
    elif month in [6, 7, 8]:
        return "JJA"
    else:
        return "SON"

df_data["season"] = df_data["month"].apply(get_season)

print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. LOAD PHYSICAL FILE & MERGE COORDINATES
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    return series.rolling(5, min_periods=1).sum().max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    max_run, current_run = 0, 0
    for val in is_dry:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    max_run, current_run = 0, 0
    for val in is_wet:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_r95p_r99p(series, percentiles=(95,99)):
    """
    r95amt, r95pct, r99amt, r99pct:
    - r95amt = sum of daily prcp above 95th percentile
    - r95pct = (r95amt / total) * 100
    - similarly for 99th percentile
    """
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentiles[0])
    p99 = np.percentile(wet, percentiles[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """
    wetdays = #days >= wet_thr
    drydays = #days < dry_thr
    """
    return (series >= wet_thr).sum(), (series < dry_thr).sum()

###############################################################################
# 5. FUNCTION TO COMPUTE INDICES FOR A GROUP (MONTHLY or SEASONAL)
###############################################################################
def compute_indices(df_group):
    """
    For a subset of daily data (e.g. station+month, or station+season),
    compute the climate indices for Obs vs MERRA2, plus ratio columns.
    """
    obs_series = df_group["obs"].dropna().reset_index(drop=True)
    merra2_series = df_group["merra2_lwr12_val"].dropna().reset_index(drop=True)
    if len(obs_series) == 0 or len(merra2_series) == 0:
        return None
    
    res = {}
    # Rx1day / Rx5day
    res["rx1day_obs"] = calc_rx1day(obs_series)
    res["rx1day_merra2"] = calc_rx1day(merra2_series)
    res["rx5day_obs"] = calc_rx5day(obs_series)
    res["rx5day_merra2"] = calc_rx5day(merra2_series)
    
    # CDD / CWD
    res["cdd_obs"] = calc_cdd(obs_series)
    res["cdd_merra2"] = calc_cdd(merra2_series)
    res["cwd_obs"] = calc_cwd(obs_series)
    res["cwd_merra2"] = calc_cwd(merra2_series)
    
    # R95 / R99
    r95_obs = calc_r95p_r99p(obs_series)
    r95_merra2 = calc_r95p_r99p(merra2_series)
    res["r95amt_obs"] = r95_obs[0]
    res["r95pct_obs"] = r95_obs[1]
    res["r95amt_merra2"] = r95_merra2[0]
    res["r95pct_merra2"] = r95_merra2[1]
    res["r99amt_obs"] = r95_obs[2]
    res["r99pct_obs"] = r95_obs[3]
    res["r99amt_merra2"] = r95_merra2[2]
    res["r99pct_merra2"] = r95_merra2[3]
    
    # Wet / Dry days
    wet_obs, dry_obs = calc_wetdays_drydays(obs_series)
    wet_merra2, dry_merra2 = calc_wetdays_drydays(merra2_series)
    res["wetdays_obs"] = wet_obs
    res["wetdays_merra2"] = wet_merra2
    res["drydays_obs"] = dry_obs
    res["drydays_merra2"] = dry_merra2
    
    # Ratio columns: MERRA2/OBS if obs != 0
    if res["rx1day_obs"]:
        res["rx1day_ratio"] = res["rx1day_merra2"] / res["rx1day_obs"]
    if res["rx5day_obs"]:
        res["rx5day_ratio"] = res["rx5day_merra2"] / res["rx5day_obs"]
    if res["cdd_obs"]:
        res["cdd_ratio"] = res["cdd_merra2"] / res["cdd_obs"]
    if res["cwd_obs"]:
        res["cwd_ratio"] = res["cwd_merra2"] / res["cwd_obs"]
    if res["r95amt_obs"]:
        res["r95amt_ratio"] = res["r95amt_merra2"] / res["r95amt_obs"]
    if res["r95pct_obs"]:
        res["r95pct_ratio"] = res["r95pct_merra2"] / res["r95pct_obs"]
    if res["r99amt_obs"]:
        res["r99amt_ratio"] = res["r99amt_merra2"] / res["r99amt_obs"]
    if res["r99pct_obs"]:
        res["r99pct_ratio"] = res["r99pct_merra2"] / res["r99pct_obs"]
    if res["wetdays_obs"]:
        res["wetdays_ratio"] = res["wetdays_merra2"] / res["wetdays_obs"]
    if res["drydays_obs"]:
        res["drydays_ratio"] = res["drydays_merra2"] / res["drydays_obs"]
    
    return res

###############################################################################
# 6. MONTHLY INDICES
###############################################################################
monthly_results = []
group_month = df_data.groupby(["station_name", "month"])
for (st_name, mon), group in group_month:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["month"] = mon
    monthly_results.append(indices)

df_monthly = pd.DataFrame(monthly_results)
df_monthly = pd.merge(
    df_monthly,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_monthly = df_monthly.sort_values(["station_name", "month"])
monthly_out = os.path.join(output_dir, "Indices_Monthly.xlsx")
df_monthly.to_excel(monthly_out, index=False)
print("Monthly indices saved =>", monthly_out)

###############################################################################
# 7. SEASONAL INDICES
###############################################################################
seasonal_results = []
group_season = df_data.groupby(["station_name", "season"])
for (st_name, seas), group in group_season:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["season"] = seas
    seasonal_results.append(indices)

df_seasonal = pd.DataFrame(seasonal_results)
df_seasonal = pd.merge(
    df_seasonal,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_seasonal = df_seasonal.sort_values(["station_name", "season"])
seasonal_out = os.path.join(output_dir, "Indices_Seasonal.xlsx")
df_seasonal.to_excel(seasonal_out, index=False)
print("Seasonal indices saved =>", seasonal_out)

###############################################################################
# 8. DONE
###############################################################################
print("\nAll monthly and seasonal indices have been saved. (No extreme-event stratification.)")


In [None]:
# DJF for MERRA2

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import Point
from scipy.stats import pearsonr
import seaborn as sns

###############################################################################
# 1. CONFIG & PATHS
###############################################################################
indices_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\ClimaticIndices-Seasonal"
seasonal_file = os.path.join(indices_dir, "Indices_Seasonal.xlsx")  # single file
output_plots  = os.path.join(indices_dir, "AnalysisPlots_DJF")
os.makedirs(output_plots, exist_ok=True)

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

# Indices in your seasonal file
index_list = ["rx1day","rx5day","cdd","cwd","r95p","r99p","wetdays","drydays"]

# For summary stats, define how to find obs vs. emd columns
index_columns = {
    "rx1day":  ("rx1day_obs",  "rx1day_merra2"),
    "rx5day":  ("rx5day_obs",  "rx5day_merra2"),
    "cdd":     ("cdd_obs",     "cdd_merra2"),
    "cwd":     ("cwd_obs",     "cwd_merra2"),
    "r95p":    (("r95amt_obs","r95pct_obs"), ("r95amt_merra2","r95pct_merra2")),
    "r99p":    (("r99amt_obs","r99pct_obs"), ("r99amt_merra2","r99pct_merra2")),
    "wetdays": ("wetdays_obs","wetdays_merra2"),
    "drydays": ("drydays_obs","drydays_merra2"),
}

###############################################################################
# 2. LOAD SEASONAL FILE & FILTER TO DJF
###############################################################################
df_season = pd.read_excel(seasonal_file)
print("Loaded =>", seasonal_file, "| shape =", df_season.shape)

# Filter to DJF
df_season = df_season[df_season["season"]=="DJF"].copy()
df_season = df_season.dropna(subset=["lat","lon"])  # ensure lat/lon exist
print("After filtering to DJF => shape =", df_season.shape)

mdf = df_season.reset_index(drop=True)
master_xlsx = os.path.join(output_plots, "MasterTable_Seasonal_DJF.xlsx")
mdf.to_excel(master_xlsx, index=False)
print(f"\n(A) Master table (DJF) saved => {master_xlsx}")
print("Columns:", mdf.columns.tolist())

###############################################################################
# 3. SUMMARY TABLE (MBE, RMSE, STD, CC, d) for DJF
###############################################################################
def index_of_agreement(obs, model):
    obs_mean = np.mean(obs)
    num = np.sum((model - obs)**2)
    den = np.sum((abs(model - obs_mean) + abs(obs - obs_mean))**2)
    if den == 0:
        return np.nan
    return 1 - num/den

def rmse(a, b):
    return np.sqrt(np.mean((a-b)**2))

def std_of_residuals(a, b):
    return np.std(a-b, ddof=1)

def mean_bias_error(a, b):
    return np.mean(b-a)

summary_rows = []
for idx_name in index_list:
    obs_cols = index_columns[idx_name][0]
    merra2_cols = index_columns[idx_name][1]

    if isinstance(obs_cols, tuple):
        # multiple columns
        for oc, ec in zip(obs_cols, merra2_cols):
            valid = mdf[[oc, ec]].dropna()
            if len(valid) < 2:
                continue
            obs_vals = valid[oc].values
            merra2_vals = valid[ec].values
            MB  = mean_bias_error(obs_vals, merra2_vals)
            RM  = rmse(obs_vals, merra2_vals)
            SR  = std_of_residuals(obs_vals, merra2_vals)
            CC  = pearsonr(obs_vals, merra2_vals)[0] if len(obs_vals) > 1 else np.nan
            dd  = index_of_agreement(obs_vals, merra2_vals)
            idx_label = f"{idx_name}_{oc.replace('_obs','')}"
            summary_rows.append({
                "Index": idx_label,
                "Count": len(valid),
                "MBE": MB,
                "RMSE": RM,
                "STDres": SR,
                "CC": CC,
                "d": dd,
            })
    else:
        oc = obs_cols
        ec = merra2_cols
        valid = mdf[[oc, ec]].dropna()
        if len(valid) < 2:
            continue
        obs_vals = valid[oc].values
        merra2_vals = valid[ec].values
        MB = mean_bias_error(obs_vals, merra2_vals)
        RM = rmse(obs_vals, merra2_vals)
        SR = std_of_residuals(obs_vals, merra2_vals)
        CC = pearsonr(obs_vals, merra2_vals)[0] if len(obs_vals) > 1 else np.nan
        dd = index_of_agreement(obs_vals, merra2_vals)
        summary_rows.append({
            "Index": idx_name,
            "Count": len(valid),
            "MBE": MB,
            "RMSE": RM,
            "STDres": SR,
            "CC": CC,
            "d": dd,
        })

summary_df = pd.DataFrame(summary_rows)
summary_cols = ["Index","Count","MBE","RMSE","STDres","CC","d"]
summary_df = summary_df[summary_cols]
summary_xlsx = os.path.join(output_plots, "SummaryTable_Extremes_DJF.xlsx")
summary_df.to_excel(summary_xlsx, index=False)
print(f"(B) Summary Table (DJF) => {summary_xlsx}\n{summary_df}")

###############################################################################
# 4. MAPPING: Combine Observed, MERRA2, Ratio in One Figure
###############################################################################
gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

def add_basin_lakes(ax):
    #ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='black', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='cyan', linewidth=1)

def plot_map_triptych(df, obs_col, merra2_col, ratio_col, idx_name, out_png):
    """
    Creates a single figure with 3 subplots (side-by-side):
      1) Observed
      2) MERRA2
      3) Ratio (MERRA2/OBS)
    Each subplot has a colorbar, a 90th-percentile hotspot circle, etc.
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6),
                             subplot_kw={"projection": ccrs.PlateCarree()})

    # We'll define a small helper to do each subplot
    def scatter_map(ax, value_col, title):
        ax.set_extent([-95.5, -72, 38.5, 52.5])
        add_basin_lakes(ax)
        sc = ax.scatter(df["lon"], df["lat"], c=df[value_col], cmap="viridis",
                        s=60, transform=ccrs.PlateCarree(), edgecolor="k", zorder=10)
        cb = plt.colorbar(sc, ax=ax, shrink=0.8)
        cb.set_label(value_col)

        # Hotspots => top 10%
        vals = df[value_col].dropna().values
        if len(vals) > 0:
            thr = np.percentile(vals, 90)
            is_hot = df[value_col]>=thr
            ax.scatter(df.loc[is_hot,"lon"], df.loc[is_hot,"lat"],
                       marker='o', facecolors='none', edgecolors='red', s=80,
                       transform=ccrs.PlateCarree(), zorder=11,
                       label=f"Hotspot >= {thr:.2f}")
        ax.set_title(title, fontsize=12)
        gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
        gl.right_labels = False
        gl.top_labels   = False
        ax.legend(loc='upper right')

    scatter_map(axes[0], obs_col,  f"{idx_name} Observed (DJF)")
    scatter_map(axes[1], merra2_col,  f"{idx_name} MERRA2 (DJF)")
    scatter_map(axes[2], ratio_col,f"{idx_name} (MERRA2/OBS) (DJF)")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

def get_map_cols(idx_name):
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs  = f"{idx_name}_obs"
        merra2  = f"{idx_name}_merra2"
        ratio= f"{idx_name}_ratio"
        return obs, merra2, ratio
    elif idx_name == "wetdays":
        obs  = "wetdays_obs"
        merra2  = "wetdays_merra2"
        ratio= "wetdays_ratio"
        return obs, merra2, ratio
    elif idx_name == "r95p":
        obs  = "r95amt_obs"
        merra2  = "r95amt_merra2"
        ratio= "r95amt_ratio"
        return obs, merra2, ratio
    elif idx_name == "r99p":
        obs  = "r99amt_obs"
        merra2  = "r99amt_merra2"
        ratio= "r99amt_ratio"
        return obs, merra2, ratio
    else:
        return None, None, None

for idx_name in index_list:
    obs_col, merra2_col, ratio_col = get_map_cols(idx_name)
    if obs_col is None:
        continue

    needed_cols = [obs_col, merra2_col, ratio_col, "lat", "lon"]
    if not all(c in mdf.columns for c in needed_cols):
        print(f"Skipping map for {idx_name} - missing columns.")
        continue

    subdf = mdf.dropna(subset=["lat","lon"]).copy()
    out_png = os.path.join(output_plots, f"DJF_{idx_name}_MAP_3panel.png")
    plot_map_triptych(subdf, obs_col, merra2_col, ratio_col, idx_name, out_png)

###############################################################################
# 5. DISTRIBUTION & BOX/CDF/Scatter in One Figure
###############################################################################
def plot_distribution_triptych(df, obs_col, merra2_col, label, out_png):
    """
    Creates a single figure with 3 subplots side-by-side:
      1) Boxplot
      2) CDF
      3) Scatter
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18,6))

    # A) Boxplot
    ax_box = axes[0]
    data = pd.DataFrame({"Obs": df[obs_col], "MERRA2": df[merra2_col]}).melt(
        var_name="Dataset", value_name=label
    )
    sns.boxplot(data=data, x="Dataset", y=label, ax=ax_box)
    ax_box.set_title(f"Boxplot: {label} (DJF)")

    # B) CDF
    ax_cdf = axes[1]
    obs_vals = df[obs_col].dropna()
    merra2_vals = df[merra2_col].dropna()

    def ecdf(x):
        xs = np.sort(x)
        ys = np.arange(1, len(xs)+1)/len(xs)
        return xs, ys

    if len(obs_vals)>=2 and len(merra2_vals)>=2:
        xs_o, ys_o = ecdf(obs_vals)
        xs_e, ys_e = ecdf(merra2_vals)
        ax_cdf.plot(xs_o, ys_o, label="Obs")
        ax_cdf.plot(xs_e, ys_e, label="MERRA2")
        ax_cdf.set_title(f"CDF of {label} (DJF)")
        ax_cdf.set_xlabel(label)
        ax_cdf.set_ylabel("Probability")
        ax_cdf.legend()
    else:
        ax_cdf.set_title(f"CDF: not enough data ({label})")

    # C) Scatter
    ax_scat = axes[2]
    valid = df[[obs_col, merra2_col]].dropna()
    if len(valid)>=2:
        x = valid[obs_col]
        y = valid[merra2_col]
        cc, _ = pearsonr(x, y)
        ax_scat.scatter(x, y, edgecolors='k', alpha=0.7)
        mn, mx = np.nanmin([x.min(), y.min()]), np.nanmax([x.max(), y.max()])
        ax_scat.plot([mn, mx],[mn, mx],'r--')
        ax_scat.set_xlabel(f"Obs {label} (DJF)")
        ax_scat.set_ylabel(f"MERRA2 {label} (DJF)")
        ax_scat.set_title(f"{label} (Corr={cc:.2f}, DJF)")
    else:
        ax_scat.set_title(f"Scatter: not enough data ({label})")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

for idx_name in index_list:
    # figure out obs, emd columns
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs_col  = f"{idx_name}_obs"
        merra2_col  = f"{idx_name}_merra2"
    elif idx_name == "wetdays":
        obs_col  = "wetdays_obs"
        merra2_col  = "wetdays_merra2"
    elif idx_name == "r95p":
        obs_col  = "r95amt_obs"
        merra2_col  = "r95amt_merra2"
    elif idx_name == "r99p":
        obs_col  = "r99amt_obs"
        merra2_col  = "r99amt_merra2"
    else:
        continue

    if obs_col not in mdf.columns or merra2_col not in mdf.columns:
        print(f"Skipping distribution for {idx_name} - missing columns.")
        continue

    subdf = mdf[[obs_col, merra2_col]].dropna()
    if len(subdf)<2:
        print(f"Skipping distribution for {idx_name} - not enough data.")
        continue

    out_3panel = os.path.join(output_plots, f"DJF_{idx_name}_Distribution_3panel.png")
    plot_distribution_triptych(subdf, obs_col, merra2_col, idx_name, out_3panel)

###############################################################################
# 6. DONE
###############################################################################
print("\nAll DJF steps completed! See outputs in:", output_plots)

# For the other seasons, just change any DJF to JJA, MAM, or SON

In [None]:
                 #########################                                           ##############################
                 #########################                  CHIRPS                   ##############################
                 #########################                                           ##############################

In [None]:
# Calculating the climatic indices for CHIRPS 

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIG
###############################################################################
csv_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\daily_loop\chirps_vs_stations_12Nearest_LWR_1991_2012.csv"
physical_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

output_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\ClimaticIndices-12Nearest"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY CSV DATA
###############################################################################
print("Loading daily CSV data (obs, chirps_lwr12_val) ...")
df_data = pd.read_csv(csv_file)
# parse 'time' as datetime if needed
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# Unify station_name: remove leading/trailing spaces, uppercase
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

print(f"df_data shape = {df_data.shape}")
print("Columns:", df_data.columns.tolist())
print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. (OPTIONAL) MERGE WITH PHYSICAL FILE TO GET LAT/LON
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
# Unify station_name
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    roll_5 = series.rolling(5, min_periods=1).sum()
    return roll_5.max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    maxr, curr = 0, 0
    for val in is_dry:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    maxr, curr = 0, 0
    for val in is_wet:
        if val:
            curr += 1
            maxr = max(maxr, curr)
        else:
            curr = 0
    return maxr

def calc_r95p_r99p(series, percentile=(95,99)):
    """R95p, R99p total in mm, plus percentage of total precipitation."""
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentile[0])
    p99 = np.percentile(wet, percentile[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total   = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """Count number of wet days (>=5 mm) and dry days (<1 mm)."""
    w = (series >= wet_thr).sum()
    d = (series <  dry_thr).sum()
    return w, d

###############################################################################
# 5. COMPUTE INDICES FOR EACH STATION
###############################################################################
rx1_list, rx5_list = [], []
cdd_list, cwd_list = [], []
r95_list, r99_list = [], []
wet_list, dry_list = [], []

print("Computing indices for each station...")

grouped = df_data.groupby("station_name", as_index=False)
for st_name, grp in grouped:
    # Sort by time
    grp = grp.sort_values("time")
    
    # Daily series: observed and CHIRPS LWR interpolation
    obs_series = grp["obs"].dropna().reset_index(drop=True)
    chirps_series = grp["chirps_lwr12_val"].dropna().reset_index(drop=True)
    
    # A) Rx1day
    obs_rx1 = calc_rx1day(obs_series)
    chirps_rx1 = calc_rx1day(chirps_series)
    rx1_list.append({
        "station_name": st_name,
        "obs_rx1day": obs_rx1,
        "chirps_rx1day": chirps_rx1
    })
    
    # B) Rx5day
    obs_rx5 = calc_rx5day(obs_series)
    chirps_rx5 = calc_rx5day(chirps_series)
    rx5_list.append({
        "station_name": st_name,
        "obs_rx5day": obs_rx5,
        "chirps_rx5day": chirps_rx5
    })
    
    # C) CDD
    obs_cdd_val = calc_cdd(obs_series)
    chirps_cdd_val = calc_cdd(chirps_series)
    cdd_list.append({
        "station_name": st_name,
        "obs_cdd": obs_cdd_val,
        "chirps_cdd": chirps_cdd_val
    })
    
    # D) CWD
    obs_cwd_val = calc_cwd(obs_series)
    chirps_cwd_val = calc_cwd(chirps_series)
    cwd_list.append({
        "station_name": st_name,
        "obs_cwd": obs_cwd_val,
        "chirps_cwd": chirps_cwd_val
    })
    
    # E) R95 / R99
    or95a, or95p, or99a, or99p = calc_r95p_r99p(obs_series)
    cr95a, cr95p, cr99a, cr99p = calc_r95p_r99p(chirps_series)
    r95_list.append({
        "station_name": st_name,
        "obs_r95amt": or95a,
        "obs_r95pct": or95p,
        "chirps_r95amt": cr95a,
        "chirps_r95pct": cr95p
    })
    r99_list.append({
        "station_name": st_name,
        "obs_r99amt": or99a,
        "obs_r99pct": or99p,
        "chirps_r99amt": cr99a,
        "chirps_r99pct": cr99p
    })
    
    # F) Wet/Dry days
    obs_wet5, obs_dry = calc_wetdays_drydays(obs_series)
    chirps_wet5, chirps_dry = calc_wetdays_drydays(chirps_series)
    wet_list.append({
        "station_name": st_name,
        "obs_wetdays5mm": obs_wet5,
        "chirps_wetdays5mm": chirps_wet5
    })
    dry_list.append({
        "station_name": st_name,
        "obs_drydays": obs_dry,
        "chirps_drydays": chirps_dry
    })

print("Finished computing. Now merging lat/lon from physical file ...")

def attach_coords(df_in):
    """Attach lat, lon, and elevation from physical file on station_name."""
    df_out = pd.merge(
        df_in,
        df_phys[["station_name", "lat", "lon", "elev"]],
        on="station_name",
        how="left"
    )
    return df_out

df_rx1 = attach_coords(pd.DataFrame(rx1_list))
df_rx5 = attach_coords(pd.DataFrame(rx5_list))
df_cdd = attach_coords(pd.DataFrame(cdd_list))
df_cwd = attach_coords(pd.DataFrame(cwd_list))
df_r95 = attach_coords(pd.DataFrame(r95_list))
df_r99 = attach_coords(pd.DataFrame(r99_list))
df_wet = attach_coords(pd.DataFrame(wet_list))
df_dry = attach_coords(pd.DataFrame(dry_list))

###############################################################################
# 6. SAVE OUTPUT
###############################################################################
print("Saving index tables to Excel in:", output_dir)
df_rx1.to_excel(os.path.join(output_dir, "rx1day.xlsx"),  index=False)
df_rx5.to_excel(os.path.join(output_dir, "rx5day.xlsx"),  index=False)
df_cdd.to_excel(os.path.join(output_dir, "cdd.xlsx"),     index=False)
df_cwd.to_excel(os.path.join(output_dir, "cwd.xlsx"),     index=False)
df_r95.to_excel(os.path.join(output_dir, "r95p.xlsx"),    index=False)
df_r99.to_excel(os.path.join(output_dir, "r99p.xlsx"),    index=False)
df_wet.to_excel(os.path.join(output_dir, "wetdays.xlsx"), index=False)
df_dry.to_excel(os.path.join(output_dir, "drydays.xlsx"), index=False)

print("\nAll precipitation-based indices have been saved to Excel.")

###############################################################################
# (OPTIONAL) QUICK MAP EXAMPLE
###############################################################################
try:
    print("\nQuick map example for obs_rx5day ...")
    # Load shapefiles
    gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
    gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

    gdf_stations = gpd.GeoDataFrame(
        df_rx5,
        geometry=gpd.points_from_xy(df_rx5["lon"], df_rx5["lat"]),
        crs="EPSG:4326"
    )

    fig, ax = plt.subplots(figsize=(10,8), subplot_kw={"projection": ccrs.PlateCarree()})
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='cyan', linewidth=1)

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')

    sc = ax.scatter(gdf_stations.geometry.x, gdf_stations.geometry.y,
                    c=gdf_stations["obs_rx5day"], cmap="Reds", s=60,
                    transform=ccrs.PlateCarree(), edgecolor="k")
    plt.colorbar(sc, ax=ax, label="Obs Rx5day (mm)")
    ax.set_extent([-95.5, -72, 38.5, 52.5])
    gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
    gl.right_labels = False
    gl.top_labels   = False

    plt.title("Obs Rx5day (from CSV daily data)", fontsize=14)
    plt.show()

except Exception as e:
    print("Mapping step failed:", e)

print("\n✅ Done computing precipitation-based indices from 'chirps_lwr12_val' column!")


In [None]:
# Temporal stratification of climatic indices CHIRPS for having the seasonal indices

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

###############################################################################
# 1. FILE PATHS & CONFIGURATION
###############################################################################
csv_file      = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\daily_loop\chirps_vs_stations_25km_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\ClimaticIndices-Seasonal"
os.makedirs(output_dir, exist_ok=True)

###############################################################################
# 2. LOAD DAILY DATA & ADD TEMPORAL FIELDS
###############################################################################
print("Loading daily CSV data ...")
df_data = pd.read_csv(csv_file)
df_data["time"] = pd.to_datetime(df_data["time"], errors="coerce")

# Standardize station_name
df_data["station_name"] = df_data["station_name"].astype(str).str.strip().str.upper()

# Add month (1..12) and season (DJF, MAM, JJA, SON)
df_data["month"] = df_data["time"].dt.month

def get_season(month):
    if month in [12, 1, 2]:
        return "DJF"
    elif month in [3, 4, 5]:
        return "MAM"
    elif month in [6, 7, 8]:
        return "JJA"
    else:
        return "SON"

df_data["season"] = df_data["month"].apply(get_season)

print("Time range:", df_data["time"].min(), "to", df_data["time"].max())

###############################################################################
# 3. LOAD PHYSICAL FILE & MERGE COORDINATES
###############################################################################
df_phys = pd.read_csv(physical_file)
df_phys = df_phys.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})
df_phys["station_name"] = df_phys["station_name"].astype(str).str.strip().str.upper()

###############################################################################
# 4. HELPER FUNCTIONS FOR CLIMATE INDICES
###############################################################################
def calc_rx1day(series):
    """Max 1-day precipitation."""
    return series.max(skipna=True)

def calc_rx5day(series):
    """Max 5-day running sum."""
    return series.rolling(5, min_periods=1).sum().max(skipna=True)

def calc_cdd(series, dry_threshold=1.0):
    """Longest run of days < dry_threshold."""
    is_dry = series < dry_threshold
    max_run, current_run = 0, 0
    for val in is_dry:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_cwd(series, wet_threshold=1.0):
    """Longest run of days >= wet_threshold."""
    is_wet = series >= wet_threshold
    max_run, current_run = 0, 0
    for val in is_wet:
        if val:
            current_run += 1
            max_run = max(max_run, current_run)
        else:
            current_run = 0
    return max_run

def calc_r95p_r99p(series, percentiles=(95,99)):
    """
    r95amt, r95pct, r99amt, r99pct:
    - r95amt = sum of daily prcp above 95th percentile
    - r95pct = (r95amt / total) * 100
    - similarly for 99th percentile
    """
    wet = series[series >= 1.0]
    if len(wet) < 5:
        return np.nan, np.nan, np.nan, np.nan
    p95 = np.percentile(wet, percentiles[0])
    p99 = np.percentile(wet, percentiles[1])
    r95_amt = wet[wet > p95].sum()
    r99_amt = wet[wet > p99].sum()
    total = series.sum(skipna=True)
    r95_pct = (r95_amt / total) * 100 if total > 0 else np.nan
    r99_pct = (r99_amt / total) * 100 if total > 0 else np.nan
    return r95_amt, r95_pct, r99_amt, r99_pct

def calc_wetdays_drydays(series, wet_thr=5.0, dry_thr=1.0):
    """
    wetdays = #days >= wet_thr
    drydays = #days < dry_thr
    """
    return (series >= wet_thr).sum(), (series < dry_thr).sum()

###############################################################################
# 5. FUNCTION TO COMPUTE INDICES FOR A GROUP (MONTHLY or SEASONAL)
###############################################################################
def compute_indices(df_group):
    """
    For a subset of daily data (e.g. station+month, or station+season),
    compute the climate indices for Obs vs CHIRPS, plus ratio columns.
    """
    obs_series = df_group["obs"].dropna().reset_index(drop=True)
    chirps_series = df_group["chirps_val"].dropna().reset_index(drop=True)
    if len(obs_series) == 0 or len(chirps_series) == 0:
        return None
    
    res = {}
    # Rx1day / Rx5day
    res["rx1day_obs"] = calc_rx1day(obs_series)
    res["rx1day_chirps"] = calc_rx1day(chirps_series)
    res["rx5day_obs"] = calc_rx5day(obs_series)
    res["rx5day_chirps"] = calc_rx5day(chirps_series)
    
    # CDD / CWD
    res["cdd_obs"] = calc_cdd(obs_series)
    res["cdd_chirps"] = calc_cdd(chirps_series)
    res["cwd_obs"] = calc_cwd(obs_series)
    res["cwd_chirps"] = calc_cwd(chirps_series)
    
    # R95 / R99
    r95_obs = calc_r95p_r99p(obs_series)
    r95_chirps = calc_r95p_r99p(chirps_series)
    res["r95amt_obs"] = r95_obs[0]
    res["r95pct_obs"] = r95_obs[1]
    res["r95amt_chirps"] = r95_chirps[0]
    res["r95pct_chirps"] = r95_chirps[1]
    res["r99amt_obs"] = r95_obs[2]
    res["r99pct_obs"] = r95_obs[3]
    res["r99amt_chirps"] = r95_chirps[2]
    res["r99pct_chirps"] = r95_chirps[3]
    
    # Wet / Dry days
    wet_obs, dry_obs = calc_wetdays_drydays(obs_series)
    wet_chirps, dry_chirps = calc_wetdays_drydays(chirps_series)
    res["wetdays_obs"] = wet_obs
    res["wetdays_chirps"] = wet_chirps
    res["drydays_obs"] = dry_obs
    res["drydays_chirps"] = dry_chirps
    
    # Ratio columns: CHIRPS/OBS if obs != 0
    if res["rx1day_obs"]:
        res["rx1day_ratio"] = res["rx1day_chirps"] / res["rx1day_obs"]
    if res["rx5day_obs"]:
        res["rx5day_ratio"] = res["rx5day_chirps"] / res["rx5day_obs"]
    if res["cdd_obs"]:
        res["cdd_ratio"] = res["cdd_chirps"] / res["cdd_obs"]
    if res["cwd_obs"]:
        res["cwd_ratio"] = res["cwd_chirps"] / res["cwd_obs"]
    if res["r95amt_obs"]:
        res["r95amt_ratio"] = res["r95amt_chirps"] / res["r95amt_obs"]
    if res["r95pct_obs"]:
        res["r95pct_ratio"] = res["r95pct_chirps"] / res["r95pct_obs"]
    if res["r99amt_obs"]:
        res["r99amt_ratio"] = res["r99amt_chirps"] / res["r99amt_obs"]
    if res["r99pct_obs"]:
        res["r99pct_ratio"] = res["r99pct_chirps"] / res["r99pct_obs"]
    if res["wetdays_obs"]:
        res["wetdays_ratio"] = res["wetdays_chirps"] / res["wetdays_obs"]
    if res["drydays_obs"]:
        res["drydays_ratio"] = res["drydays_chirps"] / res["drydays_obs"]
    
    return res

###############################################################################
# 6. MONTHLY INDICES
###############################################################################
monthly_results = []
group_month = df_data.groupby(["station_name", "month"])
for (st_name, mon), group in group_month:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["month"] = mon
    monthly_results.append(indices)

df_monthly = pd.DataFrame(monthly_results)
df_monthly = pd.merge(
    df_monthly,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_monthly = df_monthly.sort_values(["station_name", "month"])
monthly_out = os.path.join(output_dir, "Indices_Monthly.xlsx")
df_monthly.to_excel(monthly_out, index=False)
print("Monthly indices saved =>", monthly_out)

###############################################################################
# 7. SEASONAL INDICES
###############################################################################
seasonal_results = []
group_season = df_data.groupby(["station_name", "season"])
for (st_name, seas), group in group_season:
    indices = compute_indices(group)
    if indices is None:
        continue
    indices["station_name"] = st_name
    indices["season"] = seas
    seasonal_results.append(indices)

df_seasonal = pd.DataFrame(seasonal_results)
df_seasonal = pd.merge(
    df_seasonal,
    df_phys[["station_name", "lat", "lon", "elev"]],
    on="station_name",
    how="left"
)
df_seasonal = df_seasonal.sort_values(["station_name", "season"])
seasonal_out = os.path.join(output_dir, "Indices_Seasonal.xlsx")
df_seasonal.to_excel(seasonal_out, index=False)
print("Seasonal indices saved =>", seasonal_out)

###############################################################################
# 8. DONE
###############################################################################
print("\nAll monthly and seasonal indices have been saved. (No extreme-event stratification.)")


In [None]:
# DJF for CHIRPS

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import Point
from scipy.stats import pearsonr
import seaborn as sns

###############################################################################
# 1. CONFIG & PATHS
###############################################################################
indices_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\ClimaticIndices-Seasonal"
seasonal_file = os.path.join(indices_dir, "Indices_Seasonal.xlsx")  # single file
output_plots  = os.path.join(indices_dir, "AnalysisPlots_DJF")
os.makedirs(output_plots, exist_ok=True)

shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"

# Indices in your seasonal file
index_list = ["rx1day","rx5day","cdd","cwd","r95p","r99p","wetdays","drydays"]

# For summary stats, define how to find obs vs. chirps columns
index_columns = {
    "rx1day":  ("rx1day_obs",  "rx1day_chirps"),
    "rx5day":  ("rx5day_obs",  "rx5day_chirps"),
    "cdd":     ("cdd_obs",     "cdd_chirps"),
    "cwd":     ("cwd_obs",     "cwd_chirps"),
    "r95p":    (("r95amt_obs","r95pct_obs"), ("r95amt_chirps","r95pct_chirps")),
    "r99p":    (("r99amt_obs","r99pct_obs"), ("r99amt_chirps","r99pct_chirps")),
    "wetdays": ("wetdays_obs","wetdays_chirps"),
    "drydays": ("drydays_obs","drydays_chirps"),
}

###############################################################################
# 2. LOAD SEASONAL FILE & FILTER TO DJF
###############################################################################
df_season = pd.read_excel(seasonal_file)
print("Loaded =>", seasonal_file, "| shape =", df_season.shape)

# Filter to DJF
df_season = df_season[df_season["season"]=="DJF"].copy()
df_season = df_season.dropna(subset=["lat","lon"])  # ensure lat/lon exist
print("After filtering to DJF => shape =", df_season.shape)

mdf = df_season.reset_index(drop=True)
master_xlsx = os.path.join(output_plots, "MasterTable_Seasonal_DJF.xlsx")
mdf.to_excel(master_xlsx, index=False)
print(f"\n(A) Master table (DJF) saved => {master_xlsx}")
print("Columns:", mdf.columns.tolist())

###############################################################################
# 3. SUMMARY TABLE (MBE, RMSE, STD, CC, d) for DJF
###############################################################################
def index_of_agreement(obs, model):
    obs_mean = np.mean(obs)
    num = np.sum((model - obs)**2)
    den = np.sum((abs(model - obs_mean) + abs(obs - obs_mean))**2)
    if den == 0:
        return np.nan
    return 1 - num/den

def rmse(a, b):
    return np.sqrt(np.mean((a-b)**2))

def std_of_residuals(a, b):
    return np.std(a-b, ddof=1)

def mean_bias_error(a, b):
    return np.mean(b-a)

summary_rows = []
for idx_name in index_list:
    obs_cols = index_columns[idx_name][0]
    chirps_cols = index_columns[idx_name][1]

    if isinstance(obs_cols, tuple):
        # multiple columns
        for oc, ec in zip(obs_cols, chirps_cols):
            valid = mdf[[oc, ec]].dropna()
            if len(valid) < 2:
                continue
            obs_vals = valid[oc].values
            chirps_vals = valid[ec].values
            MB  = mean_bias_error(obs_vals, chirps_vals)
            RM  = rmse(obs_vals, chirps_vals)
            SR  = std_of_residuals(obs_vals, chirps_vals)
            CC  = pearsonr(obs_vals, chirps_vals)[0] if len(obs_vals) > 1 else np.nan
            dd  = index_of_agreement(obs_vals, chirps_vals)
            idx_label = f"{idx_name}_{oc.replace('_obs','')}"
            summary_rows.append({
                "Index": idx_label,
                "Count": len(valid),
                "MBE": MB,
                "RMSE": RM,
                "STDres": SR,
                "CC": CC,
                "d": dd,
            })
    else:
        oc = obs_cols
        ec = chirps_cols
        valid = mdf[[oc, ec]].dropna()
        if len(valid) < 2:
            continue
        obs_vals = valid[oc].values
        chirps_vals = valid[ec].values
        MB = mean_bias_error(obs_vals, chirps_vals)
        RM = rmse(obs_vals, chirps_vals)
        SR = std_of_residuals(obs_vals, chirps_vals)
        CC = pearsonr(obs_vals, chirps_vals)[0] if len(obs_vals) > 1 else np.nan
        dd = index_of_agreement(obs_vals, chirps_vals)
        summary_rows.append({
            "Index": idx_name,
            "Count": len(valid),
            "MBE": MB,
            "RMSE": RM,
            "STDres": SR,
            "CC": CC,
            "d": dd,
        })

summary_df = pd.DataFrame(summary_rows)
summary_cols = ["Index","Count","MBE","RMSE","STDres","CC","d"]
summary_df = summary_df[summary_cols]
summary_xlsx = os.path.join(output_plots, "SummaryTable_Extremes_DJF.xlsx")
summary_df.to_excel(summary_xlsx, index=False)
print(f"(B) Summary Table (DJF) => {summary_xlsx}\n{summary_df}")

###############################################################################
# 4. MAPPING: Combine Observed, CHIRPS, Ratio in One Figure
###############################################################################
gdf_basin = gpd.read_file(shapefile_path).to_crs(epsg=4326)
gdf_lakes = gpd.read_file(lakes_shp).to_crs(epsg=4326)

def add_basin_lakes(ax):
    #ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    for geom in gdf_basin.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='black', linewidth=1)
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(),
                          facecolor='none', edgecolor='cyan', linewidth=1)

def plot_map_triptych(df, obs_col, chirps_col, ratio_col, idx_name, out_png):
    """
    Creates a single figure with 3 subplots (side-by-side):
      1) Observed
      2) CHIRPS v2.1
      3) Ratio (CHIRPS/OBS)
    Each subplot has a colorbar, a 90th-percentile hotspot circle, etc.
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6),
                             subplot_kw={"projection": ccrs.PlateCarree()})

    # We'll define a small helper to do each subplot
    def scatter_map(ax, value_col, title):
        ax.set_extent([-95.5, -72, 38.5, 52.5])
        add_basin_lakes(ax)
        sc = ax.scatter(df["lon"], df["lat"], c=df[value_col], cmap="viridis",
                        s=60, transform=ccrs.PlateCarree(), edgecolor="k", zorder=10)
        cb = plt.colorbar(sc, ax=ax, shrink=0.8)
        cb.set_label(value_col)

        # Hotspots => top 10%
        vals = df[value_col].dropna().values
        if len(vals) > 0:
            thr = np.percentile(vals, 90)
            is_hot = df[value_col]>=thr
            ax.scatter(df.loc[is_hot,"lon"], df.loc[is_hot,"lat"],
                       marker='o', facecolors='none', edgecolors='red', s=80,
                       transform=ccrs.PlateCarree(), zorder=11,
                       label=f"Hotspot >= {thr:.2f}")
        ax.set_title(title, fontsize=12)
        gl = ax.gridlines(draw_labels=True, linestyle='--', color='gray')
        gl.right_labels = False
        gl.top_labels   = False
        ax.legend(loc='upper right')

    scatter_map(axes[0], obs_col,  f"{idx_name} Observed (DJF)")
    scatter_map(axes[1], chirps_col,  f"{idx_name} CHIRPS (DJF)")
    scatter_map(axes[2], ratio_col,f"{idx_name} (CHIRPS/OBS) (DJF)")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

def get_map_cols(idx_name):
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs  = f"{idx_name}_obs"
        chirps  = f"{idx_name}_chirps"
        ratio= f"{idx_name}_ratio"
        return obs, chirps, ratio
    elif idx_name == "wetdays":
        obs  = "wetdays_obs"
        chirps  = "wetdays_chirps"
        ratio= "wetdays_ratio"
        return obs, chirps, ratio
    elif idx_name == "r95p":
        obs  = "r95amt_obs"
        chirps  = "r95amt_chirps"
        ratio= "r95amt_ratio"
        return obs, chirps, ratio
    elif idx_name == "r99p":
        obs  = "r99amt_obs"
        chirps  = "r99amt_chirps"
        ratio= "r99amt_ratio"
        return obs, chirps, ratio
    else:
        return None, None, None

for idx_name in index_list:
    obs_col, chirps_col, ratio_col = get_map_cols(idx_name)
    if obs_col is None:
        continue

    needed_cols = [obs_col, chirps_col, ratio_col, "lat", "lon"]
    if not all(c in mdf.columns for c in needed_cols):
        print(f"Skipping map for {idx_name} - missing columns.")
        continue

    subdf = mdf.dropna(subset=["lat","lon"]).copy()
    out_png = os.path.join(output_plots, f"DJF_{idx_name}_MAP_3panel.png")
    plot_map_triptych(subdf, obs_col, chirps_col, ratio_col, idx_name, out_png)

###############################################################################
# 5. DISTRIBUTION & BOX/CDF/Scatter in One Figure
###############################################################################
def plot_distribution_triptych(df, obs_col, chirps_col, label, out_png):
    """
    Creates a single figure with 3 subplots side-by-side:
      1) Boxplot
      2) CDF
      3) Scatter
    """
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18,6))

    # A) Boxplot
    ax_box = axes[0]
    data = pd.DataFrame({"Obs": df[obs_col], "CHIRPS": df[chirps_col]}).melt(
        var_name="Dataset", value_name=label
    )
    sns.boxplot(data=data, x="Dataset", y=label, ax=ax_box)
    ax_box.set_title(f"Boxplot: {label} (DJF)")

    # B) CDF
    ax_cdf = axes[1]
    obs_vals = df[obs_col].dropna()
    chirps_vals = df[chirps_col].dropna()

    def ecdf(x):
        xs = np.sort(x)
        ys = np.arange(1, len(xs)+1)/len(xs)
        return xs, ys

    if len(obs_vals)>=2 and len(chirps_vals)>=2:
        xs_o, ys_o = ecdf(obs_vals)
        xs_e, ys_e = ecdf(chirps_vals)
        ax_cdf.plot(xs_o, ys_o, label="Obs")
        ax_cdf.plot(xs_e, ys_e, label="CHIRPS")
        ax_cdf.set_title(f"CDF of {label} (DJF)")
        ax_cdf.set_xlabel(label)
        ax_cdf.set_ylabel("Probability")
        ax_cdf.legend()
    else:
        ax_cdf.set_title(f"CDF: not enough data ({label})")

    # C) Scatter
    ax_scat = axes[2]
    valid = df[[obs_col, chirps_col]].dropna()
    if len(valid)>=2:
        x = valid[obs_col]
        y = valid[chirps_col]
        cc, _ = pearsonr(x, y)
        ax_scat.scatter(x, y, edgecolors='k', alpha=0.7)
        mn, mx = np.nanmin([x.min(), y.min()]), np.nanmax([x.max(), y.max()])
        ax_scat.plot([mn, mx],[mn, mx],'r--')
        ax_scat.set_xlabel(f"Obs {label} (DJF)")
        ax_scat.set_ylabel(f"CHIRPS {label} (DJF)")
        ax_scat.set_title(f"{label} (Corr={cc:.2f}, DJF)")
    else:
        ax_scat.set_title(f"Scatter: not enough data ({label})")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved =>", out_png)

for idx_name in index_list:
    # figure out obs, emd columns
    if idx_name in ["rx1day","rx5day","cdd","cwd","drydays"]:
        obs_col  = f"{idx_name}_obs"
        chirps_col  = f"{idx_name}_chirps"
    elif idx_name == "wetdays":
        obs_col  = "wetdays_obs"
        chirps_col  = "wetdays_chirps"
    elif idx_name == "r95p":
        obs_col  = "r95amt_obs"
        chirps_col  = "r95amt_chirps"
    elif idx_name == "r99p":
        obs_col  = "r99amt_obs"
        chirps_col  = "r99amt_chirps"
    else:
        continue

    if obs_col not in mdf.columns or chirps_col not in mdf.columns:
        print(f"Skipping distribution for {idx_name} - missing columns.")
        continue

    subdf = mdf[[obs_col, chirps_col]].dropna()
    if len(subdf)<2:
        print(f"Skipping distribution for {idx_name} - not enough data.")
        continue

    out_3panel = os.path.join(output_plots, f"DJF_{idx_name}_Distribution_3panel.png")
    plot_distribution_triptych(subdf, obs_col, chirps_col, idx_name, out_3panel)

###############################################################################
# 6. DONE
###############################################################################
print("\nAll DJF steps completed! See outputs in:", output_plots)

# For the other seasons, just change any DJF to JJA, MAM, or SON