In [None]:
# === USER INPUTS ===

# Project/crop tags
crop          = "maize"                 # e.g., "maize", "wheat_winter", "soybean", ...
region_tag    = "ITnorth_core42"        # used in filenames (keep as-is if using same 42-core cells)

# Years and season
start_year    = 1982
end_year      = 2016                    # inclusive
season_months = [5, 6, 7, 8, 9]         # MJJAS for Po Valley maize

# Paths (edit to your structure)
BASE_DIR      = r"..\italy_core_data"
DERIVED_DIR   = r"..\italy_core_data\derived"    # outputs will be written here
MASK_CSV      = r"..\italy_core_data\derived\mask_core_42_on_gdhy.csv"

# GDHY yield files (per-year .nc4)
GDHY_DIR      = r"..\data\maize"       # folder containing yield_YYYY.nc4 for the selected crop
GDHY_FILE_TMPL= "yield_{year}.nc4"     # change if your files differ
GDHY_VARNAME  = None                   # None -> auto-detect first data_var, or set e.g. "var"

# ERA5-Land monthly averaged grib files (per-year)
ERA5_DIR      = r"..\data\climate_monthly_full"
ERA5_FILE_TMPL= "era5_land_monthly_{year}.grib"  # change if your files differ

# Optional: convert SSR to MJ/m² in monthly/seasonal outputs? (keeps J/m² if False)
SSR_TO_MJ     = False

# === end user inputs ===


In [None]:
from pathlib import Path
import xarray as xr
import pandas as pd
import numpy as np

# Resolve paths
DERIVED = Path(DERIVED_DIR); DERIVED.mkdir(exist_ok=True)
MASK_CSV = Path(MASK_CSV)
GDHY_DIR = Path(GDHY_DIR)
ERA5_DIR = Path(ERA5_DIR)

years = list(range(int(start_year), int(end_year) + 1))
month_names = {1:"Jan",2:"Feb",3:"Mar",4:"Apr",5:"May",6:"Jun",7:"Jul",8:"Aug",9:"Sep",10:"Oct",11:"Nov",12:"Dec"}
mon_order = ["May","Jun","Jul","Aug","Sep"]

# ---------- mask & bbox ----------
m42 = (pd.read_csv(MASK_CSV).sort_values(["lat","lon"]).reset_index(drop=True))
latc, lonc = m42["lat"].to_numpy(), m42["lon"].to_numpy()
ilat = m42["ilat"].to_numpy() if "ilat" in m42.columns else None
ilon = m42["ilon"].to_numpy() if "ilon" in m42.columns else None

# Small margin so binning catches edge cells
BBOX = dict(
    lat_min=float(latc.min() - 0.5),
    lat_max=float(latc.max() + 0.5),
    lon_min=float(lonc.min() - 0.5),
    lon_max=float(lonc.max() + 0.5),
)

# ---------- file locators ----------
def gdhyds_path(y:int)->Path:
    p = GDHY_DIR / GDHY_FILE_TMPL.format(year=y)
    if not p.exists():
        # fallback: any file containing year
        cands = list(GDHY_DIR.glob(f"*{y}*.nc*"))
        if not cands:
            raise FileNotFoundError(f"GDHY file not found for {y} in {GDHY_DIR}")
        return cands[0]
    return p

def era5_path(y:int)->Path:
    p = ERA5_DIR / ERA5_FILE_TMPL.format(year=y)
    if not p.exists():
        cands = list(ERA5_DIR.glob(f"*{y}*.grib"))
        if not cands:
            raise FileNotFoundError(f"ERA5 file not found for {y} in {ERA5_DIR}")
        return cands[0]
    return p

# ---------- ERA5 open + crop ----------
def open_era5_var_year(year:int, short:str, xr_name:str)->xr.DataArray:
    f = era5_path(year)
    ds = xr.open_dataset(
        f,
        engine="cfgrib",
        backend_kwargs=dict(indexpath="", filter_by_keys={"shortName": short}),
        decode_timedelta=True,
    ).rename({"latitude":"lat","longitude":"lon"})
    # ascending lat
    if ds.lat[0] > ds.lat[-1]:
        ds = ds.sortby("lat")
    # wrap lon to 0..360
    if ds.lon.min() < 0:
        ds = ds.assign_coords(lon=((ds.lon % 360 + 360) % 360)).sortby("lon")
    # crop to bbox
    ds = ds.sel(lat=slice(BBOX["lat_min"], BBOX["lat_max"]),
                lon=slice(BBOX["lon_min"], BBOX["lon_max"]))
    da = ds[xr_name]
    if "expver" in da.dims:
        da = da.isel(expver=-1)
    return da  # dims: time, lat, lon

# ---------- 0.1° -> 0.5° block-average to GDHY centers (offset 0.25°) ----------
def bin_to_half_degree(da: xr.DataArray, step=0.5, offset=0.25) -> xr.DataArray:
    lat_bins = (offset + step * np.round((da["lat"].values - offset) / step)).astype(np.float64)
    lon_bins = (offset + step * np.round((da["lon"].values - offset) / step)).astype(np.float64)
    da = da.assign_coords(lat_bin=("lat", lat_bins), lon_bin=("lon", lon_bins))
    da_c = da.groupby("lat_bin").mean("lat").groupby("lon_bin").mean("lon")
    return da_c.rename({"lat_bin":"lat","lon_bin":"lon"})

# ---------- monthly reductions ----------
def monthly_means(da_m05: xr.DataArray, months:list)->xr.DataArray:
    return da_m05.where(da_m05["time"].dt.month.isin(months), drop=True)

def monthly_totals_from_daily_means(da_m05: xr.DataArray, months:list)->xr.DataArray:
    sel = da_m05.where(da_m05["time"].dt.month.isin(months), drop=True)
    days = xr.DataArray(sel["time"].dt.days_in_month, coords={"time": sel["time"]}, dims=["time"]).astype(np.float64)
    return sel * days

# ---------- seasonal reductions ----------
def seasonal_mean(da_m05: xr.DataArray, months:list)->xr.DataArray:
    return monthly_means(da_m05, months).groupby("time.year").mean("time").rename({"year":"year"})

def seasonal_total_from_daily_means(da_m05: xr.DataArray, months:list)->xr.DataArray:
    return monthly_totals_from_daily_means(da_m05, months).groupby("time.year").sum("time").rename({"year":"year"})

# ---------- select exact 42 cells by lat/lon value (nearest) ----------
def select_core_cells(da: xr.DataArray)->xr.DataArray:
    out = da.sel(lat=xr.DataArray(latc, dims="cell"),
                 lon=xr.DataArray(lonc, dims="cell"),
                 method="nearest")
    out = out.assign_coords(cell=("cell", np.arange(len(latc))),
                            lat=("cell", latc), lon=("cell", lonc))
    return out

# ---------- build yield panel ----------
def build_yield_panel()->pd.DataFrame:
    rows = []
    for y in years:
        f = gdhyds_path(y)
        ds = xr.open_dataset(f)
        varname = GDHY_VARNAME or [v for v in ds.data_vars][0]
        da = ds[varname].squeeze()
        if set(("lat","lon")).issubset(da.dims):
            da = da.transpose("lat","lon")
        else:
            raise ValueError(f"{f} missing lat/lon dims; got {da.dims}")
        arr = da.values
        if ilat is None or ilon is None:
            # Fallback: map coords to indices
            lat_vals = da["lat"].values; lon_vals = da["lon"].values
            def idx(val, axis_vals): 
                i = int(np.argmin(np.abs(axis_vals - val)))
                assert np.isclose(axis_vals[i], val, atol=5e-4)
                return i
            i_lat = np.array([idx(v, lat_vals) for v in latc])
            i_lon = np.array([idx(v, lon_vals) for v in lonc])
        else:
            i_lat, i_lon = ilat, ilon
        vals = arr[i_lat, i_lon].astype("float32")
        rows.append(pd.DataFrame({"year": y, "lat": latc, "lon": lonc, "yield_maize": vals}))
    out = pd.concat(rows, ignore_index=True).sort_values(["lat","lon","year"]).reset_index(drop=True)
    return out

# ---------- process a single ERA5 variable to seasonal + monthly ----------
def process_era5_var(short:str, xr_name:str, out_col:str, vtype:str):
    """
    short/xr_name : GRIB shortName and xarray variable name (often same)
    out_col       : output column base name
    vtype         : 'state' (mean) or 'flux' (total from daily means)
    Returns: (seasonal_df, monthly_wide_df)
    """
    # collect monthly for all years
    monthly_rows = []
    seasonal_rows = []
    for y in years:
        da = open_era5_var_year(y, short, xr_name)       # regional 0.1°
        # unit conversions (per variable)
        if out_col == "temperature":
            da = da - 273.15
        if out_col == "precipitation":
            da = da * 1000.0       # m -> mm (daily mean)
        if out_col == "potential_evaporation":
            da = -da * 1000.0      # flip sign, m -> mm (daily mean)
        # block-average to 0.5°
        da05 = bin_to_half_degree(da)
        # monthly slice
        if vtype == "state":
            dam = monthly_means(da05, season_months)
        elif vtype == "flux":
            dam = monthly_totals_from_daily_means(da05, season_months)
        else:
            raise ValueError("vtype must be 'state' or 'flux'")
        # select 42 cells
        subm = select_core_cells(dam)
        dft = subm.to_dataframe(name=out_col).reset_index()
        dft["year"] = pd.to_datetime(dft["time"]).dt.year
        dft["month"] = pd.to_datetime(dft["time"]).dt.month
        dft["month_name"] = dft["month"].map(month_names)
        dft = dft[dft["month"].isin(season_months)]
        monthly_rows.append(dft[["lat","lon","year","month","month_name",out_col]])
        # seasonal
        if vtype == "state":
            ya = seasonal_mean(da05, season_months)
        else:
            ya = seasonal_total_from_daily_means(da05, season_months)
        suby = select_core_cells(ya)
        dfs = suby.to_dataframe(name=out_col).reset_index()[["lat","lon","year",out_col]]
        seasonal_rows.append(dfs)
        # progress
        if (y - years[0]) % 5 == 0:
            print(f"{out_col}: processed {y}", flush=True)
    # concat
    monthly_long = pd.concat(monthly_rows, ignore_index=True)
    seasonal_df  = pd.concat(seasonal_rows, ignore_index=True).sort_values(["lat","lon","year"]).reset_index(drop=True)
    # pivot monthly to wide (May..Sep)
    wide = monthly_long.pivot_table(index=["lat","lon","year"], columns="month_name", values=out_col, aggfunc="first").reset_index()
    for m in mon_order:
        if m not in wide.columns: wide[m] = np.nan
    wide = wide[["lat","lon","year"] + mon_order]
    wide = wide.rename(columns={m: f"{out_col}_{m}" for m in mon_order})
    # optional SSR MJ/m²
    if out_col == "solar_radiation" and SSR_TO_MJ:
        # add converted copies with _MJm2 suffix (keep originals too)
        for m in mon_order:
            wide[f"solar_radiation_{m}_MJm2"] = wide[f"solar_radiation_{m}"] / 1e6
        seasonal_df["solar_radiation_MJJAS_MJm2"] = seasonal_df["solar_radiation"] / 1e6
    return seasonal_df, wide


In [None]:
# ---------- 1) Build yield panel ----------
yield_df = build_yield_panel()
print("Yield panel:", yield_df.shape, "| years", yield_df["year"].min(), "–", yield_df["year"].max())

# ---------- 2) Process ERA5 variables ----------
# Map of variables to process (shortName, xr_name, out_col, type)
vars_cfg = [
    ("2t",   "t2m",   "temperature",           "state"),
    ("tp",   "tp",    "precipitation",         "flux"),
    ("swvl1","swvl1", "soil_water",            "state"),
    ("ssr",  "ssr",   "solar_radiation",       "flux"),
    ("pev",  "pev",   "potential_evaporation", "flux"),
]

seasonal_tables = []
monthly_tables  = []

for short, xr_name, out_col, vtype in vars_cfg:
    seas, mon = process_era5_var(short, xr_name, out_col, vtype)
    seasonal_tables.append(seas)
    monthly_tables.append(mon)

# ---------- 3) Merge seasonal into base ----------
base = yield_df.copy()
for tbl in seasonal_tables:
    base = base.merge(tbl, on=["lat","lon","year"], how="inner")

# Basic checks
assert len(base) == len(yield_df), "Row count changed unexpectedly after seasonal merge"
assert base.isna().sum().sum() == 0, "NaNs found after seasonal merge"

# ---------- 4) Save seasonal CSV ----------
base_name = f"{crop}_{region_tag}_{start_year}_{end_year}_allstressors.csv"
base_csv  = DERIVED / base_name
base.to_csv(base_csv, index=False)
print("Saved seasonal dataset →", base_csv)

# ---------- 5) Merge monthly wide tables ----------
enriched = base.copy()
for mon in monthly_tables:
    enriched = enriched.merge(mon, on=["lat","lon","year"], how="inner")

assert len(enriched) == len(base), "Row count changed unexpectedly after monthly merge"
assert enriched.isna().sum().sum() == 0, "NaNs found after monthly merge"

# ---------- 6) Save enriched CSV ----------
enriched_name = f"{crop}_{region_tag}_{start_year}_{end_year}_allstressors_with_monthly.csv"
enriched_csv  = DERIVED / enriched_name
enriched.to_csv(enriched_csv, index=False)
print("Saved enriched dataset with monthly columns →", enriched_csv)

# Preview
print("Columns (first 16):", list(enriched.columns)[:16], "…")
print("Rows:", len(enriched))


In [None]:
# Quick sanity checks on ranges (prints only)
cols_check = ["yield_maize","temperature","precipitation","soil_water","solar_radiation","potential_evaporation"]
print("Seasonal ranges:")
for c in cols_check:
    lo, hi = enriched[c].min(), enriched[c].max()
    print(f"  {c:>22s}: min={lo:.3f}  max={hi:.3f}")

# Confirm May–Sep columns exist
must_have = [f"temperature_{m}" for m in mon_order] + [f"precipitation_{m}" for m in mon_order]
missing = [c for c in must_have if c not in enriched.columns]
print("Monthly cols missing:", missing)
