In [None]:
# ─────────────────────────────────────────────────────────────────────────────
#  PRISM precipitation · 8-nearest locally weighted regression to stations
# ─────────────────────────────────────────────────────────────────────────────

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import warnings
import os

###############################################################################
# 1. CONFIGURATIONS
###############################################################################
# Input files
station_file   = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\prcp_data.nc'
prism_file     = r'D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\PRISM_prcp_GLB_1991-2013_with_Elevation.nc'
physical_file  = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv'
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"   #For plotting
lakes_shp      = r'D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp' #For plotting
target_crs     = "ESRI:102008"

# Paths for daily loop & metrics (same directory as PRISM file)
daily_loop_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\daily_loop3"
metrics_dir       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\PRISM_GLB_Precipitation\metrics3"

os.makedirs(daily_loop_dir, exist_ok=True)
os.makedirs(metrics_dir,    exist_ok=True)

# Variable names
obs_var_name   = 'prcp'   # station variable
prism_var_name = 'prcp'   # PRISM variable

# LWR settings
NEAREST = 8               # ← add this line (number of neighbouring grid-cells)

# Time range
start_date, end_date = '1991-01-01', '2012-12-31'

###############################################################################
# 2. DISTANCE & WEIGHT FUNCTIONS
###############################################################################
def haversine_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = (np.sin(dlat / 2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2)**2)
    c = 2 * np.arcsin(np.sqrt(a))
    return 6371 * c

def tricube_weight(distances, d_max):
    if d_max == 0:
        return np.ones_like(distances)
    ratio = distances / d_max
    w = (1 - ratio**3)**3
    w[distances > d_max] = 0.0
    return w

###############################################################################
# 3. OPTIONAL: FORCE prism TIME
###############################################################################
def force_prism_time(ds):
    """
    If needed, convert time coordinate to datetime.
    If ds['time'] is already a standard datetime, skip or simplify.
    """
    if 'time' not in ds.coords and 'time' in ds.dims:
        ds = ds.assign_coords(time=ds['time'])
    ds['time'] = pd.to_datetime(ds['time'].values)
    return ds

###############################################################################
# 4. LOAD DATASETS
###############################################################################
print("Loading station observations (NetCDF) ...")
obs_ds = xr.open_dataset(station_file)

print("Loading prism reanalysis (NetCDF) ...")
prism_ds = xr.open_dataset(prism_file)

# Fix station time coords if needed
if 'time' in obs_ds.coords:
    obs_ds['time'] = pd.to_datetime(obs_ds['time'].values)

# Force prism time if needed
prism_ds = force_prism_time(prism_ds)

# Subset to 1991–2012
obs_ds   = obs_ds.sel(time=slice(start_date, end_date))
prism_ds = prism_ds.sel(time=slice(start_date, end_date))

print("After subsetting:")
print(f"obs_ds time steps = {obs_ds.sizes['time']}")
print(f"prism_ds time steps = {prism_ds.sizes['time']}")

###############################################################################
# 5. ENSURE EACH VARIABLE IS (time, lat, lon)
###############################################################################
d = prism_ds[prism_var_name].dims

# ── native PRISM order -------------------------------------------------------
if d == ('lat', 'lon', 'time'):
    prism_ds[prism_var_name] = prism_ds[prism_var_name].transpose('time',
                                                                  'lat', 'lon')

# ── another common permutation ----------------------------------------------
elif d == ('time', 'lon', 'lat'):
    prism_ds[prism_var_name] = prism_ds[prism_var_name].transpose('time',
                                                                  'lat', 'lon')

# ── already correct (time,lat,lon) ------------------------------------------
elif d == ('time', 'lat', 'lon'):
    pass

# ── anything else is unexpected ---------------------------------------------
else:
    warnings.warn(f"Unexpected dim order for {prism_var_name}: {d}")

################################################################################
# 5-B. LOAD VARIABLES & OBSERVATIONS INTO MEMORY (for speed) + GRID COORDS
################################################################################
print("Reading PRISM and station arrays into RAM …")
prism_arr = prism_ds[prism_var_name].values          # (time, lat, lon)
obs_arr   = obs_ds [obs_var_name ].values            # (time, station)

# --- grid coordinates --------------------------------------------------------
lats = prism_ds['lat'].values
lons = prism_ds['lon'].values
lon2d, lat2d = np.meshgrid(lons, lats)

grid_elev = (prism_ds['elevation'].values
             if 'elevation' in prism_ds else np.zeros_like(lon2d))

grid_lat_flat  = lat2d.ravel()
grid_lon_flat  = lon2d.ravel()
grid_elev_flat = grid_elev.ravel()

# keep station names before closing the files
station_names_nc = obs_ds['station'].values

prism_ds.close()
obs_ds.close()

###############################################################################
# 6. EXTRACT PRISM GRID
###############################################################################
#lats = prism_ds['lat'].values
#lons = prism_ds['lon'].values
#lon2d, lat2d = np.meshgrid(lons, lats)

#if 'elevation' in prism_ds:
#    grid_elev = prism_ds['elevation'].values
#else:
#    warnings.warn("No 'elevation' found in prism dataset => set elev=0.")
#    grid_elev = np.zeros_like(lat2d)

#grid_lat_flat  = lat2d.flatten()
#grid_lon_flat  = lon2d.flatten()
#grid_elev_flat = grid_elev.flatten()

###############################################################################
# 7. LOAD STATIONS & MATCH
###############################################################################
print("Loading station metadata (CSV) …")
stations_df = pd.read_csv(physical_file).dropna(axis=1, how='all')

station_index_map = {name: i for i, name in enumerate(station_names_nc)}
stations_df['netcdf_index'] = stations_df['NAME'].map(station_index_map)

stations_df = (stations_df
               .dropna(subset=['netcdf_index'])
               .reset_index(drop=True)
               .astype({'netcdf_index': int}))
print(f"Total matched stations: {len(stations_df)}")


###############################################################################
# 8. PRECOMPUTE THE 8 NEAREST GRIDS FOR EACH STATION
###############################################################################
station_neighbors = {}
print(f"Pre-computing the {NEAREST} nearest PRISM grids for each station …")

for i, row in stations_df.iterrows():
    d = haversine_distance(row['LATITUDE'], row['LONGITUDE'],
                           grid_lat_flat, grid_lon_flat)
    station_neighbors[i] = np.argsort(d)[:NEAREST]

print("Neighbour pre-computation complete.")


###############################################################################
# 9. LOCAL-WEIGHTED REGRESSION FUNCTION  (robust to NaNs & rank-deficiency)
###############################################################################
def local_weighted_regression(
    station_lat, station_lon, station_elev,
    neighbor_indices,
    grid_lat, grid_lon, grid_elev, grid_val
):
    """
    Return the LWR estimate for one station on one day, using an
    inverse-distance tricube-weighted multiple-linear model:

        value ~ 1 + lat + lon + elev

    • If fewer than 3 valid neighbours remain after NaN-filtering, the
      function returns NaN.
    • If np.linalg.lstsq fails (rare), it falls back to a simple
      distance-weighted mean.
    • Negative predictions are clipped to 0 (precipitation can’t be < 0).
    """
    # ── sanity check ---------------------------------------------------------
    if neighbor_indices is None or len(neighbor_indices) == 0:
        return np.nan                           # no neighbours → cannot interpolate

    # ── slice neighbour data -------------------------------------------------
    lat_n   = grid_lat [neighbor_indices]
    lon_n   = grid_lon [neighbor_indices]
    elev_n  = grid_elev[neighbor_indices]
    val_n   = grid_val [neighbor_indices]

    # ── discard neighbours whose value is NaN --------------------------------
    valid   = ~np.isnan(val_n)
    if valid.sum() < 3:                         # need ≥3 points for 4-parameter fit
        return np.nan
    lat_n, lon_n, elev_n, val_n = (
        lat_n[valid], lon_n[valid], elev_n[valid], val_n[valid]
    )

    # ── tricube weights from great-circle distance ---------------------------
    dist    = haversine_distance(station_lat, station_lon, lat_n, lon_n)
    w       = tricube_weight(dist, dist.max())
    if np.all(w == 0.0):
        return np.nan

    # ── weighted multiple-linear regression ----------------------------------
    X       = np.column_stack([np.ones_like(lat_n), lat_n, lon_n, elev_n])
    sqrt_w  = np.sqrt(w)

    try:
        beta, *_ = np.linalg.lstsq(
            X * sqrt_w[:, None],                # weight columns
            val_n * sqrt_w,                     # weight response
            rcond=None
        )
        pred = np.dot([1, station_lat, station_lon, station_elev], beta)
    except np.linalg.LinAlgError:               # rare SVD failure
        pred = np.average(val_n, weights=w)     # fallback: weighted mean

    return max(pred, 0.0)                       # precipitation cannot be negative


###############################################################################
# 10. DAILY LOOP — PRISM → STATIONS (fast, RAM-based)
###############################################################################
results  = []
n_days   = prism_arr.shape[0]
print(f"Processing {n_days} days …")

for t_idx in range(n_days):
    curr_time = pd.to_datetime(start_date) + pd.Timedelta(days=int(t_idx))

    grid_flat = prism_arr[t_idx].ravel()   # (lat*lon,)
    obs_day   = obs_arr[t_idx]             # (station,)

    for i, row in stations_df.iterrows():
        pred = local_weighted_regression(
            row['LATITUDE'], row['LONGITUDE'], row.get('Elevation', 0.0),
            station_neighbors[i],
            grid_lat_flat, grid_lon_flat, grid_elev_flat, grid_flat
        )

        results.append({
            'time'         : curr_time,
            'station_index': i,
            'station_name' : row['NAME'],
            'obs'          : obs_day[row['netcdf_index']],
            'prism_lwr8_val': pred
        })

    if (t_idx + 1) % 500 == 0 or t_idx == n_days-1:
        print(f"  {t_idx+1}/{n_days} days done")


###############################################################################
# 11. BUILD A DATAFRAME & SAVE
###############################################################################
results_df = pd.DataFrame(results)
results_df['time'] = pd.to_datetime(results_df['time'])
print("Sample of daily results:\n", results_df.head())

# CSV output with the 8-nearest LWR data
daily_loop_csv = os.path.join(daily_loop_dir, "prism_vs_stations_8Nearest_LWR_1991_2012.csv")
results_df.to_csv(daily_loop_csv, index=False)
print(f"Daily interpolation results saved to {daily_loop_csv}")

###############################################################################
# 12. (OPTIONAL) MAPPING
###############################################################################
try:
    print("Creating a quick map of average ERA5 precipitation at each station (8-nearest LWR)...")
    glb  = gpd.read_file(shapefile_path).to_crs(target_crs)
    lakes= gpd.read_file(lakes_shp).to_crs(target_crs)

    station_mean = results_df.groupby('station_index')['prism_lwr8_val'].mean().reset_index()
    merged = stations_df.copy()
    merged['mean_prism_lwr8'] = station_mean['prism_lwr8_val']

    geometry = [Point(lon, lat) for lon, lat in zip(merged['LONGITUDE'], merged['LATITUDE'])]
    stations_gdf = gpd.GeoDataFrame(merged, geometry=geometry, crs="EPSG:4326").to_crs(target_crs)

    fig, ax = plt.subplots(figsize=(10, 8))
    glb.boundary.plot(ax=ax, edgecolor='black')
    lakes.plot(ax=ax, color='blue', alpha=0.5)
    stations_gdf.plot(column='mean_prism_lwr8', ax=ax, legend=True, cmap='viridis', markersize=50)
    ax.set_title("Mean prism Precip (8-Nearest LWR), 1991-2012")
    plt.show()
except Exception as e:
    print("Mapping step skipped due to error or missing shapefiles:", str(e))

###############################################################################
# 13. METRICS & SAVING
###############################################################################
def remove_nan_pairs(obs, pred):
    mask = ~np.isnan(obs) & ~np.isnan(pred)
    return obs[mask], pred[mask]

def mean_bias_error(obs, pred):
    return np.mean(pred - obs)

def root_mean_square_error(obs, pred):
    return np.sqrt(np.mean((pred - obs)**2))

def std_of_residuals(obs, pred):
    return np.std(pred - obs, ddof=1)

def pearson_correlation(obs, pred):
    if len(obs) < 2:
        return np.nan
    return np.corrcoef(obs, pred)[0, 1]

def index_of_agreement(obs, pred):
    obs_mean = np.mean(obs)
    numerator = np.sum((pred - obs)**2)
    denominator = np.sum((np.abs(pred - obs_mean) + np.abs(obs - obs_mean))**2)
    if denominator == 0:
        return np.nan
    return 1 - numerator / denominator

obs_all = results_df['obs'].values
prism_all = results_df['prism_lwr8_val'].values
obs_all, prism_all = remove_nan_pairs(obs_all, prism_all)

if len(obs_all) > 0:
    metrics_all = {
        'MBE': mean_bias_error(obs_all, prism_all),
        'RMSE': root_mean_square_error(obs_all, prism_all),
        'STD': std_of_residuals(obs_all, prism_all),
        'CC': pearson_correlation(obs_all, prism_all),
        'Index_of_Agreement': index_of_agreement(obs_all, prism_all)
    }
    print("\nOverall Metrics (All Stations, prism 8-Nearest LWR, 1991-2012):")
    for k, v in metrics_all.items():
        print(f"{k}: {v:.4f}")
else:
    print("\nNo valid (obs, prism_lwr8_val) pairs found for overall metrics.")

# Station-level metrics
station_groups = results_df.groupby('station_index')
per_station_metrics = []

for st_idx, grp in station_groups:
    obs_st = grp['obs'].values
    prism_st = grp['prism_lwr8_val'].values
    obs_st, prism_st = remove_nan_pairs(obs_st, prism_st)
    if len(obs_st) == 0:
        per_station_metrics.append({
            'station_index': st_idx,
            'station_name': grp['station_name'].iloc[0],
            'MBE': np.nan,
            'RMSE': np.nan,
            'STD': np.nan,
            'CC': np.nan,
            'Index_of_Agreement': np.nan
        })
        continue

    row_dict = {
        'station_index': st_idx,
        'station_name': grp['station_name'].iloc[0],
        'MBE':  mean_bias_error(obs_st, prism_st),
        'RMSE': root_mean_square_error(obs_st, prism_st),
        'STD':  std_of_residuals(obs_st, prism_st),
        'CC':   pearson_correlation(obs_st, prism_st),
        'Index_of_Agreement': index_of_agreement(obs_st, prism_st)
    }
    per_station_metrics.append(row_dict)

metrics_df = pd.DataFrame(per_station_metrics)
print("\nSample of per-station metrics:")
print(metrics_df.head())

metrics_csv = os.path.join(metrics_dir, "station_metrics_8Nearest_LWR_prism_1991_2012.csv")
metrics_df.to_csv(metrics_csv, index=False)
print(f"Station metrics saved to {metrics_csv}")

###############################################################################
# 14. (OPTIONAL) SAVE FULL DAILY RESULTS AS NETCDF
###############################################################################
try:
    print("\nSaving daily station results to NetCDF file in daily_loop folder...")
    netcdf_file = os.path.join(daily_loop_dir, "prism_vs_stations_8Nearest_LWR_1991_2012.nc")

    # 1) Pivot so 'time' is the row index, 'station_index' are columns
    pivot_prism = results_df.pivot(index='time', columns='station_index', values='prism_lwr8_val')
    pivot_obs  = results_df.pivot(index='time', columns='station_index', values='obs')

    # 2) Ensure the pivoted index is recognized as real datetime
    pivot_prism.index = pd.to_datetime(pivot_prism.index, errors='coerce')
    pivot_obs.index  = pd.to_datetime(pivot_obs.index,  errors='coerce')

    # 3) Convert each pivoted DataFrame into an xarray Dataset
    ds_prism = pivot_prism.to_xarray()
    ds_obs  = pivot_obs.to_xarray()

    if 'index' in ds_prism.dims:
        ds_prism = ds_prism.rename({'index': 'time'})
    if 'columns' in ds_prism.dims:
        ds_prism = ds_prism.rename({'columns': 'station_index'})

    if 'index' in ds_obs.dims:
        ds_obs = ds_obs.rename({'index': 'time'})
    if 'columns' in ds_obs.dims:
        ds_obs = ds_obs.rename({'columns': 'station_index'})

    da_prism = ds_prism.to_array(name='prism_lwr8_val').squeeze()
    da_obs  = ds_obs.to_array(name='obs').squeeze()

    ds_out = xr.Dataset({
        'prism_lwr8_val': da_prism,
        'obs':           da_obs
    })

    ds_out.to_netcdf(netcdf_file)
    print(f"NetCDF saved to: {netcdf_file}")

except Exception as e:
    print("Error saving NetCDF:", e)

print("\nDone. The 8-nearest LWR interpolation for prism dataset is complete!")

In [None]:
#EMDNA
# The approach of making the interpolated EMDNA data at the station location with LWR 25 km for prcp "10 ensembles"

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import warnings
import os

###############################################################################
# 0. ENSEMBLE LIST  &  ROOT PATH
###############################################################################
ENSEMBLES = [1, 11, 21, 31, 41, 51, 61, 71, 81, 91]

root_dir = (
    r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder"
    r"\Ensemble files\EMDNA_GLB_Precipitation"
)

###############################################################################
# 1. CONFIGURATIONS
###############################################################################
for ENS in ENSEMBLES:
    print(f"\n================  ENSEMBLE {ENS}  ===========================")

    # ---- generic, ensemble-independent files --------------------------------
    station_file  = r"D:\PhD\GLB\Merged USA and CA\Entire GLB\prcp_data.nc"
    physical_file = (
        r"D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv"
    )
    shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
    lakes_shp      = r"D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp"
    target_crs     = "ESRI:102008"

    # ---- ensemble-specific paths & files ------------------------------------
    ens_dir     = os.path.join(root_dir, str(ENS))
    emdna_file  = os.path.join(
        ens_dir,
        f"EMDNA_{ENS:03d}_merged_prcp_1991_2013_with_Elevation.nc"
    )

    daily_loop_dir = os.path.join(ens_dir, "daily_loop")
    metrics_dir    = os.path.join(ens_dir, "metrics")
    os.makedirs(daily_loop_dir, exist_ok=True)
    os.makedirs(metrics_dir,    exist_ok=True)

    # ─── variables we will handle ───────────────────────────────────────────────
    VAR_LIST = ["prcp"]                 # ◄─ just precipitation this time
    
    # Local-Weighted Regression settings
    search_radius_km = 25
    min_points       = 5
    
    start_date, end_date = "1991-01-01", "2012-12-31"
    
    ###############################################################################
    # 2. DISTANCE & WEIGHT FUNCTIONS
    ###############################################################################
    def haversine_distance(lat1, lon1, lat2, lon2):
        lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
        dlat = lat2 - lat1
        dlon = lon2 - lon1
        a = (np.sin(dlat / 2)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon / 2)**2)
        c = 2*np.arcsin(np.sqrt(a))
        return 6371 * c
    
    def tricube_weight(distances, d_max):
        """
        w_j = [1 - (dist_j / d_max)^3]^3  for dist_j <= d_max; else 0
        """
        if d_max == 0:
            return np.ones_like(distances)
        ratio = distances / d_max
        w = (1 - ratio**3)**3
        w[distances > d_max] = 0.0
        return w
    
    ###############################################################################
    # 3. FORCE EMDNA TIME (IF NEEDED)
    ###############################################################################
    def force_emdna_time(ds):
        if 'time' not in ds.coords and 'time' in ds.dims:
            ds = ds.assign_coords(time=ds['time'])
        if np.issubdtype(ds['time'].dtype, np.integer):
            day0 = pd.to_datetime("1991-01-01")
            numeric_days = ds['time'].values
            real_times = [day0 + pd.Timedelta(days=int(d)) for d in numeric_days]
            ds = ds.assign_coords(time=("time", real_times))
        ds['time'] = pd.to_datetime(ds['time'].values)
        return ds
    
    print("Loading station observations (NetCDF) ...")
    obs_ds = xr.open_dataset(station_file)
    
    print("Loading EMDNA reanalysis (NetCDF) ...")
    emdna_ds = xr.open_dataset(emdna_file)
    
    # ─── fix / force time coordinates ───────────────────────────────────────────
    if 'time' in obs_ds.coords:
        obs_ds['time'] = pd.to_datetime(obs_ds['time'].values)
    emdna_ds = force_emdna_time(emdna_ds)
    
    # ─── convert rotated-pole coordinates if present ────────────────────────────
    if {'rlat', 'rlon'}.issubset(emdna_ds.coords):
        print("Renaming rlat/rlon ➜ lat/lon ...")
        emdna_ds = emdna_ds.rename({'rlat': 'lat', 'rlon': 'lon'})
    
    # ─── subset analysis period ─────────────────────────────────────────────────
    obs_ds  = obs_ds.sel(time=slice(start_date, end_date))
    emdna_ds = emdna_ds.sel(time=slice(start_date, end_date))
    
    print("After subsetting:")
    print(f"obs_ds time steps  = {obs_ds.sizes['time']}")
    print(f"emdna_ds time steps = {emdna_ds.sizes['time']}")
    
    # ─── ensure each variable is (time, lat, lon) ───────────────────────────────
    for v in VAR_LIST:
        dims_now = emdna_ds[v].dims
        print(f"Current dims for EMDNA {v}: {dims_now}")
        if dims_now == ('time', 'lon', 'lat'):
            print(f"Transposing {v} from (time,lon,lat) ➜ (time,lat,lon)")
            emdna_ds[v] = emdna_ds[v].transpose('time', 'lat', 'lon')
        elif dims_now != ('time', 'lat', 'lon'):
            warnings.warn(f"Unexpected dimension order for {v}: {dims_now}")
    
    ###############################################################################
    # 3-B. READ THE VARIABLE INTO MEMORY (avoids slow disk hit each day)
    ###############################################################################
    print("Loading EMDNA and station-obs arrays into memory …")
    emdna_arr = {v: emdna_ds[v].values for v in VAR_LIST}   # dict: v → (time,lat,lon)
    obs_arr   = {v:  obs_ds[v].values for v in VAR_LIST}    # dict: v → (time,station)
    
    # ─── build flattened coordinate arrays for neighbour search ────────────────
    lats = emdna_ds['lat'].values
    lons = emdna_ds['lon'].values
    lon2d, lat2d = np.meshgrid(lons, lats)
    
    if 'elevation' in emdna_ds:
        grid_elev = emdna_ds['elevation'].values
    else:
        warnings.warn("No 'elevation' variable in EMDNA file – using zeros.")
        grid_elev = np.zeros_like(lat2d)
    
    grid_lat_flat  = lat2d.ravel()
    grid_lon_flat  = lon2d.ravel()
    grid_elev_flat = grid_elev.ravel()
    
    # ── keep station names *before* closing the NetCDF handles ────────────────
    station_names_nc = obs_ds['station'].values
    
    # all arrays & coords are now in RAM – close the files
    emdna_ds.close()
    obs_ds.close()
    
    ###############################################################################
    # 4. LOAD STATIONS & MATCH
    ###############################################################################
    print("Loading station metadata (CSV) ...")
    stations_df = pd.read_csv(physical_file).dropna(axis=1, how='all')
    print(stations_df.head())
    
    print("Matching station names to NetCDF index ...")
    station_index_map = {name: i for i, name in enumerate(station_names_nc)}
    stations_df['netcdf_index'] = stations_df['NAME'].map(station_index_map)
    stations_df = stations_df.dropna(subset=['netcdf_index']).reset_index(drop=True)
    stations_df['netcdf_index'] = stations_df['netcdf_index'].astype(int)
    print("Total matched stations:", len(stations_df))
    
    ###############################################################################
    # 5. PRECOMPUTE NEIGHBORS FOR EACH STATION
    ###############################################################################
    station_neighbors = {}
    print(f"Precomputing neighbors within {search_radius_km} km...")
    
    for i, row in stations_df.iterrows():
        st_lat = row['LATITUDE']
        st_lon = row['LONGITUDE']
        dist_all = haversine_distance(st_lat, st_lon, grid_lat_flat, grid_lon_flat)
        neighbor_idx = np.where(dist_all <= search_radius_km)[0]
        if len(neighbor_idx) < min_points:
            station_neighbors[i] = None
        else:
            station_neighbors[i] = neighbor_idx
    
    print("Neighbor precomputation complete.")
    
    ###############################################################################
    # 6. LOCAL WEIGHTED REGRESSION FUNCTION
    ###############################################################################
    def local_weighted_regression(
        station_lat, station_lon, station_elev,
        neighbor_indices,
        grid_lat, grid_lon, grid_elev, grid_val
    ):
        """
        Predict daily precipitation at one station with a 25-km LWR:
    
            prcp ~ 1 + lat + lon + elev
    
        • Weights: tricube of great-circle distance.
        • Requires ≥ min_points valid neighbours; otherwise returns NaN.
        • Negative predictions are *not* allowed → clipped to 0.0 mm.
        """
        if neighbor_indices is None or len(neighbor_indices) < min_points:
            return np.nan
    
        lat_n  = grid_lat [neighbor_indices]
        lon_n  = grid_lon [neighbor_indices]
        elev_n = grid_elev[neighbor_indices]
        val_n  = grid_val [neighbor_indices]
    
        mask = ~np.isnan(val_n)
        if mask.sum() < min_points:
            return np.nan
        lat_n, lon_n, elev_n, val_n = lat_n[mask], lon_n[mask], elev_n[mask], val_n[mask]
    
        dist = haversine_distance(station_lat, station_lon, lat_n, lon_n)
        w    = tricube_weight(dist, dist.max())
        if np.all(w == 0):
            return np.nan
    
        X       = np.column_stack([np.ones_like(lat_n), lat_n, lon_n, elev_n])
        sqrt_w  = np.sqrt(w)
        try:
            beta, *_ = np.linalg.lstsq(X * sqrt_w[:, None], val_n * sqrt_w, rcond=None)
            pred = np.dot([1, station_lat, station_lon, station_elev], beta)
        except np.linalg.LinAlgError:
            pred = np.average(val_n, weights=w)
    
        return max(pred, 0.0)           # precipitation cannot be negative
    ###############################################################################
    # 7. DAILY LOOP  – EMDNA ➜ STATIONS  (RAM-based, no disk reads in loop)
    ###############################################################################
    results  = []
    n_days   = emdna_arr[VAR_LIST[0]].shape[0]
    print(f"Processing {n_days} days …")
    
    for t_index in range(n_days):
        current_time = pd.to_datetime(start_date) + pd.Timedelta(days=t_index)
    
        for var in VAR_LIST:
            # (lat,lon) slice for this day already in memory
            day_data_2d       = emdna_arr[var][t_index, :, :]
            grid_val_flat_day = day_data_2d.ravel()
    
            obs_day = obs_arr[var][t_index, :]   # shape (station,)
    
            for i, row in stations_df.iterrows():
                netcdf_idx = row["netcdf_index"]
                st_obs     = obs_day[netcdf_idx] if netcdf_idx < len(obs_day) else np.nan
    
                st_emdna = local_weighted_regression(
                    station_lat  = row["LATITUDE"],
                    station_lon  = row["LONGITUDE"],
                    station_elev = row.get("Elevation", 0.0),
                    neighbor_indices = station_neighbors[i],
                    grid_lat  = grid_lat_flat,
                    grid_lon  = grid_lon_flat,
                    grid_elev = grid_elev_flat,
                    grid_val  = grid_val_flat_day
                )
    
                results.append({
                    "time"         : current_time,
                    "var"          : var,
                    "station_index": i,
                    "station_name" : row["NAME"],
                    "obs"          : st_obs,
                    "emdna_lwr25_val": st_emdna
                })
    
        if (t_index + 1) % 100 == 0:
            print(f"Processed {t_index+1} / {n_days} days")
    
    
    ###############################################################################
    # 8. BUILD A DATAFRAME & SAVE
    ###############################################################################
    results_df            = pd.DataFrame(results)
    results_df['time']    = pd.to_datetime(results_df['time'])
    print("Sample of daily results:\n", results_df.head())
    
    daily_loop_csv = os.path.join(
        daily_loop_dir,
        f"emdna_vs_stations_25km_LWR_1991_2012_prcp_{ENS:03d}.csv"
    )
    results_df.to_csv(daily_loop_csv, index=False)
    print(f"Daily interpolation results saved to {daily_loop_csv}")
    
    ###############################################################################
    # 9. (OPTIONAL) MAPPING  – one map per variable
    ###############################################################################
    try:
        glb   = gpd.read_file(shapefile_path).to_crs(target_crs)
        lakes = gpd.read_file(lakes_shp).to_crs(target_crs)
    
        for var in VAR_LIST:
            print(f"Creating map for {var} …")
            sub = results_df[results_df["var"] == var]
    
            # mean value at each station_index
            station_mean = (
                sub.groupby("station_index")["emdna_lwr25_val"]
                   .mean()
                   .reset_index()
                   .rename(columns={"emdna_lwr25_val": "mean_emdna"})
            )
    
            # join to station metadata so rows line up correctly
            merged = stations_df.merge(station_mean, on="station_index", how="left")
    
            geometry = [
                Point(lon, lat)
                for lon, lat in zip(merged["LONGITUDE"], merged["LATITUDE"])
            ]
            stations_gdf = (
                gpd.GeoDataFrame(merged, geometry=geometry, crs="EPSG:4326")
                .to_crs(target_crs)
            )
    
            fig, ax = plt.subplots(figsize=(10, 8))
            glb.boundary.plot(ax=ax, edgecolor="black")
            lakes.plot(ax=ax, color="blue", alpha=0.5)
            stations_gdf.plot(column="mean_emdna", ax=ax,
                              cmap="viridis", legend=True, markersize=50)
            ax.set_title(f"Mean EMDNA {var.upper()} (LWR 25 km), 1991-2012")
            plt.show()
    
    except Exception as e:
        print("Mapping step skipped:", e)
    
    ###############################################################################
    # 10. METRICS & SAVING
    ###############################################################################
    def remove_nan_pairs(obs, pred):
        m = ~np.isnan(obs) & ~np.isnan(pred)
        return obs[m], pred[m]
    
    # ─── overall metrics per variable ───────────────────────────────────────────
    for var, grp in results_df.groupby('var'):
        obs_all, emdna_all = remove_nan_pairs(grp['obs'].values,
                                             grp['emdna_lwr25_val'].values)
        if len(obs_all) == 0:
            print(f"\nNo valid pairs for {var}.")
            continue
        mbe  = np.mean(emdna_all - obs_all)
        rmse = np.sqrt(np.mean((emdna_all - obs_all) ** 2))
        std  = np.std(emdna_all - obs_all, ddof=1)
        cc   = np.corrcoef(obs_all, emdna_all)[0, 1]
        ia   = 1 - np.sum((emdna_all - obs_all) ** 2) / \
                  np.sum((np.abs(emdna_all - obs_all.mean()) +
                          np.abs(obs_all - obs_all.mean())) ** 2)
        print(f"\nOverall metrics ({var}, EMDNA 25-km LWR, 1991-2012):")
        for k, v in zip(['MBE','RMSE','STD','CC','Index_of_Agreement'],
                        [mbe, rmse, std, cc, ia]):
            print(f"{k}: {v:.4f}")
    
    # ─── per-station metrics (still per variable) ───────────────────────────────
    station_metrics = []
    for (var, st_idx), g in results_df.groupby(['var', 'station_index']):
        o, p = remove_nan_pairs(g['obs'].values, g['emdna_lwr25_val'].values)
        if len(o) == 0:
            station_metrics.append({'var':var,'station_index':st_idx,
                'station_name':g['station_name'].iloc[0],
                'MBE':np.nan,'RMSE':np.nan,'STD':np.nan,'CC':np.nan,'Index_of_Agreement':np.nan})
            continue
        station_metrics.append({
            'var'  : var,
            'station_index' : st_idx,
            'station_name'  : g['station_name'].iloc[0],
            'MBE'  : np.mean(p - o),
            'RMSE' : np.sqrt(np.mean((p - o) ** 2)),
            'STD'  : np.std(p - o, ddof=1),
            'CC'   : np.corrcoef(o, p)[0, 1],
            'Index_of_Agreement':
                1 - np.sum((p - o) ** 2) /
                    np.sum((np.abs(p - o.mean()) + np.abs(o - o.mean())) ** 2)
        })
    
    metrics_df  = pd.DataFrame(station_metrics)

    metrics_csv = os.path.join(
        metrics_dir,
        f"station_metrics_25km_LWR_EMDNA_1991_2012_prcp_{ENS:03d}.csv"
    )
    metrics_df.to_csv(metrics_csv, index=False)
    print(f"Station metrics saved to {metrics_csv}")
    
    # (Optional) NetCDF
    ###############################################################################
    # 10. (OPTIONAL) SAVE FULL DAILY RESULTS AS NETCDF
    ###############################################################################
    try:
        print("\nSaving daily station results to NetCDF …")
        netcdf_file = os.path.join(
        daily_loop_dir,
        f"emdna_vs_stations_25km_LWR_1991_2012_prcp_{ENS:03d}.nc"
    )
    
        ds_vars = {}
        for var in VAR_LIST:
            pivot_emdna = (
                results_df[results_df["var"] == var]
                .pivot(index="time", columns="station_index", values="emdna_lwr25_val")
                .sort_index()
            )
            pivot_obs  = (
                results_df[results_df["var"] == var]
                .pivot(index="time", columns="station_index", values="obs")
                .sort_index()
            )
    
            common_cols   = pivot_emdna.columns.intersection(pivot_obs.columns)
            pivot_emdna, pivot_obs = pivot_emdna[common_cols], pivot_obs[common_cols]
    
            da_emdna = xr.DataArray(
                data   = pivot_emdna.values,
                dims   = ("time", "station_index"),
                coords = {"time": pivot_emdna.index,
                          "station_index": common_cols},
                name   = f"emdna_lwr25_{var}"
            )
            da_obs  = xr.DataArray(
                data   = pivot_obs.values,
                dims   = ("time", "station_index"),
                coords = {"time": pivot_obs.index,
                          "station_index": common_cols},
                name   = f"obs_{var}"
            )
            ds_vars[da_emdna.name] = da_emdna
            ds_vars[da_obs.name]  = da_obs
    
        xr.Dataset(ds_vars).to_netcdf(netcdf_file)
        print(f"NetCDF saved to: {netcdf_file}")
    
    except Exception as e:
        print("Error saving NetCDF:", e)




In [None]:
#ERA5 Precipitation
# This script performs an 8-nearest-grid LWR interpolation using ERA5 data.

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import warnings
import os

###############################################################################
# 1. CONFIGURATIONS
###############################################################################
# Input files
station_file   = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\prcp_data.nc'
era5_file      = r'D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\ERA5_GLB_prcp_daily_1991_2013_with_Elevation_mm.nc'
physical_file  = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv'
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r'D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp'
target_crs     = "ESRI:102008"

# Output directories
# Adjust to match your desired folders
# This example uses a path near the ERA5 file location

daily_loop_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\daily_loop"
metrics_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Total_Precipitaion\Modified\metrics"

os.makedirs(daily_loop_dir, exist_ok=True)
os.makedirs(metrics_dir,    exist_ok=True)

# Variable names
obs_var_name  = 'prcp'  # Station daily precipitation var
era5_var_name = 'prcp'  # ERA5 daily precipitation var (change if needed)

# Time range: 1991–2012
start_date = '1991-01-01'
end_date   = '2012-12-31'

###############################################################################
# 2. DISTANCE & WEIGHT FUNCTIONS
###############################################################################
def haversine_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = (np.sin(dlat / 2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2)**2)
    c = 2 * np.arcsin(np.sqrt(a))
    return 6371 * c

def tricube_weight(distances, d_max):
    if d_max == 0:
        return np.ones_like(distances)
    ratio = distances / d_max
    w = (1 - ratio**3)**3
    w[distances > d_max] = 0.0
    return w

###############################################################################
# 3. OPTIONAL: FORCE ERA5 TIME
###############################################################################
def force_era5_time(ds):
    """
    If needed, convert time coordinate to datetime.
    If ds['time'] is already a standard datetime, skip or simplify.
    """
    if 'time' not in ds.coords and 'time' in ds.dims:
        ds = ds.assign_coords(time=ds['time'])
    ds['time'] = pd.to_datetime(ds['time'].values)
    return ds

###############################################################################
# 4. LOAD DATASETS
###############################################################################
print("Loading station observations (NetCDF) ...")
obs_ds = xr.open_dataset(station_file)

print("Loading ERA5 reanalysis (NetCDF) ...")
era5_ds = xr.open_dataset(era5_file)

# Fix station time coords if needed
if 'time' in obs_ds.coords:
    obs_ds['time'] = pd.to_datetime(obs_ds['time'].values)

# Force ERA5 time if needed
era5_ds = force_era5_time(era5_ds)

# Subset to 1991–2012
obs_ds   = obs_ds.sel(time=slice(start_date, end_date))
era5_ds = era5_ds.sel(time=slice(start_date, end_date))

print("After subsetting:")
print(f"obs_ds time steps = {obs_ds.sizes['time']}")
print(f"era5_ds time steps = {era5_ds.sizes['time']}")

###############################################################################
# 5. IF dims are (time, lon, lat), transpose to (time, lat, lon)
###############################################################################
actual_dims = era5_ds[era5_var_name].dims
print("Current dims for ERA5 prcp:", actual_dims)
if actual_dims == ("time","lon","lat"):
    print("Transposing prcp from (time,lon,lat) -> (time,lat,lon).")
    era5_ds[era5_var_name] = era5_ds[era5_var_name].transpose("time","lat","lon")
    print("New dims:", era5_ds[era5_var_name].dims)
else:
    print("No transpose needed or already (time,lat,lon).")

###############################################################################
# 6. EXTRACT ERA5 GRID
###############################################################################
lats = era5_ds['lat'].values
lons = era5_ds['lon'].values
lon2d, lat2d = np.meshgrid(lons, lats)

if 'elevation' in era5_ds:
    grid_elev = era5_ds['elevation'].values
else:
    warnings.warn("No 'elevation' found in ERA5 dataset => set elev=0.")
    grid_elev = np.zeros_like(lat2d)

grid_lat_flat  = lat2d.flatten()
grid_lon_flat  = lon2d.flatten()
grid_elev_flat = grid_elev.flatten()

###############################################################################
# 7. LOAD STATIONS & MATCH
###############################################################################
print("Loading station metadata (CSV) ...")
stations_df = pd.read_csv(physical_file).dropna(axis=1, how='all')
print(stations_df.head())

# If station_file has dimension station, match by name
station_names_nc = obs_ds['station'].values
station_index_map = {name: i for i, name in enumerate(station_names_nc)}

stations_df['netcdf_index'] = stations_df['NAME'].map(station_index_map)
stations_df = stations_df.dropna(subset=['netcdf_index']).reset_index(drop=True)
stations_df['netcdf_index'] = stations_df['netcdf_index'].astype(int)
print("Total matched stations:", len(stations_df))

###############################################################################
# 8. PRECOMPUTE THE 8 NEAREST GRIDS FOR EACH STATION
###############################################################################
station_neighbors = {}
print("Precomputing the 8 nearest ERA5 grids for each station ...")

for i, row in stations_df.iterrows():
    st_lat = row['LATITUDE']
    st_lon = row['LONGITUDE']
    dist_all = haversine_distance(st_lat, st_lon, grid_lat_flat, grid_lon_flat)
    
    # Sort by distance and take the indices of the 8 nearest
    sorted_idx = np.argsort(dist_all)
    neighbor_idx = sorted_idx[:8]
    
    station_neighbors[i] = neighbor_idx

print("Neighbor precomputation complete.")

###############################################################################
# 9. LOCAL WEIGHTED REGRESSION FUNCTION
###############################################################################
def local_weighted_regression(
    station_lat, station_lon, station_elev,
    neighbor_indices,
    grid_lat, grid_lon, grid_elev, grid_val
):
    """
    Weighted linear regression of grid_val ~ [1, lat, lon, elev].
    Uses tricube weighting based on distances.
    Clip negative precipitation to zero.
    """
    if (neighbor_indices is None) or (len(neighbor_indices) < 1):
        return np.nan

    lat_n  = grid_lat[neighbor_indices]
    lon_n  = grid_lon[neighbor_indices]
    elev_n = grid_elev[neighbor_indices]
    val_n  = grid_val[neighbor_indices]

    dist_n = haversine_distance(station_lat, station_lon, lat_n, lon_n)
    d_max  = dist_n.max()
    w = tricube_weight(dist_n, d_max)
    if np.all(w == 0):
        return np.nan

    # Build design matrix X: [1, lat, lon, elev]
    X = np.column_stack([np.ones_like(lat_n), lat_n, lon_n, elev_n])

    sqrt_w = np.sqrt(w)
    X_w = X * sqrt_w[:, None]
    y_w = val_n * sqrt_w

    try:
        beta, residuals, rank, s = np.linalg.lstsq(X_w, y_w, rcond=None)
    except np.linalg.LinAlgError:
        return np.nan

    # Predict at station location
    X_station = np.array([1, station_lat, station_lon, station_elev])
    pred = X_station @ beta

    # Clip negative precipitation to zero
    if pred < 0:
        pred = 0.0

    return pred

###############################################################################
# 10. DAILY LOOP: ERA5 -> STATIONS
###############################################################################
results = []
era5_times = era5_ds['time'].values
print(f"Processing ERA5 days: {len(era5_times)}")

for t_index, t_val in enumerate(era5_times):
    current_time = pd.to_datetime(t_val)
    # shape => (lat, lon)
    day_data_2d = era5_ds[era5_var_name].isel(time=t_index).values
    grid_val_flat_day = day_data_2d.flatten()

    obs_day = obs_ds[obs_var_name].isel(time=t_index).values

    for i, row in stations_df.iterrows():
        st_name    = row['NAME']
        st_lat     = row['LATITUDE']
        st_lon     = row['LONGITUDE']
        st_elev    = row['Elevation'] if 'Elevation' in row else 0.0
        netcdf_idx = row['netcdf_index']

        st_obs = obs_day[netcdf_idx] if netcdf_idx < len(obs_day) else np.nan

        neigh_idx = station_neighbors[i]
        st_era5 = local_weighted_regression(
            station_lat=st_lat,
            station_lon=st_lon,
            station_elev=st_elev,
            neighbor_indices=neigh_idx,
            grid_lat=grid_lat_flat,
            grid_lon=grid_lon_flat,
            grid_elev=grid_elev_flat,
            grid_val=grid_val_flat_day
        )

        results.append({
            'time': current_time,
            'station_index': i,
            'station_name': st_name,
            'obs': st_obs,
            'era5_lwr8_val': st_era5
        })

    if (t_index + 1) % 100 == 0:
        print(f"Processed day {t_index+1} / {len(era5_times)}")

###############################################################################
# 11. BUILD A DATAFRAME & SAVE
###############################################################################
results_df = pd.DataFrame(results)
results_df['time'] = pd.to_datetime(results_df['time'])
print("Sample of daily results:\n", results_df.head())

# CSV output with the 8-nearest LWR data
daily_loop_csv = os.path.join(daily_loop_dir, "era5_vs_stations_8Nearest_LWR_1991_2012.csv")
results_df.to_csv(daily_loop_csv, index=False)
print(f"Daily interpolation results saved to {daily_loop_csv}")

###############################################################################
# 12. (OPTIONAL) MAPPING
###############################################################################
try:
    print("Creating a quick map of average ERA5 precipitation at each station (8-nearest LWR)...")
    glb  = gpd.read_file(shapefile_path).to_crs(target_crs)
    lakes= gpd.read_file(lakes_shp).to_crs(target_crs)

    station_mean = results_df.groupby('station_index')['era5_lwr8_val'].mean().reset_index()
    merged = stations_df.copy()
    merged['mean_era5_lwr8'] = station_mean['era5_lwr8_val']

    geometry = [Point(lon, lat) for lon, lat in zip(merged['LONGITUDE'], merged['LATITUDE'])]
    stations_gdf = gpd.GeoDataFrame(merged, geometry=geometry, crs="EPSG:4326").to_crs(target_crs)

    fig, ax = plt.subplots(figsize=(10, 8))
    glb.boundary.plot(ax=ax, edgecolor='black')
    lakes.plot(ax=ax, color='blue', alpha=0.5)
    stations_gdf.plot(column='mean_era5_lwr8', ax=ax, legend=True, cmap='viridis', markersize=50)
    ax.set_title("Mean ERA5 Precip (8-Nearest LWR), 1991-2012")
    plt.show()
except Exception as e:
    print("Mapping step skipped due to error or missing shapefiles:", str(e))

###############################################################################
# 13. METRICS & SAVING
###############################################################################
def remove_nan_pairs(obs, pred):
    mask = ~np.isnan(obs) & ~np.isnan(pred)
    return obs[mask], pred[mask]

def mean_bias_error(obs, pred):
    return np.mean(pred - obs)

def root_mean_square_error(obs, pred):
    return np.sqrt(np.mean((pred - obs)**2))

def std_of_residuals(obs, pred):
    return np.std(pred - obs, ddof=1)

def pearson_correlation(obs, pred):
    if len(obs) < 2:
        return np.nan
    return np.corrcoef(obs, pred)[0, 1]

def index_of_agreement(obs, pred):
    obs_mean = np.mean(obs)
    numerator = np.sum((pred - obs)**2)
    denominator = np.sum((np.abs(pred - obs_mean) + np.abs(obs - obs_mean))**2)
    if denominator == 0:
        return np.nan
    return 1 - numerator / denominator

obs_all = results_df['obs'].values
era5_all = results_df['era5_lwr8_val'].values
obs_all, era5_all = remove_nan_pairs(obs_all, era5_all)

if len(obs_all) > 0:
    metrics_all = {
        'MBE': mean_bias_error(obs_all, era5_all),
        'RMSE': root_mean_square_error(obs_all, era5_all),
        'STD': std_of_residuals(obs_all, era5_all),
        'CC': pearson_correlation(obs_all, era5_all),
        'Index_of_Agreement': index_of_agreement(obs_all, era5_all)
    }
    print("\nOverall Metrics (All Stations, ERA5 8-Nearest LWR, 1991-2012):")
    for k, v in metrics_all.items():
        print(f"{k}: {v:.4f}")
else:
    print("\nNo valid (obs, era5_lwr8_val) pairs found for overall metrics.")

# Station-level metrics
station_groups = results_df.groupby('station_index')
per_station_metrics = []

for st_idx, grp in station_groups:
    obs_st = grp['obs'].values
    era5_st = grp['era5_lwr8_val'].values
    obs_st, era5_st = remove_nan_pairs(obs_st, era5_st)
    if len(obs_st) == 0:
        per_station_metrics.append({
            'station_index': st_idx,
            'station_name': grp['station_name'].iloc[0],
            'MBE': np.nan,
            'RMSE': np.nan,
            'STD': np.nan,
            'CC': np.nan,
            'Index_of_Agreement': np.nan
        })
        continue

    row_dict = {
        'station_index': st_idx,
        'station_name': grp['station_name'].iloc[0],
        'MBE':  mean_bias_error(obs_st, era5_st),
        'RMSE': root_mean_square_error(obs_st, era5_st),
        'STD':  std_of_residuals(obs_st, era5_st),
        'CC':   pearson_correlation(obs_st, era5_st),
        'Index_of_Agreement': index_of_agreement(obs_st, era5_st)
    }
    per_station_metrics.append(row_dict)

metrics_df = pd.DataFrame(per_station_metrics)
print("\nSample of per-station metrics:")
print(metrics_df.head())

metrics_csv = os.path.join(metrics_dir, "station_metrics_8Nearest_LWR_ERA5_1991_2012.csv")
metrics_df.to_csv(metrics_csv, index=False)
print(f"Station metrics saved to {metrics_csv}")

###############################################################################
# 14. (OPTIONAL) SAVE FULL DAILY RESULTS AS NETCDF
###############################################################################
try:
    print("\nSaving daily station results to NetCDF file in daily_loop folder...")
    netcdf_file = os.path.join(daily_loop_dir, "era5_vs_stations_8Nearest_LWR_1991_2012.nc")

    # 1) Pivot so 'time' is the row index, 'station_index' are columns
    pivot_era5 = results_df.pivot(index='time', columns='station_index', values='era5_lwr8_val')
    pivot_obs  = results_df.pivot(index='time', columns='station_index', values='obs')

    # 2) Ensure the pivoted index is recognized as real datetime
    pivot_era5.index = pd.to_datetime(pivot_era5.index, errors='coerce')
    pivot_obs.index  = pd.to_datetime(pivot_obs.index,  errors='coerce')

    # 3) Convert each pivoted DataFrame into an xarray Dataset
    ds_era5 = pivot_era5.to_xarray()
    ds_obs  = pivot_obs.to_xarray()

    if 'index' in ds_era5.dims:
        ds_era5 = ds_era5.rename({'index': 'time'})
    if 'columns' in ds_era5.dims:
        ds_era5 = ds_era5.rename({'columns': 'station_index'})

    if 'index' in ds_obs.dims:
        ds_obs = ds_obs.rename({'index': 'time'})
    if 'columns' in ds_obs.dims:
        ds_obs = ds_obs.rename({'columns': 'station_index'})

    da_era5 = ds_era5.to_array(name='era5_lwr8_val').squeeze()
    da_obs  = ds_obs.to_array(name='obs').squeeze()

    ds_out = xr.Dataset({
        'era5_lwr8_val': da_era5,
        'obs':           da_obs
    })

    ds_out.to_netcdf(netcdf_file)
    print(f"NetCDF saved to: {netcdf_file}")

except Exception as e:
    print("Error saving NetCDF:", e)

print("\nDone. The 8-nearest LWR interpolation for ERA5 dataset is complete!")

In [None]:
#ERA5 Temperature
# This script performs an 8-nearest-grid LWR interpolation using ERA5 data for tmin-max

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import warnings
import os

###############################################################################
# 1. CONFIGURATIONS
###############################################################################
# Input files
station_file   = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\tmin_tmax_data.nc'
era5_file      = r'D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Temperature\Temperature ERA5\tmin-tmax\ERA5_GLB_daily_tmin_tmax_1991-2013_with_Elevation.nc'
physical_file  = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation_Temperature.csv'
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r'D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp'
target_crs     = "ESRI:102008"

# Output directories
# Adjust to match your desired folders
# This example uses a path near the ERA5 file location

daily_loop_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Temperature\Temperature ERA5\tmin-tmax\daily_loop"
metrics_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\ERA5_GLB_Temperature\Temperature ERA5\tmin-tmax\metrics"

os.makedirs(daily_loop_dir, exist_ok=True)
os.makedirs(metrics_dir,    exist_ok=True)

# Variable names
VAR_LIST = ['tmin', 'tmax']       # <-- new

# Time range: 1991–2012
start_date = '1991-01-01'
end_date   = '1991-12-31'

###############################################################################
# 2. DISTANCE & WEIGHT FUNCTIONS
###############################################################################
def haversine_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = (np.sin(dlat / 2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2)**2)
    c = 2 * np.arcsin(np.sqrt(a))
    return 6371 * c

def tricube_weight(distances, d_max):
    if d_max == 0:
        return np.ones_like(distances)
    ratio = distances / d_max
    w = (1 - ratio**3)**3
    w[distances > d_max] = 0.0
    return w

###############################################################################
# 3. OPTIONAL: FORCE ERA5 TIME
###############################################################################
def force_era5_time(ds):
    """
    If needed, convert time coordinate to datetime.
    If ds['time'] is already a standard datetime, skip or simplify.
    """
    if 'time' not in ds.coords and 'time' in ds.dims:
        ds = ds.assign_coords(time=ds['time'])
    ds['time'] = pd.to_datetime(ds['time'].values)
    return ds

###############################################################################
# 4. LOAD DATASETS
###############################################################################
print("Loading station observations (NetCDF) ...")
obs_ds = xr.open_dataset(station_file)

print("Loading ERA5 reanalysis (NetCDF) ...")
era5_ds = xr.open_dataset(era5_file)

# Fix station time coords if needed
if 'time' in obs_ds.coords:
    obs_ds['time'] = pd.to_datetime(obs_ds['time'].values)

# Force ERA5 time if needed
era5_ds = force_era5_time(era5_ds)

# Subset to 1991–2012
obs_ds   = obs_ds.sel(time=slice(start_date, end_date))
era5_ds = era5_ds.sel(time=slice(start_date, end_date))

print("After subsetting:")
print(f"obs_ds time steps = {obs_ds.sizes['time']}")
print(f"era5_ds time steps = {era5_ds.sizes['time']}")

###############################################################################
# 5. IF dims are (time, lon, lat), transpose to (time, lat, lon)
###############################################################################
for v in VAR_LIST:
    actual_dims = era5_ds[v].dims
    print(f"Current dims for ERA5 {v}:", actual_dims)
    if actual_dims == ("time", "lon", "lat"):
        print(f"Transposing {v} from (time,lon,lat) -> (time,lat,lon).")
        era5_ds[v] = era5_ds[v].transpose("time", "lat", "lon")
        print("New dims:", era5_ds[v].dims)
    else:
        print(f"No transpose needed for {v}.")

###############################################################################
# 6. EXTRACT ERA5 GRID
###############################################################################
lats = era5_ds['lat'].values
lons = era5_ds['lon'].values
lon2d, lat2d = np.meshgrid(lons, lats)

if 'elevation' in era5_ds:
    grid_elev = era5_ds['elevation'].values
else:
    warnings.warn("No 'elevation' found in ERA5 dataset => set elev=0.")
    grid_elev = np.zeros_like(lat2d)

grid_lat_flat  = lat2d.flatten()
grid_lon_flat  = lon2d.flatten()
grid_elev_flat = grid_elev.flatten()

###############################################################################
# 7. LOAD STATIONS & MATCH
###############################################################################
print("Loading station metadata (CSV) ...")
stations_df = pd.read_csv(physical_file).dropna(axis=1, how='all')
print(stations_df.head())

# If station_file has dimension station, match by name
station_names_nc = obs_ds['station'].values
station_index_map = {name: i for i, name in enumerate(station_names_nc)}

stations_df['netcdf_index'] = stations_df['NAME'].map(station_index_map)
stations_df = stations_df.dropna(subset=['netcdf_index']).reset_index(drop=True)
stations_df['netcdf_index'] = stations_df['netcdf_index'].astype(int)
print("Total matched stations:", len(stations_df))

###############################################################################
# 8. PRECOMPUTE THE 8 NEAREST GRIDS FOR EACH STATION
###############################################################################
station_neighbors = {}
print("Precomputing the 8 nearest ERA5 grids for each station ...")

for i, row in stations_df.iterrows():
    st_lat = row['LATITUDE']
    st_lon = row['LONGITUDE']
    dist_all = haversine_distance(st_lat, st_lon, grid_lat_flat, grid_lon_flat)
    
    # Sort by distance and take the indices of the 8 nearest
    sorted_idx = np.argsort(dist_all)
    neighbor_idx = sorted_idx[:8]
    
    station_neighbors[i] = neighbor_idx

print("Neighbor precomputation complete.")

###############################################################################
# 9. LOCAL WEIGHTED REGRESSION FUNCTION  (updated for temperature)
###############################################################################
def local_weighted_regression(
    station_lat, station_lon, station_elev,
    neighbor_indices,
    grid_lat, grid_lon, grid_elev, grid_val
):
    """
    Local Weighted Regression (LWR) using up to 8 nearest ERA5 grid points.

    Regression model:
        grid_val  ~  β0 + β1*lat + β2*lon + β3*elev

    • Weights are tricube based on great-circle distance.
    • **No clipping** is applied because negative temperatures are valid.
    Returns NaN if regression cannot be solved.
    """
    if (neighbor_indices is None) or (len(neighbor_indices) == 0):
        return np.nan

    # Neighbour data
    lat_n  = grid_lat[neighbor_indices]
    lon_n  = grid_lon[neighbor_indices]
    elev_n = grid_elev[neighbor_indices]
    val_n  = grid_val[neighbor_indices]

    # Tricube weights
    dist_n = haversine_distance(station_lat, station_lon, lat_n, lon_n)
    d_max  = dist_n.max()
    w      = tricube_weight(dist_n, d_max)
    if np.all(w == 0):
        return np.nan

    # Design matrix with intercept
    X = np.column_stack([np.ones_like(lat_n), lat_n, lon_n, elev_n])

    # Weighted least squares
    sqrt_w = np.sqrt(w)
    X_w = X * sqrt_w[:, None]
    y_w = val_n * sqrt_w
    try:
        beta, *_ = np.linalg.lstsq(X_w, y_w, rcond=None)
    except np.linalg.LinAlgError:
        return np.nan

    # Predict at station location (no clipping)
    pred = np.array([1, station_lat, station_lon, station_elev]) @ beta
    return pred


###############################################################################
# 10. DAILY LOOP: ERA5 -> STATIONS
###############################################################################
results = []
era5_times = era5_ds['time'].values
print(f"Processing ERA5 days: {len(era5_times)}")

for t_index, t_val in enumerate(era5_times):
    current_time = pd.to_datetime(t_val)

    for var in VAR_LIST:                      # ← outer loop over tmin / tmax
        # ERA5 grid for this day & variable
        day_data_2d       = era5_ds[var].isel(time=t_index).values
        grid_val_flat_day = day_data_2d.flatten()

        # Observed station values for same day & variable
        obs_day = obs_ds[var].isel(time=t_index).values

        # Loop stations
        for i, row in stations_df.iterrows():
            st_name = row['NAME']
            st_lat  = row['LATITUDE']
            st_lon  = row['LONGITUDE']
            st_elev = row['Elevation'] if 'Elevation' in row else 0.0
            netcdf_idx = row['netcdf_index']

            st_obs = obs_day[netcdf_idx] if netcdf_idx < len(obs_day) else np.nan

            neigh_idx = station_neighbors[i]
            st_era5 = local_weighted_regression(
                station_lat=st_lat, station_lon=st_lon, station_elev=st_elev,
                neighbor_indices=neigh_idx,
                grid_lat=grid_lat_flat, grid_lon=grid_lon_flat,
                grid_elev=grid_elev_flat, grid_val=grid_val_flat_day
            )

            results.append({
                'time': current_time,
                'var':  var,                   # identify tmin / tmax
                'station_index': i,
                'station_name':  st_name,
                'obs':  st_obs,
                'era5_lwr8_val': st_era5
            })

    if (t_index + 1) % 100 == 0:
        print(f"Processed day {t_index+1} / {len(era5_times)}")

###############################################################################
# 11. BUILD A DATAFRAME & SAVE
###############################################################################
results_df = pd.DataFrame(results)
results_df['time'] = pd.to_datetime(results_df['time'])
print("Sample of daily results:\n", results_df.head())

daily_loop_csv = os.path.join(
    daily_loop_dir,
    "era5_vs_stations_8Nearest_LWR_1991_2012_tmin_tmax.csv"
)
results_df.to_csv(daily_loop_csv, index=False)
print(f"Daily interpolation results saved to {daily_loop_csv}")

###############################################################################
# 12. (OPTIONAL) MAPPING
###############################################################################
try:
    print("Creating a quick map of average ERA5 temperature at each station (8-nearest LWR)...")
    glb  = gpd.read_file(shapefile_path).to_crs(target_crs)
    lakes= gpd.read_file(lakes_shp).to_crs(target_crs)

    station_mean = results_df.groupby('station_index')['era5_lwr8_val'].mean().reset_index()
    merged = stations_df.copy()
    merged['mean_era5_lwr8'] = station_mean['era5_lwr8_val']

    geometry = [Point(lon, lat) for lon, lat in zip(merged['LONGITUDE'], merged['LATITUDE'])]
    stations_gdf = gpd.GeoDataFrame(merged, geometry=geometry, crs="EPSG:4326").to_crs(target_crs)

    fig, ax = plt.subplots(figsize=(10, 8))
    glb.boundary.plot(ax=ax, edgecolor='black')
    lakes.plot(ax=ax, color='blue', alpha=0.5)
    stations_gdf.plot(column='mean_era5_lwr8', ax=ax, legend=True, cmap='viridis', markersize=50)
    ax.set_title("Mean ERA5 temperature (8-Nearest LWR), 1991-2012")
    plt.show()
except Exception as e:
    print("Mapping step skipped due to error or missing shapefiles:", str(e))

###############################################################################
# 13. METRICS & SAVING
###############################################################################
def remove_nan_pairs(obs, pred):
    m = ~np.isnan(obs) & ~np.isnan(pred)
    return obs[m], pred[m]

# Overall metrics per variable
for var, sub in results_df.groupby('var'):
    obs_all, era5_all = remove_nan_pairs(sub['obs'].values, sub['era5_lwr8_val'].values)
    if len(obs_all) == 0:
        print(f"\nNo valid pairs for {var}.")
        continue
    metrics_all = {
        'MBE':  np.mean(era5_all - obs_all),
        'RMSE': np.sqrt(np.mean((era5_all - obs_all)**2)),
        'STD':  np.std(era5_all - obs_all, ddof=1),
        'CC':   np.corrcoef(obs_all, era5_all)[0,1],
        'Index_of_Agreement':
            1 - np.sum((era5_all - obs_all)**2) /
                np.sum((np.abs(era5_all - obs_all.mean()) +
                        np.abs(obs_all - obs_all.mean()))**2)
    }
    print(f"\nOverall metrics ({var}, ERA5 8-Nearest LWR, 1991-2012):")
    for k, v in metrics_all.items():
        print(f"{k}: {v:.4f}")


# Station-level metrics, still per variable
station_metrics = []
for (var, st_idx), grp in results_df.groupby(['var','station_index']):
    obs_st, era5_st = remove_nan_pairs(grp['obs'].values, grp['era5_lwr8_val'].values)
    if len(obs_st)==0:
        station_metrics.append({'var':var,'station_index':st_idx,
                                'station_name':grp['station_name'].iloc[0],
                                'MBE':np.nan,'RMSE':np.nan,'STD':np.nan,
                                'CC':np.nan,'Index_of_Agreement':np.nan})
        continue
    station_metrics.append({
        'var': var,
        'station_index': st_idx,
        'station_name':  grp['station_name'].iloc[0],
        'MBE':  np.mean(era5_st - obs_st),
        'RMSE': np.sqrt(np.mean((era5_st - obs_st)**2)),
        'STD':  np.std(era5_st - obs_st, ddof=1),
        'CC':   np.corrcoef(obs_st, era5_st)[0,1],
        'Index_of_Agreement':
            1 - np.sum((era5_st - obs_st)**2) /
                np.sum((np.abs(era5_st - obs_st.mean()) +
                        np.abs(obs_st - obs_st.mean()))**2)
    })

metrics_df = pd.DataFrame(station_metrics)
metrics_csv = os.path.join(metrics_dir,
    "station_metrics_8Nearest_LWR_ERA5_1991_2012_tmin_tmax.csv")
metrics_df.to_csv(metrics_csv, index=False)
print(f"Station metrics saved to {metrics_csv}")

###############################################################################
# 14. (OPTIONAL) SAVE FULL DAILY RESULTS AS NETCDF
###############################################################################
try:
    print("\nSaving daily station results to NetCDF …")

    netcdf_file = os.path.join(
        daily_loop_dir,
        "era5_vs_stations_8Nearest_LWR_1991_2012_tmin_tmax.nc"
    )

    # ---- hold one DataArray per variable/dataset ----
    ds_vars = {}

    for var in VAR_LIST:
        # Pivot to (time × station_index) tables
        pivot_era5 = (
            results_df[results_df["var"] == var]
            .pivot(index="time", columns="station_index", values="era5_lwr8_val")
            .sort_index()
        )
        pivot_obs = (
            results_df[results_df["var"] == var]
            .pivot(index="time", columns="station_index", values="obs")
            .sort_index()
        )

        # Make sure both have identical columns (stations)
        common_cols = pivot_era5.columns.intersection(pivot_obs.columns)
        pivot_era5 = pivot_era5[common_cols]
        pivot_obs  = pivot_obs [common_cols]

        # Build DataArrays
        da_era5 = xr.DataArray(
            data   = pivot_era5.values,
            dims   = ("time", "station_index"),
            coords = {"time": pivot_era5.index,
                      "station_index": common_cols},
            name   = f"era5_lwr8_{var}"
        )
        da_obs  = xr.DataArray(
            data   = pivot_obs.values,
            dims   = ("time", "station_index"),
            coords = {"time": pivot_obs.index,
                      "station_index": common_cols},
            name   = f"obs_{var}"
        )

        ds_vars[da_era5.name] = da_era5
        ds_vars[da_obs.name]  = da_obs

    # Assemble and save
    ds_out = xr.Dataset(ds_vars)
    ds_out.to_netcdf(netcdf_file)
    print(f"NetCDF saved to: {netcdf_file}")

except Exception as e:
    print("Error saving NetCDF:", e)

In [None]:
#MERRA-2
# This script performs an 12-nearest-grid LWR interpolation using MERRA2 data.

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import warnings
import os

###############################################################################
# 1. CONFIGURATIONS
###############################################################################
# Input files
station_file   = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\prcp_data.nc'
merra2_file      = r'D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\MERRA2_GLB_prcp_1991_2013_with_Elevation.nc'
physical_file  = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv'
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r'D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp'
target_crs     = "ESRI:102008"

# Output directories
# Adjust to match your desired folders
# This example uses a path near the ERA5 file location

daily_loop_dir = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\daily_loop"
metrics_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\MERRA2_GLB_Precipitation\metrics"

os.makedirs(daily_loop_dir, exist_ok=True)
os.makedirs(metrics_dir,    exist_ok=True)

# Variable names
obs_var_name  = 'prcp'  # Station daily precipitation var
merra2_var_name = 'prcp'  # ERA5 daily precipitation var (change if needed)

# Time range: 1991–2012
start_date = '1991-01-01'
end_date   = '2012-12-31'

###############################################################################
# 2. DISTANCE & WEIGHT FUNCTIONS
###############################################################################
def haversine_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = (np.sin(dlat / 2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2)**2)
    c = 2 * np.arcsin(np.sqrt(a))
    return 6371 * c

def tricube_weight(distances, d_max):
    if d_max == 0:
        return np.ones_like(distances)
    ratio = distances / d_max
    w = (1 - ratio**3)**3
    w[distances > d_max] = 0.0
    return w

###############################################################################
# 3. OPTIONAL: FORCE ERA5 TIME
###############################################################################
def force_merra2_time(ds):
    """
    If needed, convert time coordinate to datetime.
    If ds['time'] is already a standard datetime, skip or simplify.
    """
    if 'time' not in ds.coords and 'time' in ds.dims:
        ds = ds.assign_coords(time=ds['time'])
    ds['time'] = pd.to_datetime(ds['time'].values)
    return ds

###############################################################################
# 4. LOAD DATASETS
###############################################################################
print("Loading station observations (NetCDF) ...")
obs_ds = xr.open_dataset(station_file)

print("Loading merra2 reanalysis (NetCDF) ...")
era5_ds = xr.open_dataset(merra2_file)

# Fix station time coords if needed
if 'time' in obs_ds.coords:
    obs_ds['time'] = pd.to_datetime(obs_ds['time'].values)

# Force ERA5 time if needed
merra2_ds = force_merra2_time(merra2_ds)

# Subset to 1991–2012
obs_ds   = obs_ds.sel(time=slice(start_date, end_date))
merra2_ds = merra2_ds.sel(time=slice(start_date, end_date))

print("After subsetting:")
print(f"obs_ds time steps = {obs_ds.sizes['time']}")
print(f"merra2_ds time steps = {merra2_ds.sizes['time']}")

###############################################################################
# 5. IF dims are (time, lon, lat), transpose to (time, lat, lon)
###############################################################################
actual_dims = merra2_ds[merra2_var_name].dims
print("Current dims for merra2 prcp:", actual_dims)
if actual_dims == ("time","lon","lat"):
    print("Transposing prcp from (time,lon,lat) -> (time,lat,lon).")
    merra2_ds[merra2_var_name] = merra2_ds[merra2_var_name].transpose("time","lat","lon")
    print("New dims:", merra2_ds[merra2_var_name].dims)
else:
    print("No transpose needed or already (time,lat,lon).")

###############################################################################
# 6. EXTRACT ERA5 GRID
###############################################################################
lats = merra2_ds['lat'].values
lons = merra2_ds['lon'].values
lon2d, lat2d = np.meshgrid(lons, lats)

if 'elevation' in merra2_ds:
    grid_elev = merra2_ds['elevation'].values
else:
    warnings.warn("No 'elevation' found in merra2 dataset => set elev=0.")
    grid_elev = np.zeros_like(lat2d)

grid_lat_flat  = lat2d.flatten()
grid_lon_flat  = lon2d.flatten()
grid_elev_flat = grid_elev.flatten()

###############################################################################
# 7. LOAD STATIONS & MATCH
###############################################################################
print("Loading station metadata (CSV) ...")
stations_df = pd.read_csv(physical_file).dropna(axis=1, how='all')
print(stations_df.head())

# If station_file has dimension station, match by name
station_names_nc = obs_ds['station'].values
station_index_map = {name: i for i, name in enumerate(station_names_nc)}

stations_df['netcdf_index'] = stations_df['NAME'].map(station_index_map)
stations_df = stations_df.dropna(subset=['netcdf_index']).reset_index(drop=True)
stations_df['netcdf_index'] = stations_df['netcdf_index'].astype(int)
print("Total matched stations:", len(stations_df))

###############################################################################
# 8. PRECOMPUTE THE 9 NEAREST GRIDS FOR EACH STATION
###############################################################################
station_neighbors = {}
print("Precomputing the 12 nearest ERA5 grids for each station ...")

for i, row in stations_df.iterrows():
    st_lat = row['LATITUDE']
    st_lon = row['LONGITUDE']
    dist_all = haversine_distance(st_lat, st_lon, grid_lat_flat, grid_lon_flat)
    
    # Sort by distance and take the indices of the 12 nearest
    sorted_idx = np.argsort(dist_all)
    neighbor_idx = sorted_idx[:12]
    
    station_neighbors[i] = neighbor_idx

print("Neighbor precomputation complete.")

###############################################################################
# 9. LOCAL WEIGHTED REGRESSION FUNCTION
###############################################################################
def local_weighted_regression(
    station_lat, station_lon, station_elev,
    neighbor_indices,
    grid_lat, grid_lon, grid_elev, grid_val
):
    """
    Weighted linear regression of grid_val ~ [1, lat, lon, elev].
    Uses tricube weighting based on distances.
    Clip negative precipitation to zero.
    """
    if (neighbor_indices is None) or (len(neighbor_indices) < 1):
        return np.nan

    lat_n  = grid_lat[neighbor_indices]
    lon_n  = grid_lon[neighbor_indices]
    elev_n = grid_elev[neighbor_indices]
    val_n  = grid_val[neighbor_indices]

    dist_n = haversine_distance(station_lat, station_lon, lat_n, lon_n)
    d_max  = dist_n.max()
    w = tricube_weight(dist_n, d_max)
    if np.all(w == 0):
        return np.nan

    # Build design matrix X: [1, lat, lon, elev]
    X = np.column_stack([np.ones_like(lat_n), lat_n, lon_n, elev_n])

    sqrt_w = np.sqrt(w)
    X_w = X * sqrt_w[:, None]
    y_w = val_n * sqrt_w

    try:
        beta, residuals, rank, s = np.linalg.lstsq(X_w, y_w, rcond=None)
    except np.linalg.LinAlgError:
        return np.nan

    # Predict at station location
    X_station = np.array([1, station_lat, station_lon, station_elev])
    pred = X_station @ beta

    # Clip negative precipitation to zero
    if pred < 0:
        pred = 0.0

    return pred

###############################################################################
# 10. DAILY LOOP: merra2 -> STATIONS
###############################################################################
results = []
merra2_times = merra2_ds['time'].values
print(f"Processing merra2 days: {len(merra2_times)}")

for t_index, t_val in enumerate(merra2_times):
    current_time = pd.to_datetime(t_val)
    # shape => (lat, lon)
    day_data_2d = merra2_ds[merra2_var_name].isel(time=t_index).values
    grid_val_flat_day = day_data_2d.flatten()

    obs_day = obs_ds[obs_var_name].isel(time=t_index).values

    for i, row in stations_df.iterrows():
        st_name    = row['NAME']
        st_lat     = row['LATITUDE']
        st_lon     = row['LONGITUDE']
        st_elev    = row['Elevation'] if 'Elevation' in row else 0.0
        netcdf_idx = row['netcdf_index']

        st_obs = obs_day[netcdf_idx] if netcdf_idx < len(obs_day) else np.nan

        neigh_idx = station_neighbors[i]
        st_merra2 = local_weighted_regression(
            station_lat=st_lat,
            station_lon=st_lon,
            station_elev=st_elev,
            neighbor_indices=neigh_idx,
            grid_lat=grid_lat_flat,
            grid_lon=grid_lon_flat,
            grid_elev=grid_elev_flat,
            grid_val=grid_val_flat_day
        )

        results.append({
            'time': current_time,
            'station_index': i,
            'station_name': st_name,
            'obs': st_obs,
            'merra2_lwr12_val': st_merra2
        })

    if (t_index + 1) % 100 == 0:
        print(f"Processed day {t_index+1} / {len(merra2_times)}")

###############################################################################
# 11. BUILD A DATAFRAME & SAVE
###############################################################################
results_df = pd.DataFrame(results)
results_df['time'] = pd.to_datetime(results_df['time'])
print("Sample of daily results:\n", results_df.head())

# CSV output with the 8-nearest LWR data
daily_loop_csv = os.path.join(daily_loop_dir, "merra2_vs_stations_12Nearest_LWR_1991_2012.csv")
results_df.to_csv(daily_loop_csv, index=False)
print(f"Daily interpolation results saved to {daily_loop_csv}")

###############################################################################
# 12. (OPTIONAL) MAPPING
###############################################################################
try:
    print("Creating a quick map of average merra2 precipitation at each station (12-nearest LWR)...")
    glb  = gpd.read_file(shapefile_path).to_crs(target_crs)
    lakes= gpd.read_file(lakes_shp).to_crs(target_crs)

    station_mean = results_df.groupby('station_index')['merra2_lwr12_val'].mean().reset_index()
    merged = stations_df.copy()
    merged['mean_merra2_lwr12'] = station_mean['merra2_lwr12_val']

    geometry = [Point(lon, lat) for lon, lat in zip(merged['LONGITUDE'], merged['LATITUDE'])]
    stations_gdf = gpd.GeoDataFrame(merged, geometry=geometry, crs="EPSG:4326").to_crs(target_crs)

    fig, ax = plt.subplots(figsize=(10, 8))
    glb.boundary.plot(ax=ax, edgecolor='black')
    lakes.plot(ax=ax, color='blue', alpha=0.5)
    stations_gdf.plot(column='mean_merra2_lwr12', ax=ax, legend=True, cmap='viridis', markersize=50)
    ax.set_title("Mean merra2 Precip (12-Nearest LWR), 1991-2012")
    plt.show()
except Exception as e:
    print("Mapping step skipped due to error or missing shapefiles:", str(e))

###############################################################################
# 13. METRICS & SAVING
###############################################################################
def remove_nan_pairs(obs, pred):
    mask = ~np.isnan(obs) & ~np.isnan(pred)
    return obs[mask], pred[mask]

def mean_bias_error(obs, pred):
    return np.mean(pred - obs)

def root_mean_square_error(obs, pred):
    return np.sqrt(np.mean((pred - obs)**2))

def std_of_residuals(obs, pred):
    return np.std(pred - obs, ddof=1)

def pearson_correlation(obs, pred):
    if len(obs) < 2:
        return np.nan
    return np.corrcoef(obs, pred)[0, 1]

def index_of_agreement(obs, pred):
    obs_mean = np.mean(obs)
    numerator = np.sum((pred - obs)**2)
    denominator = np.sum((np.abs(pred - obs_mean) + np.abs(obs - obs_mean))**2)
    if denominator == 0:
        return np.nan
    return 1 - numerator / denominator

obs_all = results_df['obs'].values
merra2_all = results_df['merra2_lwr12_val'].values
obs_all, merra2_all = remove_nan_pairs(obs_all, merra2_all)

if len(obs_all) > 0:
    metrics_all = {
        'MBE': mean_bias_error(obs_all, merra2_all),
        'RMSE': root_mean_square_error(obs_all, merra2_all),
        'STD': std_of_residuals(obs_all, merra2_all),
        'CC': pearson_correlation(obs_all, merra2_all),
        'Index_of_Agreement': index_of_agreement(obs_all, merra2_all)
    }
    print("\nOverall Metrics (All Stations, merra2 12-Nearest LWR, 1991-2012):")
    for k, v in metrics_all.items():
        print(f"{k}: {v:.4f}")
else:
    print("\nNo valid (obs, merra2_lwr12_val) pairs found for overall metrics.")

# Station-level metrics
station_groups = results_df.groupby('station_index')
per_station_metrics = []

for st_idx, grp in station_groups:
    obs_st = grp['obs'].values
    merra2_st = grp['merra2_lwr12_val'].values
    obs_st, merra2_st = remove_nan_pairs(obs_st, merra2_st)
    if len(obs_st) == 0:
        per_station_metrics.append({
            'station_index': st_idx,
            'station_name': grp['station_name'].iloc[0],
            'MBE': np.nan,
            'RMSE': np.nan,
            'STD': np.nan,
            'CC': np.nan,
            'Index_of_Agreement': np.nan
        })
        continue

    row_dict = {
        'station_index': st_idx,
        'station_name': grp['station_name'].iloc[0],
        'MBE':  mean_bias_error(obs_st, merra2_st),
        'RMSE': root_mean_square_error(obs_st, merra2_st),
        'STD':  std_of_residuals(obs_st, merra2_st),
        'CC':   pearson_correlation(obs_st, merra2_st),
        'Index_of_Agreement': index_of_agreement(obs_st, merra2_st)
    }
    per_station_metrics.append(row_dict)

metrics_df = pd.DataFrame(per_station_metrics)
print("\nSample of per-station metrics:")
print(metrics_df.head())

metrics_csv = os.path.join(metrics_dir, "station_metrics_12Nearest_LWR_merra2_1991_2012.csv")
metrics_df.to_csv(metrics_csv, index=False)
print(f"Station metrics saved to {metrics_csv}")

###############################################################################
# 14. (OPTIONAL) SAVE FULL DAILY RESULTS AS NETCDF
###############################################################################
try:
    print("\nSaving daily station results to NetCDF file in daily_loop folder...")
    netcdf_file = os.path.join(daily_loop_dir, "merra2_vs_stations_12Nearest_LWR_1991_2012.nc")

    # 1) Pivot so 'time' is the row index, 'station_index' are columns
    pivot_merra2 = results_df.pivot(index='time', columns='station_index', values='merra2_lwr12_val')
    pivot_obs  = results_df.pivot(index='time', columns='station_index', values='obs')

    # 2) Ensure the pivoted index is recognized as real datetime
    pivot_merra2.index = pd.to_datetime(pivot_merra2.index, errors='coerce')
    pivot_obs.index  = pd.to_datetime(pivot_obs.index,  errors='coerce')

    # 3) Convert each pivoted DataFrame into an xarray Dataset
    ds_merra2 = pivot_merra2.to_xarray()
    ds_obs  = pivot_obs.to_xarray()

    if 'index' in ds_merra2.dims:
        ds_merra2 = ds_merra2.rename({'index': 'time'})
    if 'columns' in ds_merra2.dims:
        ds_merra2 = ds_merra2.rename({'columns': 'station_index'})

    if 'index' in ds_obs.dims:
        ds_obs = ds_obs.rename({'index': 'time'})
    if 'columns' in ds_obs.dims:
        ds_obs = ds_obs.rename({'columns': 'station_index'})

    da_merra2 = ds_merra2.to_array(name='merra2_lwr12_val').squeeze()
    da_obs  = ds_obs.to_array(name='obs').squeeze()

    ds_out = xr.Dataset({
        'merra2_lwr12_val': da_merra2,
        'obs':           da_obs
    })

    ds_out.to_netcdf(netcdf_file)
    print(f"NetCDF saved to: {netcdf_file}")

except Exception as e:
    print("Error saving NetCDF:", e)

print("\nDone. The 12-nearest LWR interpolation for merra2 dataset is complete!")

In [None]:
#RDRS v2.1
# The approach of making the interpolated RDRS data at the station location with LWR 25 km

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import warnings
import time
import os

###############################################################################
# 1. CONFIGURATIONS
###############################################################################
station_file   = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\prcp_data.nc'
rdrs_file      = r'D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\RDRS_v2.1_merged_prcp_1988_2015_with_Elevation.nc'
physical_file  = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv'
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r'D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp'
target_crs     = "ESRI:102008"

# Paths for daily loop & metrics (same directory as ERA5 file)
daily_loop_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\daily_loop"
metrics_dir       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\RDRS v2.1_GLB_Precipitation\metrics"

os.makedirs(daily_loop_dir, exist_ok=True)
os.makedirs(metrics_dir,    exist_ok=True)

# Variable names
obs_var_name = 'prcp'
rdrs_var     = 'prcp'

# Local Weighted Regression settings
search_radius_km = 25
min_points = 5

# Time range: 1991–2012
start_date = '1991-01-01'
end_date   = '2012-12-31'

###############################################################################
# 2. DISTANCE & WEIGHT FUNCTIONS
###############################################################################
def haversine_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = (np.sin(dlat / 2)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon / 2)**2)
    c = 2*np.arcsin(np.sqrt(a))
    return 6371 * c

def tricube_weight(distances, d_max):
    """
    w_j = [1 - (dist_j / d_max)^3]^3  for dist_j <= d_max; else 0
    """
    if d_max == 0:
        return np.ones_like(distances)
    ratio = distances / d_max
    w = (1 - ratio**3)**3
    w[distances > d_max] = 0.0
    return w

###############################################################################
# 3. FORCE EMDNA TIME (IF NEEDED)
###############################################################################
def force_rdrs_time(ds):
    if 'time' not in ds.coords and 'time' in ds.dims:
        ds = ds.assign_coords(time=ds['time'])
    if np.issubdtype(ds['time'].dtype, np.integer):
        day0 = pd.to_datetime("1991-01-01")
        numeric_days = ds['time'].values
        real_times = [day0 + pd.Timedelta(days=int(d)) for d in numeric_days]
        ds = ds.assign_coords(time=("time", real_times))
    ds['time'] = pd.to_datetime(ds['time'].values)
    return ds

print("Loading station observations (NetCDF) ...")
obs_ds = xr.open_dataset(station_file)

print("Loading RDRS reanalysis (NetCDF) ...")
rdrs_ds = xr.open_dataset(rdrs_file)

# Fix time coords if needed
if 'time' in obs_ds.coords:
    obs_ds['time'] = pd.to_datetime(obs_ds['time'].values)
rdrs_ds = force_rdrs_time(rdrs_ds)

# Subset to 1991–2012
obs_ds   = obs_ds.sel(time=slice(start_date, end_date))
rdrs_ds  = rdrs_ds.sel(time=slice(start_date, end_date))

print("After subsetting:")
print(f"obs_ds time steps = {obs_ds.sizes['time']}")
print(f"rdrs_ds time steps = {rdrs_ds.sizes['time']}")

# If dims are (time, lon, lat), transpose to (time, lat, lon)
print("Checking dimension order of the RDRS variable ...")
actual_dims = rdrs_ds[rdrs_var].dims
print("Current dims for prcp:", actual_dims)
if actual_dims == ("time", "lon", "lat"):
    print("Transposing prcp from (time,lon,lat) -> (time,lat,lon).")
    rdrs_ds[rdrs_var] = rdrs_ds[rdrs_var].transpose("time", "lat", "lon")
    print("New dims:", rdrs_ds[rdrs_var].dims)
else:
    print("No transpose needed or dims are already (time,lat,lon).")

lats = rdrs_ds['lat'].values
lons = rdrs_ds['lon'].values
lon2d, lat2d = np.meshgrid(lons, lats)

if 'elevation' in rdrs_ds:
    grid_elev = rdrs_ds['elevation'].values
else:
    warnings.warn("No 'elevation' found in RDRS dataset => set elev=0.")
    grid_elev = np.zeros_like(lat2d)

grid_lat_flat  = lat2d.flatten()
grid_lon_flat  = lon2d.flatten()
grid_elev_flat = grid_elev.flatten()

###############################################################################
# 4. LOAD STATIONS & MATCH
###############################################################################
print("Loading station metadata (CSV) ...")
stations_df = pd.read_csv(physical_file).dropna(axis=1, how='all')
print(stations_df.head())

print("Matching station names to NetCDF index ...")
station_names_nc = obs_ds['station'].values
station_index_map = {name: i for i, name in enumerate(station_names_nc)}
stations_df['netcdf_index'] = stations_df['NAME'].map(station_index_map)
stations_df = stations_df.dropna(subset=['netcdf_index']).reset_index(drop=True)
stations_df['netcdf_index'] = stations_df['netcdf_index'].astype(int)
print("Total matched stations:", len(stations_df))

###############################################################################
# 5. PRECOMPUTE NEIGHBORS FOR EACH STATION
###############################################################################
station_neighbors = {}
print(f"Precomputing neighbors within {search_radius_km} km...")

for i, row in stations_df.iterrows():
    st_lat = row['LATITUDE']
    st_lon = row['LONGITUDE']
    dist_all = haversine_distance(st_lat, st_lon, grid_lat_flat, grid_lon_flat)
    neighbor_idx = np.where(dist_all <= search_radius_km)[0]
    if len(neighbor_idx) < min_points:
        station_neighbors[i] = None
    else:
        station_neighbors[i] = neighbor_idx

print("Neighbor precomputation complete.")

###############################################################################
# 6. LOCAL WEIGHTED REGRESSION FUNCTION
###############################################################################
def local_weighted_regression(
    station_lat, station_lon, station_elev,
    neighbor_indices,
    grid_lat, grid_lon, grid_elev, grid_val
):
    """
    Weighted linear regression of grid_val ~ [1, lat, lon, elev],
    with tricube weighting based on station->grid distance.
    We clip negative results to zero for precipitation.
    """
    if neighbor_indices is None or len(neighbor_indices) < min_points:
        return np.nan

    lat_n  = grid_lat[neighbor_indices]
    lon_n  = grid_lon[neighbor_indices]
    elev_n = grid_elev[neighbor_indices]
    val_n  = grid_val[neighbor_indices]

    # Distances
    dist_n = haversine_distance(station_lat, station_lon, lat_n, lon_n)
    d_max  = dist_n.max()
    w = tricube_weight(dist_n, d_max)
    if np.all(w == 0):
        return np.nan

    # Build design matrix X: [1, lat, lon, elev]
    X = np.column_stack([np.ones_like(lat_n), lat_n, lon_n, elev_n])
    sqrt_w = np.sqrt(w)
    X_w = X * sqrt_w[:, None]
    y_w = val_n * sqrt_w

    # Solve Weighted LS
    try:
        beta, residuals, rank, s = np.linalg.lstsq(X_w, y_w, rcond=None)
    except np.linalg.LinAlgError:
        return np.nan

    # Predict at station location
    X_station = np.array([1, station_lat, station_lon, station_elev])
    pred = X_station @ beta

    # Clip negative precipitation to zero
    if pred < 0:
        pred = 0.0

    return pred

###############################################################################
# 7. DAILY LOOP: RDRS -> STATIONS
###############################################################################
results = []
rdrs_times = rdrs_ds['time'].values
print(f"Processing RDRS days: {len(rdrs_times)}")

for t_index, t_val in enumerate(rdrs_times):
    current_time = pd.to_datetime(t_val)
    day_data_2d = rdrs_ds[rdrs_var].isel(time=t_index).values  # shape => (lat, lon)
    grid_val_flat = day_data_2d.flatten()

    obs_day = obs_ds[obs_var_name].isel(time=t_index).values

    for i, row in stations_df.iterrows():
        st_name   = row['NAME']
        st_lat    = row['LATITUDE']
        st_lon    = row['LONGITUDE']
        st_elev   = row['Elevation']  # If station has elevation, else 0
        netcdf_idx = row['netcdf_index']

        st_obs = obs_day[netcdf_idx] if netcdf_idx < len(obs_day) else np.nan

        neigh_idx = station_neighbors[i]
        st_rdrs = local_weighted_regression(
            station_lat=st_lat,
            station_lon=st_lon,
            station_elev=st_elev,
            neighbor_indices=neigh_idx,
            grid_lat=grid_lat_flat,
            grid_lon=grid_lon_flat,
            grid_elev=grid_elev_flat,
            grid_val=grid_val_flat
        )

        results.append({
            'time': current_time,
            'station_index': i,
            'station_name': st_name,
            'obs': st_obs,
            'rdrs_val': st_rdrs
        })

    if (t_index+1) % 100 == 0:
        print(f"Processed day {t_index+1} / {len(rdrs_times)}")

###############################################################################
# 8. BUILD A DATAFRAME & SAVE
###############################################################################
results_df = pd.DataFrame(results)
results_df['time'] = pd.to_datetime(results_df['time'])
print("Sample of daily results:\n", results_df.head())

# Save the daily loop CSV
daily_loop_csv = os.path.join(daily_loop_dir, "rdrs_vs_stations_25km_LWR_1991_2012.csv")
results_df.to_csv(daily_loop_csv, index=False)
print(f"Daily interpolation results saved to {daily_loop_csv}")

###############################################################################
# 9. (OPTIONAL) MAPPING
###############################################################################
try:
    print("Creating a quick map of average RDRS precipitation at each station ...")
    glb  = gpd.read_file(shapefile_path).to_crs(target_crs)
    lakes= gpd.read_file(lakes_shp).to_crs(target_crs)

    station_mean = results_df.groupby('station_index')['rdrs_val'].mean().reset_index()
    merged = stations_df.copy()
    merged['mean_rdrs'] = station_mean['rdrs_val']

    geometry = [Point(lon, lat) for lon, lat in zip(merged['LONGITUDE'], merged['LATITUDE'])]
    stations_gdf = gpd.GeoDataFrame(merged, geometry=geometry, crs="EPSG:4326")
    stations_gdf = stations_gdf.to_crs(target_crs)

    fig, ax = plt.subplots(figsize=(10, 8))
    glb.boundary.plot(ax=ax, edgecolor='black')
    lakes.plot(ax=ax, color='blue', alpha=0.5)
    stations_gdf.plot(column='mean_rdrs', ax=ax, legend=True,
                      cmap='viridis', markersize=50)
    ax.set_title("Mean RDRS Precip (LWR, 25km radius), 1991-2012")
    plt.show()
except Exception as e:
    print("Mapping step skipped due to error or missing shapefiles:", str(e))

###############################################################################
# 10. METRICS & SAVING
###############################################################################
def remove_nan_pairs(obs, pred):
    mask = ~np.isnan(obs) & ~np.isnan(pred)
    return obs[mask], pred[mask]

def mean_bias_error(obs, pred):
    return np.mean(pred - obs)

def root_mean_square_error(obs, pred):
    return np.sqrt(np.mean((pred - obs)**2))

def std_of_residuals(obs, pred):
    return np.std(pred - obs, ddof=1)

def pearson_correlation(obs, pred):
    if len(obs) < 2:
        return np.nan
    return np.corrcoef(obs, pred)[0, 1]

def index_of_agreement(obs, pred):
    obs_mean = np.mean(obs)
    numerator = np.sum((pred - obs)**2)
    denominator = np.sum((abs(pred - obs_mean) + abs(obs - obs_mean))**2)
    if denominator == 0:
        return np.nan
    return 1 - numerator / denominator

obs_all = results_df['obs'].values
rdrs_all = results_df['rdrs_val'].values
obs_all, rdrs_all = remove_nan_pairs(obs_all, rdrs_all)

if len(obs_all) > 0:
    metrics_all = {
        'MBE': mean_bias_error(obs_all, rdrs_all),
        'RMSE': root_mean_square_error(obs_all, rdrs_all),
        'STD': std_of_residuals(obs_all, rdrs_all),
        'CC': pearson_correlation(obs_all, rdrs_all),
        'Index_of_Agreement': index_of_agreement(obs_all, rdrs_all)
    }
    print("\nOverall Metrics (All Stations, LWR, 1991-2012):")
    for k, v in metrics_all.items():
        print(f"{k}: {v:.4f}")
else:
    print("\nNo valid (obs, rdrs_val) pairs found for overall metrics.")

# Station-level metrics
station_groups = results_df.groupby('station_index')
per_station_metrics = []
for st_idx, grp in station_groups:
    obs_st = grp['obs'].values
    rdrs_st = grp['rdrs_val'].values
    obs_st, rdrs_st = remove_nan_pairs(obs_st, rdrs_st)
    if len(obs_st) == 0:
        per_station_metrics.append({
            'station_index': st_idx,
            'station_name': grp['station_name'].iloc[0],
            'MBE': np.nan,
            'RMSE': np.nan,
            'STD': np.nan,
            'CC': np.nan,
            'Index_of_Agreement': np.nan
        })
        continue
    row_dict = {
        'station_index': st_idx,
        'station_name': grp['station_name'].iloc[0],
        'MBE':  mean_bias_error(obs_st, rdrs_st),
        'RMSE': root_mean_square_error(obs_st, rdrs_st),
        'STD':  std_of_residuals(obs_st, rdrs_st),
        'CC':   pearson_correlation(obs_st, rdrs_st),
        'Index_of_Agreement': index_of_agreement(obs_st, rdrs_st)
    }
    per_station_metrics.append(row_dict)

metrics_df = pd.DataFrame(per_station_metrics)
print("\nSample of per-station metrics:")
print(metrics_df.head())

metrics_csv = os.path.join(metrics_dir, "station_metrics_25km_LWR_1991_2012.csv")
metrics_df.to_csv(metrics_csv, index=False)
print(f"Station metrics saved to {metrics_csv}")

# (Optional) NetCDF
###############################################################################
# 10. (OPTIONAL) SAVE FULL DAILY RESULTS AS NETCDF
###############################################################################
try:
    print("\nSaving daily station results to NetCDF file in daily_loop folder...")
    netcdf_file = os.path.join(daily_loop_dir, "rdrs_vs_stations_25km_LWR_1991_2012.nc")

    # 1) Pivot so 'time' is the row index, 'station_index' are columns
    pivot_rdrs = results_df.pivot(index='time', columns='station_index', values='rdrs_val')
    pivot_obs  = results_df.pivot(index='time', columns='station_index', values='obs')

    # 2) Ensure the pivoted index is recognized as real datetime
    pivot_rdrs.index = pd.to_datetime(pivot_rdrs.index, errors='coerce')
    pivot_obs.index  = pd.to_datetime(pivot_obs.index, errors='coerce')

    # 3) Convert each pivoted DataFrame into an xarray Dataset
    ds_rdrs = pivot_rdrs.to_xarray()   # might have dims: ('time', 'station_index')
    ds_obs  = pivot_obs.to_xarray()

    #    If needed, rename dims. For instance, some versions produce dims ('index','columns').
    if 'index' in ds_rdrs.dims:
        ds_rdrs = ds_rdrs.rename({'index': 'time'})
    if 'columns' in ds_rdrs.dims:
        ds_rdrs = ds_rdrs.rename({'columns': 'station_index'})
    if 'index' in ds_obs.dims:
        ds_obs = ds_obs.rename({'index': 'time'})
    if 'columns' in ds_obs.dims:
        ds_obs = ds_obs.rename({'columns': 'station_index'})

    # 4) Turn each into a single DataArray:
    da_rdrs = ds_rdrs.to_array(name='rdrs_val').squeeze()  # dims => ('rdrs_val','time','station_index') => squeeze => ('time','station_index')
    da_obs  = ds_obs.to_array(name='obs').squeeze()

    # 5) Combine into a single Dataset with two DataArrays
    ds_out = xr.Dataset({
        'rdrs_val': da_rdrs,
        'obs':      da_obs
    })

    ds_out.to_netcdf(netcdf_file)
    print(f"NetCDF saved to: {netcdf_file}")

except Exception as e:
    print("Error saving NetCDF:", e)




In [None]:
#CHIRPS v2.0
# The approach of making the interpolated CHIRPS data at the station location with LWR (25km)

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
from shapely.geometry import Point
import warnings
import time
import os

###############################################################################
# 1. CONFIGURATIONS
###############################################################################
station_file   = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\prcp_data.nc'
chirps_file    = r'D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\CHIRPS_GLB_1991_2013_with_Elevation.nc'
physical_file  = r'D:\PhD\GLB\Merged USA and CA\Entire GLB\filtered_stations_with_elevation.csv'
shapefile_path = r"D:\PhD\GLB\greatlakes_subbasins\New folder\Great_Lakes.shp"
lakes_shp      = r'D:\PhD\GLB\greatlakes_subbasins\GLB_Water_Bodies\Main_Lakes_GLB.shp'
target_crs     = "ESRI:102008"

# Paths for daily loop & metrics (using CHIRPS directory)
daily_loop_dir    = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\daily_loop"
metrics_dir       = r"D:\PhD\GLB\EMDNA(Historical data)\Ensembles\New folder\Ensemble files\CHIRPS_GLB_Precipitation\Masked to GLB\metrics"

os.makedirs(daily_loop_dir, exist_ok=True)
os.makedirs(metrics_dir,    exist_ok=True)

# Variable names (both observed and gridded CHIRPS use 'prcp')
obs_var_name   = 'prcp'
chirps_var     = 'prcp'

# Local Weighted Regression settings
search_radius_km = 25   # Reduced from 50 km to 25 km
min_points = 5

# Time range: 1991–2012 (adjust if needed)
start_date = '1991-01-01'
end_date   = '2012-12-31'

###############################################################################
# 2. DISTANCE & WEIGHT FUNCTIONS
###############################################################################
def haversine_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = (np.sin(dlat / 2)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon / 2)**2)
    c = 2 * np.arcsin(np.sqrt(a))
    return 6371 * c

def tricube_weight(distances, d_max):
    """
    w_j = [1 - (dist_j / d_max)^3]^3  for dist_j <= d_max; else 0
    """
    if d_max == 0:
        return np.ones_like(distances)
    ratio = distances / d_max
    w = (1 - ratio**3)**3
    w[distances > d_max] = 0.0
    return w

###############################################################################
# 3. FORCE CHIRPS TIME (IF NEEDED)
###############################################################################
def force_chirps_time(ds):
    if 'time' not in ds.coords and 'time' in ds.dims:
        ds = ds.assign_coords(time=ds['time'])
    if np.issubdtype(ds['time'].dtype, np.integer):
        day0 = pd.to_datetime("1991-01-01")
        numeric_days = ds['time'].values
        real_times = [day0 + pd.Timedelta(days=int(d)) for d in numeric_days]
        ds = ds.assign_coords(time=("time", real_times))
    ds['time'] = pd.to_datetime(ds['time'].values)
    return ds

print("Loading station observations (NetCDF) ...")
obs_ds = xr.open_dataset(station_file)

print("Loading CHIRPS gridded dataset (NetCDF) ...")
chirps_ds = xr.open_dataset(chirps_file)

# Fix time coordinates if needed
if 'time' in obs_ds.coords:
    obs_ds['time'] = pd.to_datetime(obs_ds['time'].values)
chirps_ds = force_chirps_time(chirps_ds)

# Subset to 1991–2012
obs_ds   = obs_ds.sel(time=slice(start_date, end_date))
chirps_ds  = chirps_ds.sel(time=slice(start_date, end_date))

print("After subsetting:")
print(f"obs_ds time steps = {obs_ds.sizes['time']}")
print(f"chirps_ds time steps = {chirps_ds.sizes['time']}")

# If dims are (time, lon, lat), transpose to (time, lat, lon)
print("Checking dimension order of the CHIRPS variable ...")
actual_dims = chirps_ds[chirps_var].dims
print("Current dims for prcp:", actual_dims)
if actual_dims == ("time", "lon", "lat"):
    print("Transposing prcp from (time,lon,lat) -> (time,lat,lon).")
    chirps_ds[chirps_var] = chirps_ds[chirps_var].transpose("time", "lat", "lon")
    print("New dims:", chirps_ds[chirps_var].dims)
else:
    print("No transpose needed or dims are already (time,lat,lon).")

lats = chirps_ds['lat'].values
lons = chirps_ds['lon'].values
lon2d, lat2d = np.meshgrid(lons, lats)

if 'elevation' in chirps_ds:
    grid_elev = chirps_ds['elevation'].values
else:
    warnings.warn("No 'elevation' found in CHIRPS dataset => setting elev=0.")
    grid_elev = np.zeros_like(lat2d)

grid_lat_flat  = lat2d.flatten()
grid_lon_flat  = lon2d.flatten()
grid_elev_flat = grid_elev.flatten()

###############################################################################
# 4. LOAD STATIONS & MATCH
###############################################################################
print("Loading station metadata (CSV) ...")
stations_df = pd.read_csv(physical_file).dropna(axis=1, how='all')
print(stations_df.head())

print("Matching station names to NetCDF index ...")
station_names_nc = obs_ds['station'].values
station_index_map = {name: i for i, name in enumerate(station_names_nc)}
stations_df['netcdf_index'] = stations_df['NAME'].map(station_index_map)
stations_df = stations_df.dropna(subset=['netcdf_index']).reset_index(drop=True)
stations_df['netcdf_index'] = stations_df['netcdf_index'].astype(int)
print("Total matched stations:", len(stations_df))

###############################################################################
# 5. PRECOMPUTE NEIGHBORS FOR EACH STATION
###############################################################################
station_neighbors = {}
print(f"Precomputing neighbors within {search_radius_km} km...")

for i, row in stations_df.iterrows():
    st_lat = row['LATITUDE']
    st_lon = row['LONGITUDE']
    dist_all = haversine_distance(st_lat, st_lon, grid_lat_flat, grid_lon_flat)
    neighbor_idx = np.where(dist_all <= search_radius_km)[0]
    if len(neighbor_idx) < min_points:
        station_neighbors[i] = None
    else:
        station_neighbors[i] = neighbor_idx

print("Neighbor precomputation complete.")

###############################################################################
# 6. LOCAL WEIGHTED REGRESSION FUNCTION
###############################################################################
def local_weighted_regression(
    station_lat, station_lon, station_elev,
    neighbor_indices,
    grid_lat, grid_lon, grid_elev, grid_val
):
    """
    Weighted linear regression of grid_val ~ [1, lat, lon, elev],
    with tricube weighting based on station->grid distance.
    Negative precipitation predictions are clipped to zero.
    """
    if neighbor_indices is None or len(neighbor_indices) < min_points:
        return np.nan

    lat_n  = grid_lat[neighbor_indices]
    lon_n  = grid_lon[neighbor_indices]
    elev_n = grid_elev[neighbor_indices]
    val_n  = grid_val[neighbor_indices]

    # Compute distances and weights
    dist_n = haversine_distance(station_lat, station_lon, lat_n, lon_n)
    d_max  = dist_n.max()
    w = tricube_weight(dist_n, d_max)
    if np.all(w == 0):
        return np.nan

    # Build design matrix X: [1, lat, lon, elev]
    X = np.column_stack([np.ones_like(lat_n), lat_n, lon_n, elev_n])
    sqrt_w = np.sqrt(w)
    X_w = X * sqrt_w[:, None]
    y_w = val_n * sqrt_w

    # Solve weighted least squares
    try:
        beta, residuals, rank, s = np.linalg.lstsq(X_w, y_w, rcond=None)
    except np.linalg.LinAlgError:
        return np.nan

    # Predict at station location
    X_station = np.array([1, station_lat, station_lon, station_elev])
    pred = X_station @ beta

    # Clip negative precipitation to zero
    if pred < 0:
        pred = 0.0

    return pred

###############################################################################
# 7. DAILY LOOP: CHIRPS -> STATIONS
###############################################################################
results = []
chirps_times = chirps_ds['time'].values
print(f"Processing CHIRPS days: {len(chirps_times)}")

for t_index, t_val in enumerate(chirps_times):
    current_time = pd.to_datetime(t_val)
    day_data_2d = chirps_ds[chirps_var].isel(time=t_index).values  # shape => (lat, lon)
    grid_val_flat = day_data_2d.flatten()

    obs_day = obs_ds[obs_var_name].isel(time=t_index).values

    for i, row in stations_df.iterrows():
        st_name   = row['NAME']
        st_lat    = row['LATITUDE']
        st_lon    = row['LONGITUDE']
        st_elev   = row['Elevation']  # Use station elevation if available; else 0
        netcdf_idx = row['netcdf_index']

        st_obs = obs_day[netcdf_idx] if netcdf_idx < len(obs_day) else np.nan

        neigh_idx = station_neighbors[i]
        st_chirps = local_weighted_regression(
            station_lat=st_lat,
            station_lon=st_lon,
            station_elev=st_elev,
            neighbor_indices=neigh_idx,
            grid_lat=grid_lat_flat,
            grid_lon=grid_lon_flat,
            grid_elev=grid_elev_flat,
            grid_val=grid_val_flat
        )

        results.append({
            'time': current_time,
            'station_index': i,
            'station_name': st_name,
            'obs': st_obs,
            'chirps_val': st_chirps
        })

    if (t_index + 1) % 100 == 0:
        print(f"Processed day {t_index + 1} / {len(chirps_times)}")

###############################################################################
# 8. BUILD A DATAFRAME & SAVE
###############################################################################
results_df = pd.DataFrame(results)
results_df['time'] = pd.to_datetime(results_df['time'])
print("Sample of daily results:\n", results_df.head())

# Save the daily loop CSV (note the filename reflects 50km radius)
daily_loop_csv = os.path.join(daily_loop_dir, "chirps_vs_stations_25km_LWR_1991_2012.csv")
results_df.to_csv(daily_loop_csv, index=False)
print(f"Daily interpolation results saved to {daily_loop_csv}")

###############################################################################
# 9. (OPTIONAL) MAPPING
###############################################################################
try:
    print("Creating a quick map of average CHIRPS precipitation at each station ...")
    glb  = gpd.read_file(shapefile_path).to_crs(target_crs)
    lakes = gpd.read_file(lakes_shp).to_crs(target_crs)

    station_mean = results_df.groupby('station_index')['chirps_val'].mean().reset_index()
    merged = stations_df.copy()
    merged['mean_chirps'] = station_mean['chirps_val']

    geometry = [Point(lon, lat) for lon, lat in zip(merged['LONGITUDE'], merged['LATITUDE'])]
    stations_gdf = gpd.GeoDataFrame(merged, geometry=geometry, crs="EPSG:4326")
    stations_gdf = stations_gdf.to_crs(target_crs)

    fig, ax = plt.subplots(figsize=(10, 8))
    glb.boundary.plot(ax=ax, edgecolor='black')
    lakes.plot(ax=ax, color='blue', alpha=0.5)
    stations_gdf.plot(column='mean_chirps', ax=ax, legend=True,
                      cmap='viridis', markersize=50)
    ax.set_title("Mean CHIRPS Precip (LWR, 25km radius), 1991-2012")
    plt.show()
except Exception as e:
    print("Mapping step skipped due to error or missing shapefiles:", str(e))

###############################################################################
# 10. METRICS & SAVING
###############################################################################
def remove_nan_pairs(obs, pred):
    mask = ~np.isnan(obs) & ~np.isnan(pred)
    return obs[mask], pred[mask]

def mean_bias_error(obs, pred):
    return np.mean(pred - obs)

def root_mean_square_error(obs, pred):
    return np.sqrt(np.mean((pred - obs)**2))

def std_of_residuals(obs, pred):
    return np.std(pred - obs, ddof=1)

def pearson_correlation(obs, pred):
    if len(obs) < 2:
        return np.nan
    return np.corrcoef(obs, pred)[0, 1]

def index_of_agreement(obs, pred):
    obs_mean = np.mean(obs)
    numerator = np.sum((pred - obs)**2)
    denominator = np.sum((np.abs(pred - obs_mean) + np.abs(obs - obs_mean))**2)
    if denominator == 0:
        return np.nan
    return 1 - numerator / denominator

obs_all = results_df['obs'].values
chirps_all = results_df['chirps_val'].values
obs_all, chirps_all = remove_nan_pairs(obs_all, chirps_all)

if len(obs_all) > 0:
    metrics_all = {
        'MBE': mean_bias_error(obs_all, chirps_all),
        'RMSE': root_mean_square_error(obs_all, chirps_all),
        'STD': std_of_residuals(obs_all, chirps_all),
        'CC': pearson_correlation(obs_all, chirps_all),
        'Index_of_Agreement': index_of_agreement(obs_all, chirps_all)
    }
    print("\nOverall Metrics (All Stations, LWR, 1991-2012):")
    for k, v in metrics_all.items():
        print(f"{k}: {v:.4f}")
else:
    print("\nNo valid (obs, chirps_val) pairs found for overall metrics.")

# Station-level metrics
station_groups = results_df.groupby('station_index')
per_station_metrics = []
for st_idx, grp in station_groups:
    obs_st = grp['obs'].values
    chirps_st = grp['chirps_val'].values
    obs_st, chirps_st = remove_nan_pairs(obs_st, chirps_st)
    if len(obs_st) == 0:
        per_station_metrics.append({
            'station_index': st_idx,
            'station_name': grp['station_name'].iloc[0],
            'MBE': np.nan,
            'RMSE': np.nan,
            'STD': np.nan,
            'CC': np.nan,
            'Index_of_Agreement': np.nan
        })
        continue
    row_dict = {
        'station_index': st_idx,
        'station_name': grp['station_name'].iloc[0],
        'MBE':  mean_bias_error(obs_st, chirps_st),
        'RMSE': root_mean_square_error(obs_st, chirps_st),
        'STD':  std_of_residuals(obs_st, chirps_st),
        'CC':   pearson_correlation(obs_st, chirps_st),
        'Index_of_Agreement': index_of_agreement(obs_st, chirps_st)
    }
    per_station_metrics.append(row_dict)

metrics_df = pd.DataFrame(per_station_metrics)
print("\nSample of per-station metrics:")
print(metrics_df.head())

metrics_csv = os.path.join(metrics_dir, "station_metrics_25km_LWR_1991_2012.csv")
metrics_df.to_csv(metrics_csv, index=False)
print(f"Station metrics saved to {metrics_csv}")

###############################################################################
# 11. (OPTIONAL) SAVE FULL DAILY RESULTS AS NETCDF
###############################################################################
try:
    print("\nSaving daily station results to NetCDF file in daily_loop folder...")
    netcdf_file = os.path.join(daily_loop_dir, "chirps_vs_stations_25km_LWR_1991_2012.nc")

    # Pivot so that 'time' is rows and 'station_index' are columns
    pivot_chirps = results_df.pivot(index='time', columns='station_index', values='chirps_val')
    pivot_obs    = results_df.pivot(index='time', columns='station_index', values='obs')

    # Ensure pivoted index is in datetime format
    pivot_chirps.index = pd.to_datetime(pivot_chirps.index, errors='coerce')
    pivot_obs.index    = pd.to_datetime(pivot_obs.index, errors='coerce')

    # Convert each pivoted DataFrame into an xarray Dataset
    ds_chirps = pivot_chirps.to_xarray()   # dims might be ('time', 'station_index')
    ds_obs    = pivot_obs.to_xarray()

    # Rename dims if necessary
    if 'index' in ds_chirps.dims:
        ds_chirps = ds_chirps.rename({'index': 'time'})
    if 'columns' in ds_chirps.dims:
        ds_chirps = ds_chirps.rename({'columns': 'station_index'})
    if 'index' in ds_obs.dims:
        ds_obs = ds_obs.rename({'index': 'time'})
    if 'columns' in ds_obs.dims:
        ds_obs = ds_obs.rename({'columns': 'station_index'})

    # Convert each into a DataArray and combine into one Dataset
    da_chirps = ds_chirps.to_array(name='chirps_val').squeeze()
    da_obs    = ds_obs.to_array(name='obs').squeeze()

    ds_out = xr.Dataset({
        'chirps_val': da_chirps,
        'obs':        da_obs
    })

    ds_out.to_netcdf(netcdf_file)
    print(f"NetCDF saved to: {netcdf_file}")

except Exception as e:
    print("Error saving NetCDF:", e)
