
# ETCCDI Multi-Year Mean + Time-Series Stats + Linear Trend

**How to use:** Edit only the **Parameters** cell below, then run all cells (Kernel → Restart & Run All).
This notebook:
1) Loads an ETCCDI NetCDF
2) Computes the **multi-year mean** and plots a **map** + **spatial histogram** + **spatial stats**
3) Computes **annual regional values** (area-weighted mean) and plots a **time series** + **histogram across years**
4) Fits a **linear trend** to the annual values and plots it, with slope and R² 


In [None]:
# Step 1 — List available NetCDF files 
import os, glob

# Directory containing your ETCCDI NetCDF outputs
DATA_DIR = "/home/jovyan/shared/NEX-GDDP-CMIP6/ETCCDI/"  # or user-calculated indices in ./etccdi_out/"

# Gather *.nc files
NC_FILES = sorted([os.path.basename(p) for p in glob.glob(os.path.join(DATA_DIR, "*.nc"))])

print(f"Found {len(NC_FILES)} .nc files in {DATA_DIR}")
for i, name in enumerate(NC_FILES, 1):
    print(f"{i:2d}. {name}")

In [None]:
# Step 2 — Parameters (EDIT by uncommenting and setting values; no defaults)
#
# Required — pick one file from the list printed in Step 1:
FILE_NAME = "ETCCDI_pr_GISS-E2-1-G_historical_annual_1985-2014_Rx5day.nc"
#

FILE_PATH = DATA_DIR + FILE_NAME

REGION_METHOD = "country"           # or "bbox"
COUNTRY_NAME = "Thailand"

BBOX = (100.0, 105.0, 10.0, 15.0)   # (minimum longitude, maximum longitude, minimum latitude, maximum latitude)
SHAPEFILE_PATH = "./shapefile"
BORDERS = True
CMAP = "viridis"

# --- Outputs ---
SAVE_OUTPUTS = True                # True to write figures/tables
OUT_DIR = "./outputs"               # where to save outputs


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

# === Helpers (imports and functions) =========================================
import os, glob, math, re
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# Optional vectorized shapely helper (non-fatal if missing)
try:
    from shapely import vectorized as shp_vec
except Exception:
    shp_vec = None

# Geopandas and shapely ops for boundaries/masks
try:
    import geopandas as gpd
    from shapely.ops import unary_union
except Exception as e:
    gpd = None
    unary_union = None

def _standardize_latlon(da: xr.DataArray) -> xr.DataArray:
    # Rename common coordinate names to lat/lon
    ren = {}
    for cand, std in (("latitude","lat"), ("Latitude","lat"), ("y","lat"),
                      ("longitude","lon"), ("Longitude","lon"), ("x","lon")):
        if cand in da.coords and std not in da.coords:
            ren[cand] = std
    if ren:
        da = da.rename(ren)
    if ("lat" not in da.coords) or ("lon" not in da.coords):
        raise ValueError("Data must have 'lat' and 'lon' coordinates.")
    return da

def ensure_lon_neg180_180(da: xr.DataArray) -> xr.DataArray:
    if "lon" not in da.dims:
        return da
    lon = da["lon"].values
    if np.nanmax(lon) > 180.0:
        lon_new = ((lon + 180.0) % 360.0) - 180.0
        da = da.assign_coords(lon=lon_new).sortby("lon")
    return da

def load_dataarray(nc_path: str) -> xr.DataArray:
    if not os.path.exists(nc_path):
        raise FileNotFoundError(f"NetCDF not found: {nc_path}")
    ds = xr.open_dataset(nc_path)
    data_vars = list(ds.data_vars)
    if not data_vars:
        raise ValueError(f"No data variables found in {nc_path}")
    # pick variable with lat/lon dims if ambiguous
    var = data_vars[0]
    for v in data_vars:
        if set(ds[v].dims) >= {"lat","lon"}:
            var = v; break
    da = ds[var]
    da.name = var
    return _standardize_latlon(da)

def _resolve_shp(shapefile_path: str) -> str:
    if not shapefile_path:
        return None
    if shapefile_path.lower().endswith(".shp") and os.path.exists(shapefile_path):
        return shapefile_path
    if os.path.isdir(shapefile_path):
        cands = sorted(glob.glob(os.path.join(shapefile_path, "*admin_0_countries*.shp")))
        return cands[0] if cands else None
    return shapefile_path if os.path.exists(shapefile_path) else None

def load_country_polygon(country: str, shapefile_path: str):
    """
    Read Natural Earth admin_0 countries and union the polygon(s) for `country`.
    Prefers NAME, else ADMIN; case-insensitive exact match, then contains.
    Accepts a folder (auto-finds the .shp) or a .shp path.
    """
    if not country:
        return None
    if gpd is None or unary_union is None:
        return None

    shp = _resolve_shp(shapefile_path)
    if not shp:
        return None

    try:
        gdf = gpd.read_file(shp)
    except Exception:
        try:
            gdf = gpd.read_file(shp, engine="pyogrio")
        except Exception:
            return None

    # Ensure WGS84
    try:
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(4326)
    except Exception:
        pass

    name_col = "NAME" if "NAME" in gdf.columns else ("ADMIN" if "ADMIN" in gdf.columns else None)
    if name_col is None:
        return None

    vals_l = gdf[name_col].astype(str).str.lower().str.strip()
    target = str(country).lower().strip()

    sel = gdf[vals_l == target]
    if len(sel) == 0:
        sel = gdf[vals_l.str.contains(target, na=False)]
    if len(sel) == 0:
        return None

    return unary_union(sel.geometry.values)

def mask_da_by_polygon(da: xr.DataArray, polygon):
    if polygon is None:
        return da
    lons = da["lon"].values
    lats = da["lat"].values
    lon2d, lat2d = np.meshgrid(lons, lats)
    # Fast path if available
    try:
        if shp_vec is not None and hasattr(shp_vec, "contains"):
            geom = polygon.buffer(0)
            mask = shp_vec.contains(geom, lon2d, lat2d)
            return da.where(mask)
    except Exception:
        pass
    # Fallback: bbox + prepared geometry per point
    try:
        from shapely import prepared as shapely_prepared
        from shapely.geometry import Point
        minx, miny, maxx, maxy = polygon.bounds
        bbox_mask = (lon2d >= minx) & (lon2d <= maxx) & (lat2d <= maxy) & (lat2d >= miny)
        mask = np.zeros_like(lon2d, dtype=bool)
        if bbox_mask.any():
            prep = shapely_prepared.prep(polygon)
            for (i, j) in np.argwhere(bbox_mask):
                if prep.contains(Point(float(lon2d[i, j]), float(lat2d[i, j]))):
                    mask[i, j] = True
        return da.where(mask)
    except Exception:
        return da

def mask_da_by_bbox(da: xr.DataArray, bbox):
    minlon, minlat, maxlon, maxlat = bbox
    lon2d, lat2d = np.meshgrid(da["lon"].values, da["lat"].values)
    mask = (lon2d >= minlon) & (lon2d <= maxlon) & (lat2d >= minlat) & (lat2d <= maxlat)
    return da.where(mask)

def plot_map(da_masked: xr.DataArray, title: str, xlim=None, ylim=None, boundary_gdf=None, cmap=None, lw_boundary=0.8, boundary_color='k', zorder_boundary=10):
    X, Y = np.meshgrid(da_masked["lon"].values, da_masked["lat"].values)
    plt.figure(figsize=(9,4.5))
    m = plt.pcolormesh(X, Y, da_masked.values, shading='auto', cmap=(cmap or "viridis"))
    cb = plt.colorbar(m)
    units = da_masked.attrs.get("units","")
    if units:
        cb.set_label(units)
    # overlay boundaries if provided
    if boundary_gdf is not None:
        ax = plt.gca()
        try:
            boundary_gdf.boundary.plot(ax=ax, linewidth=0.5, color='k')
        except Exception:
            pass
    plt.xlabel("Longitude"); plt.ylabel("Latitude"); plt.title(title)
    if xlim is not None: plt.xlim(xlim)
    if ylim is not None: plt.ylim(ylim)
    plt.tight_layout(); plt.show()

def plot_histogram(da_masked: xr.DataArray, title: str):
    flat = da_masked.values[np.isfinite(da_masked.values)]
    plt.figure(figsize=(6,4))
    plt.hist(flat, bins=40)
    units = da_masked.attrs.get("units","")
    plt.xlabel(f"{da_masked.name} ({units})".strip())
    plt.ylabel("Grid-cell count")
    plt.title(title)
    plt.tight_layout(); plt.show()

def regional_stats(da_masked: xr.DataArray):
    flat = da_masked.values[np.isfinite(da_masked.values)]
    return pd.Series(flat).describe().rename({"25%":"q25","50%":"median","75%":"q75"})

def to_linework_gdf(gdf, drop_points=True):
    if gdf is None:
        return None
    try:
        poly_mask = gdf.geometry.geom_type.isin(['Polygon','MultiPolygon'])
        if poly_mask.any():
            gdf = gdf.loc[poly_mask].set_geometry(gdf.loc[poly_mask].boundary)
        if drop_points:
            gdf = gdf.loc[~gdf.geometry.geom_type.isin(['Point','MultiPoint'])]
        gdf = gdf.loc[gdf.geometry.geom_type.isin(['LineString','MultiLineString'])]
        return gdf if len(gdf) else None
    except Exception:
        return None

def build_lowres_borders(shapefile_path: str):
    if gpd is None:
        return None
    shp_full = _resolve_shp(shapefile_path)
    if not shp_full or not os.path.exists(shp_full):
        return None
    try:
        gdf = gpd.read_file(shp_full)
    except Exception:
        try:
            gdf = gpd.read_file(shp_full, engine="pyogrio")
        except Exception:
            return None
    try:
        if gdf.crs is not None and gdf.crs.to_epsg() != 4326:
            gdf = gdf.to_crs(4326)
    except Exception:
        pass
    return to_linework_gdf(gdf, drop_points=True)

def plot_map_cartopy(da_masked: xr.DataArray, title: str, xlim=None, ylim=None, cmap='viridis'):
    """Plot with Cartopy coastlines + national borders (low-res), always showing colorbar."""
    import numpy as np
    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
    from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

    lons = da_masked['lon'].values
    lats = da_masked['lat'].values
    data = da_masked.values

    proj = ccrs.PlateCarree()
    fig, ax = plt.subplots(figsize=(9,4.5), subplot_kw=dict(projection=proj))
    m = ax.pcolormesh(lons, lats, data, transform=proj, shading='auto', cmap=cmap)
    cb = plt.colorbar(m, ax=ax)
    units = da_masked.attrs.get('units','')
    if units:
        cb.set_label(units)

    # Cartopy borders & coastlines 
    ax.coastlines(resolution='50m', linewidth=0.6, color='0.2')
    ax.add_feature(cfeature.BORDERS.with_scale('10m'), linewidth=0.6, edgecolor='k')

    # Set extent if provided
    if (xlim is not None) and (ylim is not None):
        ax.set_extent([xlim[0], xlim[1], ylim[0], ylim[1]], crs=proj)

    # ticks/labels
    try:
        ax.set_xticks(np.linspace(xlim[0], xlim[1], 5) if xlim else np.linspace(float(lons.min()), float(lons.max()), 5), crs=proj)
        ax.set_yticks(np.linspace(ylim[0], ylim[1], 5) if ylim else np.linspace(float(lats.min()), float(lats.max()), 5), crs=proj)
        ax.xaxis.set_major_formatter(LONGITUDE_FORMATTER)
        ax.yaxis.set_major_formatter(LATITUDE_FORMATTER)
    except Exception:
        pass

    ax.set_title(title)
    plt.tight_layout()
    plt.show()


In [None]:
# === Load data and compute multi-year mean ===================================
os.makedirs(OUT_DIR, exist_ok=True) if SAVE_OUTPUTS else None
nc_path = FILE_PATH
da = load_dataarray(nc_path)
da = ensure_lon_neg180_180(da)
mean_da = da.mean("time", skipna=True) if "time" in da.dims else da
mean_da = mean_da.astype("float32")


In [None]:
# === Apply region mask and set axis limits ===================================
region_label = "global"
masked = mean_da
xlim = None; ylim = None
boundary_gdf = None

if REGION_METHOD == "country":
    poly = load_country_polygon(COUNTRY_NAME, SHAPEFILE_PATH)
    if poly is not None:
        masked = mask_da_by_polygon(mean_da, poly)
        region_label = COUNTRY_NAME
        try:
            minx, miny, maxx, maxy = poly.bounds
            pad_x = max(0.5, 0.05*(maxx-minx))
            pad_y = max(0.5, 0.05*(maxy-miny))
            xlim = (minx - pad_x, maxx + pad_x)
            ylim = (miny - pad_y, maxy + pad_y)
        except Exception:
            xlim = None; ylim = None
elif REGION_METHOD == "bbox":
    masked = mask_da_by_bbox(mean_da, BBOX)
    region_label = "bbox"
    try:
        minlon, minlat, maxlon, maxlat = BBOX
        xlim = (minlon, maxlon); ylim = (minlat, maxlat)
    except Exception:
        xlim = None; ylim = None




In [None]:
# === Map: multi-year mean with colorbar, units, and boundaries ===============
title = f"{masked.name} multi-year mean ({masked.attrs.get('units','')}) — {region_label}"
plot_map_cartopy(masked, title, xlim=xlim, ylim=ylim, cmap=CMAP)


In [None]:
# === Histogram of spatial distribution (multi-year mean) =====================
plot_histogram(masked, f"{masked.name} distribution — {region_label}")


In [None]:
# === Spatial statistics table (multi-year mean) ==============================
stats_df = regional_stats(masked)
display(stats_df)


In [None]:
# === Annual regional series and linear trend =================================
if "time" in da.dims:
    def area_weighted_mean_numpy(da2d):
        lat = da2d["lat"].values
        arr = np.asarray(da2d.values, dtype="float64")
        w = np.cos(np.deg2rad(lat))
        W = np.tile(w[:, None], (1, da2d.sizes["lon"]))
        valid = np.isfinite(arr) & np.isfinite(W)
        if np.any(valid):
            return float(np.nansum(arr[valid] * W[valid]) / np.nansum(W[valid]))
        return np.nan

    def apply_region_mask_2d(da2d):
        if REGION_METHOD == "country":
            poly = load_country_polygon(COUNTRY_NAME, SHAPEFILE_PATH)
            return mask_da_by_polygon(da2d, poly) if poly is not None else da2d
        elif REGION_METHOD == "bbox":
            return mask_da_by_bbox(da2d, BBOX)
        else:
            return da2d

    years = da["time"].dt.year.values
    year_labels = sorted(np.unique(years))
    annual_vals = []
    for y in year_labels:
        da_y = da.sel(time=str(y))
        da2d = da_y.mean("time", skipna=True) if "time" in da_y.dims else da_y
        da2d = ensure_lon_neg180_180(_standardize_latlon(da2d))
        da2d = apply_region_mask_2d(da2d)
        annual_vals.append(area_weighted_mean_numpy(da2d))

    ts = pd.Series(annual_vals, index=year_labels, name=f"{da.name} regional mean")
    display(ts.describe())

    x = np.asarray(ts.index, dtype=float)
    y = np.asarray(ts.values, dtype=float)
    mask = np.isfinite(x) & np.isfinite(y)
    x = x[mask]; y = y[mask]
    if x.size >= 3:
        xbar = np.nanmean(x); ybar = np.nanmean(y)
        Sxx = np.nansum((x - xbar)**2)
        Sxy = np.nansum((x - xbar)*(y - ybar))
        slope = float(Sxy / Sxx) if Sxx != 0 else np.nan
        intercept = float(ybar - slope*xbar)
        yhat = slope*x + intercept
        SSE = np.nansum((y - yhat)**2)
        SST = np.nansum((y - ybar)**2)
        r2 = float(1.0 - SSE/SST) if SST != 0 else np.nan

        print(f"Slope per year: {slope}")
        print(f"Slope per decade: {slope*10.0}")
        print(f"R^2: {r2}")

        plt.figure(figsize=(9,3.2))
        plt.plot(ts.index, ts.values, marker="o")
        plt.plot(ts.index, yhat, lw=2)
        plt.xlabel("Year"); plt.ylabel(f"{da.name} ({da.attrs.get('units','')})")
        plt.title(f"{da.name} regional mean — linear fit")
        plt.tight_layout(); plt.show()

        plt.figure(figsize=(6,4))
        plt.hist(ts.values[np.isfinite(ts.values)], bins=15)
        plt.xlabel(f"{da.name} ({da.attrs.get('units','')})")
        plt.ylabel("Year count")
        plt.title(f"{da.name} distribution across years")
        plt.tight_layout(); plt.show()
    else:
        print("Not enough points for a trend (need >= 3).")
else:
    print("No time dimension found; skipping annual series/trend.")


In [None]:

# === Save map, histogram, and stats to files (if enabled) ====================
if SAVE_OUTPUTS:
    base = os.path.splitext(os.path.basename(FILE_PATH))[0]
    map_png = os.path.join(OUT_DIR, f"{base}_{region_label}_map.png")
    hist_png = os.path.join(OUT_DIR, f"{base}_{region_label}_hist.png")
    csv_path = os.path.join(OUT_DIR, f"{base}_{region_label}_stats.csv")

    # Map (Cartopy: coastlines + borders, colorbar with units)
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
    proj = ccrs.PlateCarree()
    fig, ax = plt.subplots(figsize=(9,4.5), subplot_kw=dict(projection=proj))
    m = ax.pcolormesh(masked['lon'].values, masked['lat'].values, masked.values, transform=proj, shading='auto', cmap=CMAP)
    cb = plt.colorbar(m, ax=ax)
    units = masked.attrs.get('units','')
    if units: cb.set_label(units)
    ax.coastlines(resolution='110m', linewidth=0.6, color='0.2')
    ax.add_feature(cfeature.BORDERS.with_scale('110m'), linewidth=0.8, edgecolor='k')
    if (xlim is not None) and (ylim is not None):
        ax.set_extent([xlim[0], xlim[1], ylim[0], ylim[1]], crs=proj)
    plt.tight_layout(); plt.savefig(map_png, dpi=150); plt.close()

    # Histogram
    flat = masked.values[np.isfinite(masked.values)]
    plt.figure(figsize=(6,4))
    plt.hist(flat, bins=40)
    plt.xlabel(f"{masked.name} ({units})".strip()); plt.ylabel("Grid-cell count")
    plt.title(f"{masked.name} distribution — {region_label}")
    plt.tight_layout(); plt.savefig(hist_png, dpi=150); plt.close()

    stats_df.to_csv(csv_path, index=True)
    print("Saved:"); print(map_png); print(hist_png); print(csv_path)
