In [None]:
# Maps and Taylor diagram for PRISM (PRCP)

import os
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from shapely.geometry import Point
from matplotlib.lines import Line2D
import matplotlib.cm as cm
from matplotlib.colors import Normalize

##############################################################################
# 1. FILE PATHS
##############################################################################
nc_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\daily_loop3\prism_vs_stations_8Nearest_LWR_1991_2012.nc"
metrics_csv   = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\metrics3\station_metrics_8Nearest_LWR_prism_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
shapefile_path= r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp     = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"
target_crs    = "EPSG:4326"  
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\Maps_and_Taylor3"

os.makedirs(output_dir, exist_ok=True)

##############################################################################
# 2. LOAD NETCDF (RENAME DIMENSION) & METRICS CSV
##############################################################################
print("Loading prism vs Station NetCDF ...")
ds_raw = xr.open_dataset(nc_file)
print("Original dims in NetCDF:", ds_raw.dims)

# Rename dimension "variable" -> "station"
ds = ds_raw.rename({"variable": "station"})
print("\nAfter renaming => ds.dims:", ds.dims)
print("Data variables in NetCDF:", list(ds.data_vars))

print("\nLoading station-level metrics from CSV ...")
df_metrics = pd.read_csv(metrics_csv)
print("Loaded metrics CSV with columns:", df_metrics.columns.tolist())

# The code uses "d" for the Index of Agreement. If your CSV uses "Index_of_Agreement",
# rename it to "d" so references to 'd' won't fail:
if "Index_of_Agreement" in df_metrics.columns and "d" not in df_metrics.columns:
    df_metrics = df_metrics.rename(columns={"Index_of_Agreement": "d"})

##############################################################################
# 3. LOAD & MERGE STATION COORDINATES
##############################################################################
print("\nLoading station location info ...")
df_physical = pd.read_csv(physical_file)
df_physical = df_physical.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})

# Merge on station_name (assuming it matches)
common_col = "station_name"
if common_col not in df_metrics.columns:
    print(f"WARNING: {common_col} not in df_metrics. Aborting or adjust code.")
    # Might raise an error or do a different approach

print("\nMerging metrics with physical station coords ...")
df_merged = pd.merge(df_metrics, df_physical, how="inner", on=common_col)
print(f"Merged shape: {df_merged.shape}")
print("Columns:", df_merged.columns.tolist())

##############################################################################
# 4. BASIC CHECKS ON NETCDF
##############################################################################
time_dim = ds["time"].size
station_dim = ds["station"].size
print(f"\nNetCDF: # time steps = {time_dim}, # stations = {station_dim}")
time_min = ds["time"].values.min()
time_max = ds["time"].values.max()
print("Time range in NetCDF =>", str(time_min), "to", str(time_max))

##############################################################################
# 5. LOAD GREAT LAKES SHAPEFILE
##############################################################################
print("\nLoading Great Lakes shapefile ...")
gdf_lakes = gpd.read_file(shapefile_path)
if gdf_lakes.crs is not None:
    gdf_lakes = gdf_lakes.to_crs(target_crs)
else:
    gdf_lakes.crs = target_crs

lon_min, lat_min, lon_max, lat_max = gdf_lakes.total_bounds
print("Great Lakes shapefile loaded. Bounds:", (lon_min, lat_min, lon_max, lat_max))

##############################################################################
# 6. MAP STATISTICAL METRICS (MBE,RMSE,STD,CC,d) WITH “HOTSPOTS”
##############################################################################
metrics_list = ["MBE","RMSE","STD","CC","d"]
titles_dict  = {
    "MBE":  "Mean Bias Error (MBE)",
    "RMSE": "Root Mean Square Error (RMSE)",
    "STD":  "Standard Deviation (STD)",
    "CC":   "Correlation Coefficient (CC)",
    "d":    "Index of Agreement (d)"
}

print("\nGenerating metric maps with hotspots ...")

for metric in metrics_list:
    if metric not in df_merged.columns:
        print(f"Metric {metric} not found in df_merged. Skipping.")
        continue

    fig, ax = plt.subplots(figsize=(10, 8), subplot_kw=dict(projection=ccrs.PlateCarree()))
    ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())

    # Base map
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.LAKES, alpha=0.4)
    ax.add_feature(cfeature.COASTLINE)

    # Plot the Great Lakes boundary
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)

    # Plot station metrics
    sc = ax.scatter(df_merged["lon"], df_merged["lat"],
                    c=df_merged[metric], cmap="viridis", s=60, edgecolor="k",
                    transform=ccrs.PlateCarree())
    cb = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
    cb.set_label(metric, fontsize=12)

    plt.title(f"Spatial Distribution: {titles_dict.get(metric, metric)}", fontsize=14)

    # Identify hotspots
    # For MBE, RMSE, STD => top 10%
    # For CC, d => bottom 10%
    vals = df_merged[metric].dropna().values
    if metric in ["MBE","RMSE","STD"]:
        thr = np.percentile(vals, 90)
        hotspot = df_merged[metric] >= thr
        label_txt = f"Hotspot >= {thr:.2f}"
    else:
        thr = np.percentile(vals, 10)
        hotspot = df_merged[metric] <= thr
        label_txt = f"Hotspot <= {thr:.2f}"

    ax.scatter(df_merged.loc[hotspot,"lon"], df_merged.loc[hotspot,"lat"],
               facecolors='none', edgecolors='red', s=90, linewidths=1.5,
               transform=ccrs.PlateCarree(), label=label_txt)

    # Add Lat/Lon gridlines
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.right_labels = False
    gl.top_labels   = False
    gl.xlabel_style = {'size':10}
    gl.ylabel_style = {'size':10}

    # Legend
    station_handle = Line2D([],[], marker='o', color='k', linestyle='None', markersize=7, label='Stations')
    handles, labels = ax.get_legend_handles_labels()
    handles.append(station_handle)
    labels.append('Stations')
    ax.legend(handles=handles, labels=labels, loc='upper right', fontsize=9)

    out_fn = os.path.join(output_dir, f"Map_{metric}.png")
    plt.savefig(out_fn, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved metric map => {out_fn}")

##############################################################################
# 7. IMPROVED TAYLOR DIAGRAM
##############################################################################
print("\nCreating the improved Taylor Diagram from metrics ...")

import mpl_toolkits.axisartist.grid_finder as gf
import mpl_toolkits.axisartist.floating_axes as fa
import matplotlib.projections as mp
from matplotlib.projections import PolarAxes
import matplotlib.patches as patches
import math

class TaylorDiagram(object):
    """Taylor Diagram with color-coded axes & short station labels. 
       Adapted from your reference snippet with 'apply_theta_transforms=False' 
       to avoid Matplotlib 3.9+ deprecation warnings."""

    def __init__(self, STD, fig=None, rect=111, label='_'):
        self.STD = STD

        # Use apply_theta_transforms=False to avoid deprecation warnings
        tr = PolarAxes.PolarTransform(apply_theta_transforms=False)

        # Correlation labels, now formatted with max 2 decimal places
        rlocs = np.concatenate(((np.arange(0,1.1,0.1)), [0.95, 0.99]))
        tlocs = np.arccos(rlocs)
        tf1 = gf.DictFormatter(dict(zip(tlocs, map(lambda x: f"{x:.2f}", rlocs))))  # ✅ FIXED FORMAT

        # STDev extent
        self.smin = 0
        self.smax = 2 * self.STD

        gh = fa.GridHelperCurveLinear(
            tr, extremes=(0, np.pi/2, self.smin, self.smax),
            grid_locator1=gf.FixedLocator(tlocs), tick_formatter1=tf1
        )

        if fig is None:
            fig = plt.figure(figsize=(8, 8))

        ax = fa.FloatingSubplot(fig, rect, grid_helper=gh)
        fig.add_subplot(ax)

        # Correlation (red, top)
        ax.axis['top'].set_axis_direction('bottom')
        ax.axis['top'].label.set_text("Correlation Coefficient")
        ax.axis['top'].label.set_color("red")
        ax.axis['top'].label.set_fontsize(14)
        ax.axis['top'].label.set_rotation(180)
        ax.axis['top'].label.set_pad(30)
        ax.axis['top'].toggle(ticklabels=True, label=True)
        ax.axis['top'].major_ticklabels.set_rotation(180)
        ax.axis['top'].major_ticklabels.set_pad(2)
        ax.axis['top'].major_ticklabels.set_color("red")
        ax.axis['top'].line.set_color("red")

        # Centered RMSE (blue, left)
        ax.axis['left'].set_axis_direction('bottom')
        ax.axis['left'].label.set_text("Centered RMSE")
        ax.axis['left'].label.set_color("blue")
        ax.axis['left'].label.set_fontsize(14)
        ax.axis['left'].label.set_pad(20)
        ax.axis['left'].toggle(ticklabels=False, label=True)

        # Standard Deviation (black, right)
        ax.axis['right'].set_axis_direction('top')
        ax.axis['right'].label.set_text("Standard Deviation")
        ax.axis['right'].label.set_fontsize(14)
        ax.axis['right'].toggle(ticklabels=True, label=True)

        # Hide bottom
        ax.axis['bottom'].set_visible(False)

        ax.grid()
        self._ax = ax            # Graphical axes
        self.ax = ax.get_aux_axes(tr)  # Polar coordinates

        # Reference star = Observed
        l, = self.ax.plot([0], self.STD, 'k*', ls='', ms=12, label=label)
        # Draw STD contour
        t = np.linspace(0, np.pi/2, 100)
        r = np.zeros_like(t)+self.STD
        self.ax.plot(t, r, 'k--', label='_')

        self.samplePoints = [l]

    def add_sample(self, stdev, corr, *args, **kwargs):
        """Add a point to the diagram. stdev => radial, corr => angle."""
        theta = np.arccos(corr)
        l, = self.ax.plot(theta, stdev, *args, **kwargs)
        self.samplePoints.append(l)
        return l

    def add_contours(self, levels=5, **kwargs):
        """Add centered RMSE contours in blue."""
        import math
        rs, ts = np.meshgrid(np.linspace(self.smin, self.smax, 100),
                             np.linspace(0, math.pi/2, 100))
        rmse = np.sqrt(self.STD**2 + rs**2 - 2*self.STD*rs*np.cos(ts))
        cont = self.ax.contour(ts, rs, rmse, levels, colors="blue", **kwargs)
        return cont

def short_station_name(full_name):
    # Return only first word
    return full_name.split()[0]

def create_taylor_diagram(df, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name"):
    """Plot the improved Taylor Diagram with color-coded axes, 
       partial station labels, & top 10% correlation in different color."""

    # Drop rows missing needed columns
    df_td = df.dropna(subset=[ref_col, std_col, corr_col])
    if df_td.empty:
        print("Not enough data for Taylor Diagram.")
        return

    # Reference STD is average of the reference col
    ref_std_val = df_td[ref_col].mean()
    # Determine top 10% correlation
    thr_cc = np.percentile(df_td[corr_col].values, 90)

    # Prepare diagram
    fig = plt.figure(figsize=(8,8))
    dia = TaylorDiagram(ref_std_val, fig=fig, rect=111, label='OBS')
    # Add RMSE contours
    ctn = dia.add_contours(levels=6)
    plt.clabel(ctn, inline=1, fontsize=10)

    # We'll color the top-10% correlation samples differently.
    # Also show only the first word of station name in legend
    norm = Normalize(vmin=0, vmax=len(df_td))
    color_map = cm.get_cmap("tab20", len(df_td))

    for i, row in enumerate(df_td.itertuples()):
        stdev = getattr(row, std_col)
        corr  = getattr(row, corr_col)
        stn   = getattr(row, stn_name_col)
        short_label = short_station_name(stn)

        # If correlation >= thr_cc => good performing => let's color them
        # else use a default style
        if corr >= thr_cc:
            # Different color and marker
            mk_style = dict(marker='o', ms=6, 
                            mec=color_map(i), mfc='none', mew=1.6,
                            label=short_label)
        else:
            mk_style = dict(marker='o', ms=4, 
                            mec='gray', mfc='none', mew=1,
                            label='_')  # underscore => not in legend

        dia.add_sample(stdev, corr, **mk_style)

    # Build legend from the sample points that have label != "_"
    labels_all = [p.get_label() for p in dia.samplePoints]
    handles_all= [p for p in dia.samplePoints]
    # Filter out the undesired ones
    final_pairs = [(h,l) for (h,l) in zip(handles_all,labels_all) if l!="_"]
    if final_pairs:
        handles_ok, labels_ok = zip(*final_pairs)
        plt.legend(handles_ok, labels_ok, numpoints=1, prop=dict(size=6), 
                   loc='upper right', title="Best CC Grids")

    # Title
    dia._ax.set_title("Taylor Diagram prism", fontsize=12, fontweight="bold")

    out_fn = os.path.join(output_dir, "Improved_TaylorDiagram.png")
    plt.savefig(out_fn, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved improved Taylor Diagram => {out_fn}")

##############################################################################
# 8. CREATE IMPROVED TAYLOR DIAGRAM
##############################################################################
# We'll use 'STD' as the reference column & std col, 'CC' as correlation
# If you have a separate obs STD column, rename above
create_taylor_diagram(df_merged, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name")

##############################################################################
# 9. OUTPUT A “METRICS EVALUATION” TABLE
##############################################################################
print("\nFinal Metrics Evaluation Table (All Stations):")
cols_for_eval = ["MBE","RMSE","STD","CC","d"]
avail_cols = [c for c in cols_for_eval if c in df_merged.columns]
eval_table = df_merged[avail_cols].agg(["count","mean","std","min","max"])
print(eval_table)

eval_table_out = os.path.join(output_dir,"Overall_Metrics_Evaluation.csv")
eval_table.to_csv(eval_table_out)
print(f"Saved overall metrics evaluation => {eval_table_out}")

print("\n✅ Done! Renamed dimension for NetCDF, renamed Index_of_Agreement to 'd', produced improved Taylor Diagram, and exported everything.")


In [None]:
# Maps and Taylor diagram of the statistical metrics for prcp EMDNA
# --------------------------------------------------------------------
# This version iterates over the 10 requested ensembles:
#   1, 11, 21, 31, 41, 51, 61, 71, 81, 91
# --------------------------------------------------------------------

import os
import math
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from shapely.geometry import Point
from matplotlib.lines import Line2D
import matplotlib.cm as cm
from matplotlib.colors import Normalize

import mpl_toolkits.axisartist.grid_finder  as gf
import mpl_toolkits.axisartist.floating_axes as fa
from   matplotlib.projections import PolarAxes

# ────────────────────────────────────────────────────────────────────────────
# 0.  CONSTANT FILES (shared across ensembles)
# ────────────────────────────────────────────────────────────────────────────
physical_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"
target_crs     = "EPSG:4326"

print("▶ Loading static inputs (station catalogue & lakes)…")
df_phys = (pd.read_csv(physical_file)
           .rename(columns={"NAME": "station_name",
                            "LATITUDE": "lat",
                            "LONGITUDE": "lon",
                            "Elevation": "elev"}))

gdf_lakes = gpd.read_file(shapefile_path)
if gdf_lakes.crs is not None:
    gdf_lakes = gdf_lakes.to_crs(target_crs)
else:
    gdf_lakes.crs = target_crs
lon_min, lat_min, lon_max, lat_max = gdf_lakes.total_bounds
print("   Great-Lakes bounds:", (lon_min, lat_min, lon_max, lat_max))

# ────────────────────────────────────────────────────────────────────────────
# 1.  LOOP OVER ENSEMBLES
# ────────────────────────────────────────────────────────────────────────────
# root folder that contains the 10 ensemble sub-directories
root_dir = (r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder"
            r"\Ensemble files\EMDNA_GLB_Precipitation")

ENSEMBLES = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]

for ens in ENSEMBLES:
    print("\n" + "=" * 79)
    print(f"⧉  Processing ensemble {ens}  ⧉")
    print("=" * 79)

    # -----------------------------------------------------------------------
    # 1-A.  FILE PATHS  (ensemble-specific)
    # -----------------------------------------------------------------------
    base_dir = os.path.join(root_dir, str(ens))

    nc_file = os.path.join(
        base_dir, "daily_loop",
        f"emdna_vs_stations_25km_LWR_1991_2012_prcp_{ens:03d}.nc"
    )
    metrics_csv = os.path.join(
        base_dir, "metrics",
        f"station_metrics_25km_LWR_EMDNA_1991_2012_prcp_{ens:03d}.csv"
    )

    output_dir = os.path.join(base_dir, "Maps_and_Taylor")
    os.makedirs(output_dir, exist_ok=True)

    # -----------------------------------------------------------------------
    # 2.  LOAD NETCDF (RENAME DIM) & METRICS CSV
    # -----------------------------------------------------------------------
    print("Loading EMDNA vs Station NetCDF …")
    if not os.path.isfile(nc_file):
        print(f"   ⚠ NetCDF missing for ensemble {ens}: {nc_file}")
        continue
    ds_raw = xr.open_dataset(nc_file)

    ds = ds_raw.rename({'station_index': 'station'})
    print("Dims after rename:", ds.dims)
    print("Data vars:", list(ds.data_vars))

    print("\nLoading station-level metrics CSV …")
    if not os.path.isfile(metrics_csv):
        print(f"   ⚠ Metrics CSV missing for ensemble {ens}: {metrics_csv}")
        continue
    
    df_metrics_all = pd.read_csv(metrics_csv)
    # standardise column name
    if "Index_of_Agreement" in df_metrics_all.columns and "d" not in df_metrics_all.columns:
        df_metrics_all = df_metrics_all.rename(columns={"Index_of_Agreement": "d"})
    
    # ── only one variable (“prcp”) in this workflow ──────────────────────────
    metrics_by_var = {"prcp": df_metrics_all.copy()}

    # -----------------------------------------------------------------------
    # 3.  MERGE STATION COORDINATES
    # -----------------------------------------------------------------------
    print("\nMerging station coordinates …")
    merged_by_var = {}
    for var, df_met in metrics_by_var.items():          # var will be "prcp"
        merged = pd.merge(df_met, df_phys, on="station_name", how="inner")
        merged_by_var[var] = merged
        print(f"   {var}: merged {len(merged)} rows")

    # -----------------------------------------------------------------------
    # 4.  BASIC CHECKS ON NETCDF
    # -----------------------------------------------------------------------
    time_dim    = ds["time"].size
    station_dim = ds["station"].size
    time_min    = ds["time"].values.min()
    time_max    = ds["time"].values.max()
    print(f"\nNetCDF summary → time steps: {time_dim:,}, stations: {station_dim}")
    print("Time span      →", str(time_min), "to", str(time_max))

    # -----------------------------------------------------------------------
    # 5.  MAP STATISTICAL METRICS WITH HOTSPOTS
    # -----------------------------------------------------------------------
    metrics_list = ["MBE", "RMSE", "STD", "CC", "d"]
    titles_dict  = { "MBE":"Mean Bias Error (MBE)",
                     "RMSE":"Root Mean Square Error (RMSE)",
                     "STD":"Standard Deviation (STD)",
                     "CC":"Correlation Coefficient (CC)",
                     "d":"Index of Agreement (d)" }

    print("\nGenerating metric maps with hotspots …")

    for var, df_merged in merged_by_var.items():        # only "prcp"
        for metric in metrics_list:
            if metric not in df_merged.columns:
                continue

            fig, ax = plt.subplots(figsize=(10, 8),
                                   subplot_kw=dict(projection=ccrs.PlateCarree()))
            ax.set_extent([lon_min, lon_max, lat_min, lat_max])

            ax.add_feature(cfeature.BORDERS,  linestyle=':')
            ax.add_feature(cfeature.LAKES,    alpha=0.4)
            ax.add_feature(cfeature.COASTLINE)

            for geom in gdf_lakes.geometry:
                ax.add_geometries([geom], ccrs.PlateCarree(),
                                  facecolor='none', edgecolor='blue', linewidth=1)

            sc = ax.scatter(df_merged["lon"], df_merged["lat"],
                            c=df_merged[metric], cmap="viridis", s=60,
                            edgecolor="k", transform=ccrs.PlateCarree())
            cb = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
            cb.set_label(metric)

            plt.title(f"{var.upper()} – {titles_dict.get(metric, metric)}")

            vals = df_merged[metric].dropna().values
            if metric in ["MBE", "RMSE", "STD"]:
                thr = np.percentile(vals, 90)
                hotspot = df_merged[metric] >= thr
            else:
                thr = np.percentile(vals, 10)
                hotspot = df_merged[metric] <= thr

            ax.scatter(df_merged.loc[hotspot, "lon"],
                       df_merged.loc[hotspot, "lat"],
                       facecolors='none', edgecolors='red', s=90, linewidths=1.5,
                       transform=ccrs.PlateCarree())

            gl = ax.gridlines(draw_labels=True, linewidth=0.4,
                              linestyle='--', alpha=0.5, color='gray')
            gl.right_labels = gl.top_labels = False

            out_fn = os.path.join(output_dir, f"{var}_Map_{metric}.png")
            plt.savefig(out_fn, dpi=300, bbox_inches="tight")
            plt.close()
            print(f"      ↳ saved {out_fn}")

    # -----------------------------------------------------------------------
    # 6.  IMPROVED TAYLOR DIAGRAM  (class & helper are **unchanged**)
    # -----------------------------------------------------------------------
    print("\nCreating the improved Taylor Diagram from metrics …")

    class TaylorDiagram(object):
        """Taylor diagram with colour-coded axes & short station labels
           (identical styling to the original tmean diagram)."""

        def __init__(self, ref_std, fig=None, rect=111, label='OBS'):
            self.ref_std = ref_std

            tr = PolarAxes.PolarTransform(apply_theta_transforms=False)

            rlocs = np.concatenate((np.arange(0, 1.1, 0.1), [0.95, 0.99]))
            tlocs = np.arccos(rlocs)
            tf1   = gf.DictFormatter(dict(zip(tlocs, [f"{r:.2f}" for r in rlocs])))

            self.smin, self.smax = 0, 1.6 * ref_std
            gh = fa.GridHelperCurveLinear(
                tr,
                extremes=(0, np.pi/2, self.smin, self.smax),
                grid_locator1=gf.FixedLocator(tlocs),
                tick_formatter1=tf1
            )

            if fig is None:
                fig = plt.figure(figsize=(8, 8))

            ax = fa.FloatingSubplot(fig, rect, grid_helper=gh)
            fig.add_subplot(ax)

            ax.axis['top'   ].set_axis_direction('bottom')
            ax.axis['top'   ].label.set_text("Correlation Coefficient")
            ax.axis['top'   ].label.set_color("red")
            ax.axis['top'   ].label.set_fontsize(14)
            ax.axis['top'   ].label.set_rotation(180)
            ax.axis['top'   ].label.set_pad(30)
            ax.axis['top'   ].toggle(ticklabels=True, label=True)
            ax.axis['top'   ].major_ticklabels.set_rotation(180)
            ax.axis['top'   ].major_ticklabels.set_color("red")
            ax.axis['top'   ].line.set_color("red")

            ax.axis['left'  ].set_axis_direction('bottom')
            ax.axis['left'  ].label.set_text("Centered RMSE")
            ax.axis['left'  ].label.set_color("blue")
            ax.axis['left'  ].label.set_fontsize(14)
            ax.axis['left'  ].toggle(ticklabels=False, label=True)

            ax.axis['right' ].set_axis_direction('top')
            ax.axis['right' ].label.set_text("Standard Deviation")
            ax.axis['right' ].label.set_fontsize(14)
            ax.axis['right' ].toggle(ticklabels=True, label=True)
            ax.axis['right' ].major_ticklabels.set_pad(2)
            ax.axis['right' ].major_ticklabels.set_fontsize(10)
            ax.axis['right' ].major_ticklabels.set_color("black")

            ax.axis['bottom'].set_visible(False)
            ax.grid()

            self._ax = ax
            self.ax  = ax.get_aux_axes(tr)

            self.ax.plot([0], [ref_std], 'k*', ms=12, label=label)
            t = np.linspace(0, np.pi/2, 100)
            self.ax.plot(t, np.full_like(t, ref_std), 'k--')

            self.samples = []

        def add_sample(self, std, corr, **kwargs):
            θ = np.arccos(corr)
            p, = self.ax.plot(θ, std, **kwargs)
            self.samples.append(p)
            return p

        def add_contours(self, levels=6, **kwargs):
            rs, ts = np.meshgrid(np.linspace(self.smin, self.smax, 100),
                                 np.linspace(0, math.pi/2, 100))
            rmse = np.sqrt(self.ref_std**2 + rs**2
                           - 2*self.ref_std*rs*np.cos(ts))
            return self.ax.contour(ts, rs, rmse, levels,
                                   colors="blue", **kwargs)

    def short_name(full):
        return full.split()[0]

    def create_taylor(df, var):
        df_plot = df.dropna(subset=["STD", "CC"])
        if df_plot.empty:
            print(f"   – no valid data for {var}")
            return

        ref_std = df_plot["STD"].mean()
        td = TaylorDiagram(ref_std)

        cs = td.add_contours(levels=6)
        plt.clabel(cs, inline=1, fontsize=9, fmt="%.0f")

        thr  = np.percentile(df_plot["CC"], 90)
        cmap = cm.get_cmap("tab20", len(df_plot))

        for i, row in df_plot.iterrows():
            kwargs = dict(marker='o',
                          ms   =6 if row["CC"] >= thr else 4,
                          mec  =cmap(i) if row["CC"] >= thr else "gray",
                          mfc  ="none",
                          mew  =1.6 if row["CC"] >= thr else 1,
                          label=short_name(row["station_name"])
                                if row["CC"] >= thr else "_")
            td.add_sample(row["STD"], row["CC"], **kwargs)

        handles = [h for h in td.samples if h.get_label() != "_"]
        labels  = [h.get_label() for h in handles]
        if handles:
            plt.legend(handles, labels, numpoints=1, prop=dict(size=6),
                       loc="upper right", title="Best CC Grids")

        td._ax.set_title(f"Taylor Diagram – EMDNA PRCP  (ensemble {ens})",
                         fontsize=12, fontweight="bold")

        out_png = os.path.join(output_dir,
                   f"Improved_TaylorDiagram_{var}.png")
        plt.savefig(out_png, dpi=300, bbox_inches="tight")
        plt.close()
        print(f"      ↳ saved {out_png}")

    for var in ["prcp"]:
        print(f"• Taylor for {var}")
        create_taylor(merged_by_var[var], var)

    # -----------------------------------------------------------------------
    # 7.  OUTPUT OVERALL METRICS TABLE
    # -----------------------------------------------------------------------
    print("\nFinal Metrics Evaluation Table (All Stations – PRCP):")
    combined_df = merged_by_var["prcp"].copy()

    cols_for_eval = ["MBE", "RMSE", "STD", "CC", "d"]
    avail_cols    = [c for c in cols_for_eval if c in combined_df.columns]

    eval_table = combined_df[avail_cols].agg(
                    ["count", "mean", "std", "min", "max"])
    print(eval_table)

    eval_table_out = os.path.join(output_dir,
                      "Overall_Metrics_Evaluation.csv")
    eval_table.to_csv(eval_table_out)
    print(f"      ↳ saved overall metrics evaluation => {eval_table_out}")


In [None]:
# Maps and Taylor diagram of the statistical metrics for ERA5 (prcp)

import os
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from shapely.geometry import Point
from matplotlib.lines import Line2D
import matplotlib.cm as cm
from matplotlib.colors import Normalize

##############################################################################
# 1. FILE PATHS
##############################################################################
nc_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified for prcp\daily_loop\era5_vs_stations_8Nearest_LWR_1991_2012.nc"
metrics_csv   = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified for prcp\metrics\station_metrics_8Nearest_LWR_ERA5_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
shapefile_path= r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp     = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"
target_crs    = "EPSG:4326"  
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified for prcp\Maps_and_Taylor2"

os.makedirs(output_dir, exist_ok=True)

##############################################################################
# 2. LOAD NETCDF (RENAME DIMENSION) & METRICS CSV
##############################################################################
print("Loading ERA5 vs Station NetCDF ...")
ds_raw = xr.open_dataset(nc_file)
print("Original dims in NetCDF:", ds_raw.dims)

# Rename dimension "variable" -> "station"
ds = ds_raw.rename({"variable": "station"})
print("\nAfter renaming => ds.dims:", ds.dims)
print("Data variables in NetCDF:", list(ds.data_vars))

print("\nLoading station-level metrics from CSV ...")
df_metrics = pd.read_csv(metrics_csv)
print("Loaded metrics CSV with columns:", df_metrics.columns.tolist())

# The code uses "d" for the Index of Agreement. If your CSV uses "Index_of_Agreement",
# rename it to "d" so references to 'd' won't fail:
if "Index_of_Agreement" in df_metrics.columns and "d" not in df_metrics.columns:
    df_metrics = df_metrics.rename(columns={"Index_of_Agreement": "d"})

##############################################################################
# 3. LOAD & MERGE STATION COORDINATES
##############################################################################
print("\nLoading station location info ...")
df_physical = pd.read_csv(physical_file)
df_physical = df_physical.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})

# Merge on station_name (assuming it matches)
common_col = "station_name"
if common_col not in df_metrics.columns:
    print(f"WARNING: {common_col} not in df_metrics. Aborting or adjust code.")
    # Might raise an error or do a different approach

print("\nMerging metrics with physical station coords ...")
df_merged = pd.merge(df_metrics, df_physical, how="inner", on=common_col)
print(f"Merged shape: {df_merged.shape}")
print("Columns:", df_merged.columns.tolist())

##############################################################################
# 4. BASIC CHECKS ON NETCDF
##############################################################################
time_dim = ds["time"].size
station_dim = ds["station"].size
print(f"\nNetCDF: # time steps = {time_dim}, # stations = {station_dim}")
time_min = ds["time"].values.min()
time_max = ds["time"].values.max()
print("Time range in NetCDF =>", str(time_min), "to", str(time_max))

##############################################################################
# 5. LOAD GREAT LAKES SHAPEFILE
##############################################################################
print("\nLoading Great Lakes shapefile ...")
gdf_lakes = gpd.read_file(shapefile_path)
if gdf_lakes.crs is not None:
    gdf_lakes = gdf_lakes.to_crs(target_crs)
else:
    gdf_lakes.crs = target_crs

lon_min, lat_min, lon_max, lat_max = gdf_lakes.total_bounds
print("Great Lakes shapefile loaded. Bounds:", (lon_min, lat_min, lon_max, lat_max))

##############################################################################
# 6. MAP STATISTICAL METRICS (MBE,RMSE,STD,CC,d) WITH “HOTSPOTS”
##############################################################################
metrics_list = ["MBE","RMSE","STD","CC","d"]
titles_dict  = {
    "MBE":  "Mean Bias Error (MBE)",
    "RMSE": "Root Mean Square Error (RMSE)",
    "STD":  "Standard Deviation (STD)",
    "CC":   "Correlation Coefficient (CC)",
    "d":    "Index of Agreement (d)"
}

print("\nGenerating metric maps with hotspots ...")

for metric in metrics_list:
    if metric not in df_merged.columns:
        print(f"Metric {metric} not found in df_merged. Skipping.")
        continue

    fig, ax = plt.subplots(figsize=(10, 8), subplot_kw=dict(projection=ccrs.PlateCarree()))
    ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())

    # Base map
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.LAKES, alpha=0.4)
    ax.add_feature(cfeature.COASTLINE)

    # Plot the Great Lakes boundary
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)

    # Plot station metrics
    sc = ax.scatter(df_merged["lon"], df_merged["lat"],
                    c=df_merged[metric], cmap="viridis", s=60, edgecolor="k",
                    transform=ccrs.PlateCarree())
    cb = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
    cb.set_label(metric, fontsize=12)

    plt.title(f"Spatial Distribution: {titles_dict.get(metric, metric)}", fontsize=14)

    # Identify hotspots
    # For MBE, RMSE, STD => top 10%
    # For CC, d => bottom 10%
    vals = df_merged[metric].dropna().values
    if metric in ["MBE","RMSE","STD"]:
        thr = np.percentile(vals, 90)
        hotspot = df_merged[metric] >= thr
        label_txt = f"Hotspot >= {thr:.2f}"
    else:
        thr = np.percentile(vals, 10)
        hotspot = df_merged[metric] <= thr
        label_txt = f"Hotspot <= {thr:.2f}"

    ax.scatter(df_merged.loc[hotspot,"lon"], df_merged.loc[hotspot,"lat"],
               facecolors='none', edgecolors='red', s=90, linewidths=1.5,
               transform=ccrs.PlateCarree(), label=label_txt)

    # Add Lat/Lon gridlines
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.right_labels = False
    gl.top_labels   = False
    gl.xlabel_style = {'size':10}
    gl.ylabel_style = {'size':10}

    # Legend
    station_handle = Line2D([],[], marker='o', color='k', linestyle='None', markersize=7, label='Stations')
    handles, labels = ax.get_legend_handles_labels()
    handles.append(station_handle)
    labels.append('Stations')
    ax.legend(handles=handles, labels=labels, loc='upper right', fontsize=9)

    out_fn = os.path.join(output_dir, f"Map_{metric}.png")
    plt.savefig(out_fn, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved metric map => {out_fn}")

##############################################################################
# 7. IMPROVED TAYLOR DIAGRAM
##############################################################################
print("\nCreating the improved Taylor Diagram from metrics ...")

import mpl_toolkits.axisartist.grid_finder as gf
import mpl_toolkits.axisartist.floating_axes as fa
import matplotlib.projections as mp
from matplotlib.projections import PolarAxes
import matplotlib.patches as patches
import math
from matplotlib.lines import Line2D

# colour-blind-safe axis colours
CB_RED  = "#D55E00"   # replaces pure red
CB_BLUE = "#0072B2"   # replaces pure blue

class TaylorDiagram(object):
    """Taylor Diagram with color-coded axes & short station labels. 
       Adapted from your reference snippet with 'apply_theta_transforms=False' 
       to avoid Matplotlib 3.9+ deprecation warnings."""

    def __init__(self, STD, fig=None, rect=111, label='_'):
        self.STD = STD

        # Use apply_theta_transforms=False to avoid deprecation warnings
        tr = PolarAxes.PolarTransform(apply_theta_transforms=False)

        # Correlation labels, now formatted with max 2 decimal places
        rlocs = np.concatenate(((np.arange(0,1.1,0.1)), [0.95, 0.99]))
        tlocs = np.arccos(rlocs)
        tf1 = gf.DictFormatter(dict(zip(tlocs, map(lambda x: f"{x:.2f}", rlocs))))  # ✅ FIXED FORMAT

        # STDev extent
        self.smin = 0
        self.smax = 1.6 * self.STD

        gh = fa.GridHelperCurveLinear(
            tr, extremes=(0, np.pi/2, self.smin, self.smax),
            grid_locator1=gf.FixedLocator(tlocs), tick_formatter1=tf1
        )

        if fig is None:
            fig = plt.figure(figsize=(8, 8))

        ax = fa.FloatingSubplot(fig, rect, grid_helper=gh)
        fig.add_subplot(ax)

        # Correlation (red, top)
        ax.axis['top'].set_axis_direction('bottom')
        ax.axis['top'].label.set_text("Correlation Coefficient")
        ax.axis['top'].label.set_color(CB_RED)
        ax.axis['top'].label.set_fontsize(14)
        ax.axis['top'].label.set_rotation(180)
        ax.axis['top'].label.set_pad(30)
        ax.axis['top'].toggle(ticklabels=True, label=True)
        ax.axis['top'].major_ticklabels.set_rotation(180)
        ax.axis['top'].major_ticklabels.set_pad(2)
        ax.axis['top'].major_ticklabels.set_color(CB_RED)
        ax.axis['top'].line.set_color(CB_RED)

        # Centered RMSE (blue, left)
        ax.axis['left'].set_axis_direction('bottom')
        ax.axis['left'].label.set_text("Centered RMSE")
        ax.axis['left'].label.set_color(CB_BLUE)
        ax.axis['left'].label.set_fontsize(14)
        ax.axis['left'].label.set_pad(20)
        ax.axis['left'].toggle(ticklabels=False, label=True)

        # Standard Deviation (black, right)
        ax.axis['right'].set_axis_direction('top')
        ax.axis['right'].label.set_text("Standard Deviation")
        ax.axis['right'].label.set_fontsize(14)
        ax.axis['right'].toggle(ticklabels=True, label=True)

        # Hide bottom
        ax.axis['bottom'].set_visible(False)

        ax.grid()
        self._ax = ax            # Graphical axes
        self.ax = ax.get_aux_axes(tr)  # Polar coordinates

        # Reference star = Observed
        l, = self.ax.plot([0], self.STD, 'k*', ls='', ms=12, label=label)
        # Draw STD contour
        t = np.linspace(0, np.pi/2, 100)
        r = np.zeros_like(t)+self.STD
        self.ax.plot(t, r, 'k--', label='_')

        self.samplePoints = [l]

    def add_sample(self, stdev, corr, *args, **kwargs):
        """Add a point to the diagram. stdev => radial, corr => angle."""
        theta = np.arccos(corr)
        l, = self.ax.plot(theta, stdev, *args, **kwargs)
        self.samplePoints.append(l)
        return l

    def add_contours(self, levels=5, **kwargs):
        """Add centered RMSE contours in blue."""
        import math
        rs, ts = np.meshgrid(np.linspace(self.smin, self.smax, 100),
                             np.linspace(0, math.pi/2, 100))
        rmse = np.sqrt(self.STD**2 + rs**2 - 2*self.STD*rs*np.cos(ts))
        cont = self.ax.contour(ts, rs, rmse, levels, colors=CB_BLUE, **kwargs)
        return cont

def short_station_name(full_name):
    # Return only first word
    return full_name.split()[0]

def create_taylor_diagram(df, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name"):
    """Plot the improved Taylor Diagram with color-coded axes, 
       partial station labels, & top 10% correlation in different color."""

    # Drop rows missing needed columns
    df_td = df.dropna(subset=[ref_col, std_col, corr_col])
    if df_td.empty:
        print("Not enough data for Taylor Diagram.")
        return

    # Reference STD is average of the reference col
    ref_std_val = df_td[ref_col].mean()
    # Determine top 10% correlation
    thr_cc = np.percentile(df_td[corr_col].values, 90)

    # Prepare diagram
    fig = plt.figure(figsize=(8,8))
    dia = TaylorDiagram(ref_std_val, fig=fig, rect=111, label='OBS')
    # Add RMSE contours
    ctn = dia.add_contours(levels=6)
    plt.clabel(ctn, inline=1, fontsize=10)

    # We'll color the top-10% correlation samples differently.
    # Also show only the first word of station name in legend
    norm = Normalize(vmin=0, vmax=len(df_td))
    color_map = cm.get_cmap("tab20", len(df_td))

    for i, row in enumerate(df_td.itertuples()):
        stdev = getattr(row, std_col)
        corr  = getattr(row, corr_col)
        stn   = getattr(row, stn_name_col)
        short_label = short_station_name(stn)

        # If correlation >= thr_cc => good performing => let's color them
        # else use a default style
        if corr >= thr_cc:                      # ← coloured, high-CC group
            mk_style = dict(marker='o', ms=6,
                            mec=color_map(i), mfc='none', mew=1.6,
                            label='_best')      # ← give a private label
        else:                                   # ← grey default group
            mk_style = dict(marker='o', ms=4,
                            mec='gray', mfc='none', mew=1,
                            label='_')          # ← suppressed in legend



        dia.add_sample(stdev, corr, **mk_style)

    # Build legend from the sample points that have label != "_"
        # ───────── legend: two entries, framed ─────────
    handle_all   = Line2D([], [], marker='o', ms=7, mfc='none', mec='gray',
                          linestyle='None', label='Grid/Station points')
    handle_best  = Line2D([], [], marker='o', ms=8, mfc='none', mec=CB_RED,
                          linestyle='None', mew=1.6, label='Top 90% CC Grids')

    plt.legend(handles=[handle_all, handle_best],
               #title="Legend",
               loc='upper right', frameon=True, framealpha=1,
               edgecolor='black', fontsize=10)


    # Title
    dia._ax.set_title("ERA5", fontsize=16, fontweight="bold")

    out_fn = os.path.join(output_dir, "Improved_TaylorDiagram.png")
    plt.savefig(out_fn, dpi=600, bbox_inches='tight', pad_inches=0.20)
    plt.close()
    print(f"Saved improved Taylor Diagram => {out_fn}")

##############################################################################
# 8. CREATE IMPROVED TAYLOR DIAGRAM
##############################################################################
# We'll use 'STD' as the reference column & std col, 'CC' as correlation
# If you have a separate obs STD column, rename above
create_taylor_diagram(df_merged, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name")

##############################################################################
# 9. OUTPUT A “METRICS EVALUATION” TABLE
##############################################################################
print("\nFinal Metrics Evaluation Table (All Stations):")
cols_for_eval = ["MBE","RMSE","STD","CC","d"]
avail_cols = [c for c in cols_for_eval if c in df_merged.columns]
eval_table = df_merged[avail_cols].agg(["count","mean","std","min","max"])
print(eval_table)

eval_table_out = os.path.join(output_dir,"Overall_Metrics_Evaluation.csv")
eval_table.to_csv(eval_table_out)
print(f"Saved overall metrics evaluation => {eval_table_out}")

print("\n✅ Done! Renamed dimension for NetCDF, renamed Index_of_Agreement to 'd', produced improved Taylor Diagram, and exported everything.")


In [None]:
# Maps and Taylor diagram of the statistical metrics for RDRS (prcp)

import os
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from shapely.geometry import Point
from matplotlib.lines import Line2D
import matplotlib.cm as cm
from matplotlib.colors import Normalize

##############################################################################
# 1. FILE PATHS
##############################################################################
nc_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\daily_loop\rdrs_vs_stations_25km_LWR_1991_2012.nc"
metrics_csv   = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\metrics\station_metrics_25km_LWR_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
shapefile_path= r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp     = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"
target_crs    = "EPSG:4326"  
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\Maps_and_Taylor"

os.makedirs(output_dir, exist_ok=True)

##############################################################################
# 2. LOAD NETCDF (RENAME DIMENSION) & METRICS CSV
##############################################################################
print("Loading RDRS vs Station NetCDF ...")
ds_raw = xr.open_dataset(nc_file)
print("Original dims in NetCDF:", ds_raw.dims)

# Rename dimension "variable" -> "station"
ds = ds_raw.rename({"variable": "station"})
print("\nAfter renaming => ds.dims:", ds.dims)
print("Data variables in NetCDF:", list(ds.data_vars))

print("\nLoading station-level metrics from CSV ...")
df_metrics = pd.read_csv(metrics_csv)
print("Loaded metrics CSV with columns:", df_metrics.columns.tolist())

# The code uses "d" for the Index of Agreement. If your CSV uses "Index_of_Agreement",
# rename it to "d" so references to 'd' won't fail:
if "Index_of_Agreement" in df_metrics.columns and "d" not in df_metrics.columns:
    df_metrics = df_metrics.rename(columns={"Index_of_Agreement": "d"})

##############################################################################
# 3. LOAD & MERGE STATION COORDINATES
##############################################################################
print("\nLoading station location info ...")
df_physical = pd.read_csv(physical_file)
df_physical = df_physical.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})

# Merge on station_name (assuming it matches)
common_col = "station_name"
if common_col not in df_metrics.columns:
    print(f"WARNING: {common_col} not in df_metrics. Aborting or adjust code.")
    # Might raise an error or do a different approach

print("\nMerging metrics with physical station coords ...")
df_merged = pd.merge(df_metrics, df_physical, how="inner", on=common_col)
print(f"Merged shape: {df_merged.shape}")
print("Columns:", df_merged.columns.tolist())

##############################################################################
# 4. BASIC CHECKS ON NETCDF
##############################################################################
time_dim = ds["time"].size
station_dim = ds["station"].size
print(f"\nNetCDF: # time steps = {time_dim}, # stations = {station_dim}")
time_min = ds["time"].values.min()
time_max = ds["time"].values.max()
print("Time range in NetCDF =>", str(time_min), "to", str(time_max))

##############################################################################
# 5. LOAD GREAT LAKES SHAPEFILE
##############################################################################
print("\nLoading Great Lakes shapefile ...")
gdf_lakes = gpd.read_file(shapefile_path)
if gdf_lakes.crs is not None:
    gdf_lakes = gdf_lakes.to_crs(target_crs)
else:
    gdf_lakes.crs = target_crs

lon_min, lat_min, lon_max, lat_max = gdf_lakes.total_bounds
print("Great Lakes shapefile loaded. Bounds:", (lon_min, lat_min, lon_max, lat_max))

##############################################################################
# 6. MAP STATISTICAL METRICS (MBE,RMSE,STD,CC,d) WITH “HOTSPOTS”
##############################################################################
metrics_list = ["MBE","RMSE","STD","CC","d"]
titles_dict  = {
    "MBE":  "Mean Bias Error (MBE)",
    "RMSE": "Root Mean Square Error (RMSE)",
    "STD":  "Standard Deviation (STD)",
    "CC":   "Correlation Coefficient (CC)",
    "d":    "Index of Agreement (d)"
}

print("\nGenerating metric maps with hotspots ...")

for metric in metrics_list:
    if metric not in df_merged.columns:
        print(f"Metric {metric} not found in df_merged. Skipping.")
        continue

    fig, ax = plt.subplots(figsize=(10, 8), subplot_kw=dict(projection=ccrs.PlateCarree()))
    ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())

    # Base map
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.LAKES, alpha=0.4)
    ax.add_feature(cfeature.COASTLINE)

    # Plot the Great Lakes boundary
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)

    # Plot station metrics
    sc = ax.scatter(df_merged["lon"], df_merged["lat"],
                    c=df_merged[metric], cmap="viridis", s=60, edgecolor="k",
                    transform=ccrs.PlateCarree())
    cb = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
    cb.set_label(metric, fontsize=12)

    plt.title(f"Spatial Distribution: {titles_dict.get(metric, metric)}", fontsize=14)

    # Identify hotspots
    # For MBE, RMSE, STD => top 10%
    # For CC, d => bottom 10%
    vals = df_merged[metric].dropna().values
    if metric in ["MBE","RMSE","STD"]:
        thr = np.percentile(vals, 90)
        hotspot = df_merged[metric] >= thr
        label_txt = f"Hotspot >= {thr:.2f}"
    else:
        thr = np.percentile(vals, 10)
        hotspot = df_merged[metric] <= thr
        label_txt = f"Hotspot <= {thr:.2f}"

    ax.scatter(df_merged.loc[hotspot,"lon"], df_merged.loc[hotspot,"lat"],
               facecolors='none', edgecolors='red', s=90, linewidths=1.5,
               transform=ccrs.PlateCarree(), label=label_txt)

    # Add Lat/Lon gridlines
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.right_labels = False
    gl.top_labels   = False
    gl.xlabel_style = {'size':10}
    gl.ylabel_style = {'size':10}

    # Legend
    station_handle = Line2D([],[], marker='o', color='k', linestyle='None', markersize=7, label='Stations')
    handles, labels = ax.get_legend_handles_labels()
    handles.append(station_handle)
    labels.append('Stations')
    ax.legend(handles=handles, labels=labels, loc='upper right', fontsize=9)

    out_fn = os.path.join(output_dir, f"Map_{metric}.png")
    plt.savefig(out_fn, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved metric map => {out_fn}")

##############################################################################
# 7. IMPROVED TAYLOR DIAGRAM
##############################################################################
print("\nCreating the improved Taylor Diagram from metrics ...")

import mpl_toolkits.axisartist.grid_finder as gf
import mpl_toolkits.axisartist.floating_axes as fa
import matplotlib.projections as mp
from matplotlib.projections import PolarAxes
import matplotlib.patches as patches
import math

class TaylorDiagram(object):
    """Taylor Diagram with color-coded axes & short station labels. 
       Adapted from your reference snippet with 'apply_theta_transforms=False' 
       to avoid Matplotlib 3.9+ deprecation warnings."""

    def __init__(self, STD, fig=None, rect=111, label='_'):
        self.STD = STD

        # Use apply_theta_transforms=False to avoid deprecation warnings
        tr = PolarAxes.PolarTransform(apply_theta_transforms=False)

        # Correlation labels, now formatted with max 2 decimal places
        rlocs = np.concatenate(((np.arange(0,1.1,0.1)), [0.95, 0.99]))
        tlocs = np.arccos(rlocs)
        tf1 = gf.DictFormatter(dict(zip(tlocs, map(lambda x: f"{x:.2f}", rlocs))))  # ✅ FIXED FORMAT

        # STDev extent
        self.smin = 0
        self.smax = 1.6 * self.STD

        gh = fa.GridHelperCurveLinear(
            tr, extremes=(0, np.pi/2, self.smin, self.smax),
            grid_locator1=gf.FixedLocator(tlocs), tick_formatter1=tf1
        )

        if fig is None:
            fig = plt.figure(figsize=(8, 8))

        ax = fa.FloatingSubplot(fig, rect, grid_helper=gh)
        fig.add_subplot(ax)

        # Correlation (red, top)
        ax.axis['top'].set_axis_direction('bottom')
        ax.axis['top'].label.set_text("Correlation Coefficient")
        ax.axis['top'].label.set_color("red")
        ax.axis['top'].label.set_fontsize(14)
        ax.axis['top'].label.set_rotation(180)
        ax.axis['top'].label.set_pad(30)
        ax.axis['top'].toggle(ticklabels=True, label=True)
        ax.axis['top'].major_ticklabels.set_rotation(180)
        ax.axis['top'].major_ticklabels.set_pad(2)
        ax.axis['top'].major_ticklabels.set_color("red")
        ax.axis['top'].line.set_color("red")

        # Centered RMSE (blue, left)
        ax.axis['left'].set_axis_direction('bottom')
        ax.axis['left'].label.set_text("Centered RMSE")
        ax.axis['left'].label.set_color("blue")
        ax.axis['left'].label.set_fontsize(14)
        ax.axis['left'].label.set_pad(20)
        ax.axis['left'].toggle(ticklabels=False, label=True)

        # Standard Deviation (black, right)
        ax.axis['right'].set_axis_direction('top')
        ax.axis['right'].label.set_text("Standard Deviation")
        ax.axis['right'].label.set_fontsize(14)
        ax.axis['right'].toggle(ticklabels=True, label=True)

        # Hide bottom
        ax.axis['bottom'].set_visible(False)

        ax.grid()
        self._ax = ax            # Graphical axes
        self.ax = ax.get_aux_axes(tr)  # Polar coordinates

        # Reference star = Observed
        l, = self.ax.plot([0], self.STD, 'k*', ls='', ms=12, label=label)
        # Draw STD contour
        t = np.linspace(0, np.pi/2, 100)
        r = np.zeros_like(t)+self.STD
        self.ax.plot(t, r, 'k--', label='_')

        self.samplePoints = [l]

    def add_sample(self, stdev, corr, *args, **kwargs):
        """Add a point to the diagram. stdev => radial, corr => angle."""
        theta = np.arccos(corr)
        l, = self.ax.plot(theta, stdev, *args, **kwargs)
        self.samplePoints.append(l)
        return l

    def add_contours(self, levels=5, **kwargs):
        """Add centered RMSE contours in blue."""
        import math
        rs, ts = np.meshgrid(np.linspace(self.smin, self.smax, 100),
                             np.linspace(0, math.pi/2, 100))
        rmse = np.sqrt(self.STD**2 + rs**2 - 2*self.STD*rs*np.cos(ts))
        cont = self.ax.contour(ts, rs, rmse, levels, colors="blue", **kwargs)
        return cont

def short_station_name(full_name):
    # Return only first word
    return full_name.split()[0]

def create_taylor_diagram(df, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name"):
    """Plot the improved Taylor Diagram with color-coded axes, 
       partial station labels, & top 10% correlation in different color."""

    # Drop rows missing needed columns
    df_td = df.dropna(subset=[ref_col, std_col, corr_col])
    if df_td.empty:
        print("Not enough data for Taylor Diagram.")
        return

    # Reference STD is average of the reference col
    ref_std_val = df_td[ref_col].mean()
    # Determine top 10% correlation
    thr_cc = np.percentile(df_td[corr_col].values, 90)

    # Prepare diagram
    fig = plt.figure(figsize=(8,8))
    dia = TaylorDiagram(ref_std_val, fig=fig, rect=111, label='OBS')
    # Add RMSE contours
    ctn = dia.add_contours(levels=6)
    plt.clabel(ctn, inline=1, fontsize=10)

    # We'll color the top-10% correlation samples differently.
    # Also show only the first word of station name in legend
    norm = Normalize(vmin=0, vmax=len(df_td))
    color_map = cm.get_cmap("tab20", len(df_td))

    for i, row in enumerate(df_td.itertuples()):
        stdev = getattr(row, std_col)
        corr  = getattr(row, corr_col)
        stn   = getattr(row, stn_name_col)
        short_label = short_station_name(stn)

        # If correlation >= thr_cc => good performing => let's color them
        # else use a default style
        if corr >= thr_cc:
            # Different color and marker
            mk_style = dict(marker='o', ms=6, 
                            mec=color_map(i), mfc='none', mew=1.6,
                            label=short_label)
        else:
            mk_style = dict(marker='o', ms=4, 
                            mec='gray', mfc='none', mew=1,
                            label='_')  # underscore => not in legend

        dia.add_sample(stdev, corr, **mk_style)

    # Build legend from the sample points that have label != "_"
    labels_all = [p.get_label() for p in dia.samplePoints]
    handles_all= [p for p in dia.samplePoints]
    # Filter out the undesired ones
    final_pairs = [(h,l) for (h,l) in zip(handles_all,labels_all) if l!="_"]
    if final_pairs:
        handles_ok, labels_ok = zip(*final_pairs)
        plt.legend(handles_ok, labels_ok, numpoints=1, prop=dict(size=6), 
                   loc='upper right', title="Best CC Grids")

    # Title
    dia._ax.set_title("Taylor Diagram RDRS", fontsize=12, fontweight="bold")

    out_fn = os.path.join(output_dir, "Improved_TaylorDiagram.png")
    plt.savefig(out_fn, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved improved Taylor Diagram => {out_fn}")

##############################################################################
# 8. CREATE IMPROVED TAYLOR DIAGRAM
##############################################################################
# We'll use 'STD' as the reference column & std col, 'CC' as correlation
# If you have a separate obs STD column, rename above
create_taylor_diagram(df_merged, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name")

##############################################################################
# 9. OUTPUT A “METRICS EVALUATION” TABLE
##############################################################################
print("\nFinal Metrics Evaluation Table (All Stations):")
cols_for_eval = ["MBE","RMSE","STD","CC","d"]
avail_cols = [c for c in cols_for_eval if c in df_merged.columns]
eval_table = df_merged[avail_cols].agg(["count","mean","std","min","max"])
print(eval_table)

eval_table_out = os.path.join(output_dir,"Overall_Metrics_Evaluation.csv")
eval_table.to_csv(eval_table_out)
print(f"Saved overall metrics evaluation => {eval_table_out}")

print("\n✅ Done! Renamed dimension for NetCDF, renamed Index_of_Agreement to 'd', produced improved Taylor Diagram, and exported everything.")


In [None]:
# Maps and Taylor diagram of the statistical metrics for MERRA-2 (prcp)

import os
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from shapely.geometry import Point
from matplotlib.lines import Line2D
import matplotlib.cm as cm
from matplotlib.colors import Normalize

##############################################################################
# 1. FILE PATHS
##############################################################################
nc_file       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\daily_loop\merra2_vs_stations_12Nearest_LWR_1991_2012.nc"
metrics_csv   = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\metrics\station_metrics_12Nearest_LWR_merra2_1991_2012.csv"
physical_file = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
shapefile_path= r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp     = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"
target_crs    = "EPSG:4326"  
output_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\Maps_and_Taylor"

os.makedirs(output_dir, exist_ok=True)

##############################################################################
# 2. LOAD NETCDF (RENAME DIMENSION) & METRICS CSV
##############################################################################
print("Loading MERRA2 vs Station NetCDF ...")
ds_raw = xr.open_dataset(nc_file)
print("Original dims in NetCDF:", ds_raw.dims)

# Rename dimension "variable" -> "station"
ds = ds_raw.rename({"variable": "station"})
print("\nAfter renaming => ds.dims:", ds.dims)
print("Data variables in NetCDF:", list(ds.data_vars))

print("\nLoading station-level metrics from CSV ...")
df_metrics = pd.read_csv(metrics_csv)
print("Loaded metrics CSV with columns:", df_metrics.columns.tolist())

# The code uses "d" for the Index of Agreement. If your CSV uses "Index_of_Agreement",
# rename it to "d" so references to 'd' won't fail:
if "Index_of_Agreement" in df_metrics.columns and "d" not in df_metrics.columns:
    df_metrics = df_metrics.rename(columns={"Index_of_Agreement": "d"})

##############################################################################
# 3. LOAD & MERGE STATION COORDINATES
##############################################################################
print("\nLoading station location info ...")
df_physical = pd.read_csv(physical_file)
df_physical = df_physical.rename(columns={
    "NAME": "station_name",
    "LATITUDE": "lat",
    "LONGITUDE": "lon",
    "Elevation": "elev"
})

# Merge on station_name (assuming it matches)
common_col = "station_name"
if common_col not in df_metrics.columns:
    print(f"WARNING: {common_col} not in df_metrics. Aborting or adjust code.")
    # Might raise an error or do a different approach

print("\nMerging metrics with physical station coords ...")
df_merged = pd.merge(df_metrics, df_physical, how="inner", on=common_col)
print(f"Merged shape: {df_merged.shape}")
print("Columns:", df_merged.columns.tolist())

##############################################################################
# 4. BASIC CHECKS ON NETCDF
##############################################################################
time_dim = ds["time"].size
station_dim = ds["station"].size
print(f"\nNetCDF: # time steps = {time_dim}, # stations = {station_dim}")
time_min = ds["time"].values.min()
time_max = ds["time"].values.max()
print("Time range in NetCDF =>", str(time_min), "to", str(time_max))

##############################################################################
# 5. LOAD GREAT LAKES SHAPEFILE
##############################################################################
print("\nLoading Great Lakes shapefile ...")
gdf_lakes = gpd.read_file(shapefile_path)
if gdf_lakes.crs is not None:
    gdf_lakes = gdf_lakes.to_crs(target_crs)
else:
    gdf_lakes.crs = target_crs

lon_min, lat_min, lon_max, lat_max = gdf_lakes.total_bounds
print("Great Lakes shapefile loaded. Bounds:", (lon_min, lat_min, lon_max, lat_max))

##############################################################################
# 6. MAP STATISTICAL METRICS (MBE,RMSE,STD,CC,d) WITH “HOTSPOTS”
##############################################################################
metrics_list = ["MBE","RMSE","STD","CC","d"]
titles_dict  = {
    "MBE":  "Mean Bias Error (MBE)",
    "RMSE": "Root Mean Square Error (RMSE)",
    "STD":  "Standard Deviation (STD)",
    "CC":   "Correlation Coefficient (CC)",
    "d":    "Index of Agreement (d)"
}

print("\nGenerating metric maps with hotspots ...")

for metric in metrics_list:
    if metric not in df_merged.columns:
        print(f"Metric {metric} not found in df_merged. Skipping.")
        continue

    fig, ax = plt.subplots(figsize=(10, 8), subplot_kw=dict(projection=ccrs.PlateCarree()))
    ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())

    # Base map
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.LAKES, alpha=0.4)
    ax.add_feature(cfeature.COASTLINE)

    # Plot the Great Lakes boundary
    for geom in gdf_lakes.geometry:
        ax.add_geometries([geom], ccrs.PlateCarree(), facecolor='none', edgecolor='blue', linewidth=1)

    # Plot station metrics
    sc = ax.scatter(df_merged["lon"], df_merged["lat"],
                    c=df_merged[metric], cmap="viridis", s=60, edgecolor="k",
                    transform=ccrs.PlateCarree())
    cb = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
    cb.set_label(metric, fontsize=12)

    plt.title(f"Spatial Distribution: {titles_dict.get(metric, metric)}", fontsize=14)

    # Identify hotspots
    # For MBE, RMSE, STD => top 10%
    # For CC, d => bottom 10%
    vals = df_merged[metric].dropna().values
    if metric in ["MBE","RMSE","STD"]:
        thr = np.percentile(vals, 90)
        hotspot = df_merged[metric] >= thr
        label_txt = f"Hotspot >= {thr:.2f}"
    else:
        thr = np.percentile(vals, 10)
        hotspot = df_merged[metric] <= thr
        label_txt = f"Hotspot <= {thr:.2f}"

    ax.scatter(df_merged.loc[hotspot,"lon"], df_merged.loc[hotspot,"lat"],
               facecolors='none', edgecolors='red', s=90, linewidths=1.5,
               transform=ccrs.PlateCarree(), label=label_txt)

    # Add Lat/Lon gridlines
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.right_labels = False
    gl.top_labels   = False
    gl.xlabel_style = {'size':10}
    gl.ylabel_style = {'size':10}

    # Legend
    station_handle = Line2D([],[], marker='o', color='k', linestyle='None', markersize=7, label='Stations')
    handles, labels = ax.get_legend_handles_labels()
    handles.append(station_handle)
    labels.append('Stations')
    ax.legend(handles=handles, labels=labels, loc='upper right', fontsize=9)

    out_fn = os.path.join(output_dir, f"Map_{metric}.png")
    plt.savefig(out_fn, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved metric map => {out_fn}")

##############################################################################
# 7. IMPROVED TAYLOR DIAGRAM
##############################################################################
print("\nCreating the improved Taylor Diagram from metrics ...")

import mpl_toolkits.axisartist.grid_finder as gf
import mpl_toolkits.axisartist.floating_axes as fa
import matplotlib.projections as mp
from matplotlib.projections import PolarAxes
import matplotlib.patches as patches
import math

class TaylorDiagram(object):
    """Taylor Diagram with color-coded axes & short station labels. 
       Adapted from your reference snippet with 'apply_theta_transforms=False' 
       to avoid Matplotlib 3.9+ deprecation warnings."""

    def __init__(self, STD, fig=None, rect=111, label='_'):
        self.STD = STD

        # Use apply_theta_transforms=False to avoid deprecation warnings
        tr = PolarAxes.PolarTransform(apply_theta_transforms=False)

        # Correlation labels, now formatted with max 2 decimal places
        rlocs = np.concatenate(((np.arange(0,1.1,0.1)), [0.95, 0.99]))
        tlocs = np.arccos(rlocs)
        tf1 = gf.DictFormatter(dict(zip(tlocs, map(lambda x: f"{x:.2f}", rlocs))))  # ✅ FIXED FORMAT

        # STDev extent
        self.smin = 0
        self.smax = 1.6 * self.STD

        gh = fa.GridHelperCurveLinear(
            tr, extremes=(0, np.pi/2, self.smin, self.smax),
            grid_locator1=gf.FixedLocator(tlocs), tick_formatter1=tf1
        )

        if fig is None:
            fig = plt.figure(figsize=(8, 8))

        ax = fa.FloatingSubplot(fig, rect, grid_helper=gh)
        fig.add_subplot(ax)

        # Correlation (red, top)
        ax.axis['top'].set_axis_direction('bottom')
        ax.axis['top'].label.set_text("Correlation Coefficient")
        ax.axis['top'].label.set_color("red")
        ax.axis['top'].label.set_fontsize(14)
        ax.axis['top'].label.set_rotation(180)
        ax.axis['top'].label.set_pad(30)
        ax.axis['top'].toggle(ticklabels=True, label=True)
        ax.axis['top'].major_ticklabels.set_rotation(180)
        ax.axis['top'].major_ticklabels.set_pad(2)
        ax.axis['top'].major_ticklabels.set_color("red")
        ax.axis['top'].line.set_color("red")

        # Centered RMSE (blue, left)
        ax.axis['left'].set_axis_direction('bottom')
        ax.axis['left'].label.set_text("Centered RMSE")
        ax.axis['left'].label.set_color("blue")
        ax.axis['left'].label.set_fontsize(14)
        ax.axis['left'].label.set_pad(20)
        ax.axis['left'].toggle(ticklabels=False, label=True)

        # Standard Deviation (black, right)
        ax.axis['right'].set_axis_direction('top')
        ax.axis['right'].label.set_text("Standard Deviation")
        ax.axis['right'].label.set_fontsize(14)
        ax.axis['right'].toggle(ticklabels=True, label=True)

        # Hide bottom
        ax.axis['bottom'].set_visible(False)

        ax.grid()
        self._ax = ax            # Graphical axes
        self.ax = ax.get_aux_axes(tr)  # Polar coordinates

        # Reference star = Observed
        l, = self.ax.plot([0], self.STD, 'k*', ls='', ms=12, label=label)
        # Draw STD contour
        t = np.linspace(0, np.pi/2, 100)
        r = np.zeros_like(t)+self.STD
        self.ax.plot(t, r, 'k--', label='_')

        self.samplePoints = [l]

    def add_sample(self, stdev, corr, *args, **kwargs):
        """Add a point to the diagram. stdev => radial, corr => angle."""
        theta = np.arccos(corr)
        l, = self.ax.plot(theta, stdev, *args, **kwargs)
        self.samplePoints.append(l)
        return l

    def add_contours(self, levels=5, **kwargs):
        """Add centered RMSE contours in blue."""
        import math
        rs, ts = np.meshgrid(np.linspace(self.smin, self.smax, 100),
                             np.linspace(0, math.pi/2, 100))
        rmse = np.sqrt(self.STD**2 + rs**2 - 2*self.STD*rs*np.cos(ts))
        cont = self.ax.contour(ts, rs, rmse, levels, colors="blue", **kwargs)
        return cont

def short_station_name(full_name):
    # Return only first word
    return full_name.split()[0]

def create_taylor_diagram(df, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name"):
    """Plot the improved Taylor Diagram with color-coded axes, 
       partial station labels, & top 10% correlation in different color."""

    # Drop rows missing needed columns
    df_td = df.dropna(subset=[ref_col, std_col, corr_col])
    if df_td.empty:
        print("Not enough data for Taylor Diagram.")
        return

    # Reference STD is average of the reference col
    ref_std_val = df_td[ref_col].mean()
    # Determine top 10% correlation
    thr_cc = np.percentile(df_td[corr_col].values, 90)

    # Prepare diagram
    fig = plt.figure(figsize=(8,8))
    dia = TaylorDiagram(ref_std_val, fig=fig, rect=111, label='OBS')
    # Add RMSE contours
    ctn = dia.add_contours(levels=6)
    plt.clabel(ctn, inline=1, fontsize=10)

    # We'll color the top-10% correlation samples differently.
    # Also show only the first word of station name in legend
    norm = Normalize(vmin=0, vmax=len(df_td))
    color_map = cm.get_cmap("tab20", len(df_td))

    for i, row in enumerate(df_td.itertuples()):
        stdev = getattr(row, std_col)
        corr  = getattr(row, corr_col)
        stn   = getattr(row, stn_name_col)
        short_label = short_station_name(stn)

        # If correlation >= thr_cc => good performing => let's color them
        # else use a default style
        if corr >= thr_cc:
            # Different color and marker
            mk_style = dict(marker='o', ms=6, 
                            mec=color_map(i), mfc='none', mew=1.6,
                            label=short_label)
        else:
            mk_style = dict(marker='o', ms=4, 
                            mec='gray', mfc='none', mew=1,
                            label='_')  # underscore => not in legend

        dia.add_sample(stdev, corr, **mk_style)

    # Build legend from the sample points that have label != "_"
    labels_all = [p.get_label() for p in dia.samplePoints]
    handles_all= [p for p in dia.samplePoints]
    # Filter out the undesired ones
    final_pairs = [(h,l) for (h,l) in zip(handles_all,labels_all) if l!="_"]
    if final_pairs:
        handles_ok, labels_ok = zip(*final_pairs)
        plt.legend(handles_ok, labels_ok, numpoints=1, prop=dict(size=6), 
                   loc='upper right', title="Best CC Grids")

    # Title
    dia._ax.set_title("Taylor Diagram MERRA2", fontsize=12, fontweight="bold")

    out_fn = os.path.join(output_dir, "Improved_TaylorDiagram.png")
    plt.savefig(out_fn, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved improved Taylor Diagram => {out_fn}")

##############################################################################
# 8. CREATE IMPROVED TAYLOR DIAGRAM
##############################################################################
# We'll use 'STD' as the reference column & std col, 'CC' as correlation
# If you have a separate obs STD column, rename above
create_taylor_diagram(df_merged, ref_col="STD", std_col="STD", corr_col="CC", stn_name_col="station_name")

##############################################################################
# 9. OUTPUT A “METRICS EVALUATION” TABLE
##############################################################################
print("\nFinal Metrics Evaluation Table (All Stations):")
cols_for_eval = ["MBE","RMSE","STD","CC","d"]
avail_cols = [c for c in cols_for_eval if c in df_merged.columns]
eval_table = df_merged[avail_cols].agg(["count","mean","std","min","max"])
print(eval_table)

eval_table_out = os.path.join(output_dir,"Overall_Metrics_Evaluation.csv")
eval_table.to_csv(eval_table_out)
print(f"Saved overall metrics evaluation => {eval_table_out}")

print("\n✅ Done! Renamed dimension for NetCDF, renamed Index_of_Agreement to 'd', produced improved Taylor Diagram, and exported everything.")
