In [5]:
# ==== USER INPUTS (edit) ====

# Tag used in output filenames
region_tag = "ITnorth_core41"      # if balancing drops a cell, feel free to rename to core41 manually

# Years (inclusive)
start_year = 1982
end_year   = 2016
years      = list(range(start_year, end_year + 1))

# Seasons (list months in SEASON ORDER, start→end)
winter_months = [11,12,1,2,3,4,5,6]  # Nov–Jun (wraps over Dec→Jan)
spring_months = [3,4,5,6,7]          # Mar–Jul

# Paths
DERIVED     = r"..\italy_core_data\derived"
MASK_CSV    = r"..\italy_core_data\derived\mask_core_42_on_gdhy.csv"

# GDHY yield directories (per-year yield_YYYY.nc4 files)
GDHY_WINTER_DIR = r"..\data\wheat_winter"
GDHY_SPRING_DIR = r"..\data\wheat_spring"
GDHY_FILE_TMPL  = "yield_{year}.nc4"     # change if filenames differ
GDHY_VARNAME    = None                   # None → auto-detect (first data_var)

# ERA5 monthly averaged GRIBs (one file per year)
ERA5_DIR        = r"..\data\climate_monthly_full"
ERA5_FILE_TMPL  = "era5_land_monthly_{year}.grib"

# Units convenience
SSR_TO_MJ = False   # if True, also add solar_radiation_*_MJm2 columns

# Month labels
month_names = {1:"Jan",2:"Feb",3:"Mar",4:"Apr",5:"May",6:"Jun",7:"Jul",8:"Aug",9:"Sep",10:"Oct",11:"Nov",12:"Dec"}


In [6]:
from pathlib import Path
import xarray as xr
import pandas as pd
import numpy as np

DERIVED = Path(DERIVED); DERIVED.mkdir(exist_ok=True)

# --- mask & bbox from your 42-core cells ---
m42 = pd.read_csv(MASK_CSV).sort_values(["lat","lon"]).reset_index(drop=True)
LATC_ALL = m42["lat"].to_numpy()
LONC_ALL = m42["lon"].to_numpy()

def make_bbox(latc, lonc, pad=0.5):
    return dict(
        lat_min=float(latc.min() - pad),
        lat_max=float(latc.max() + pad),
        lon_min=float(lonc.min() - pad),
        lon_max=float(lonc.max() + pad),
    )

# --- file finders ---
def yield_path(gdhy_dir: Path, y: int) -> Path:
    p = gdhy_dir / GDHY_FILE_TMPL.format(year=y)
    if p.exists(): return p
    cands = list(gdhy_dir.glob(f"*{y}*.nc*"))
    if not cands:
        raise FileNotFoundError(f"No GDHY yield for {y} in {gdhy_dir}")
    return cands[0]

def era5_path(y: int) -> Path:
    p = Path(ERA5_DIR) / ERA5_FILE_TMPL.format(year=y)
    if p.exists(): return p
    cands = list(Path(ERA5_DIR).glob(f"*{y}*.grib"))
    if not cands:
        raise FileNotFoundError(f"No ERA5 GRIB for {y} in {ERA5_DIR}")
    return cands[0]

# --- ERA5 opener with cropping; returns DataArray (time, lat, lon) ---
def open_era5_var_year(y: int, short: str, xr_name: str, bbox: dict) -> xr.DataArray:
    f = era5_path(y)
    ds = xr.open_dataset(
        f, engine="cfgrib",
        backend_kwargs=dict(indexpath="", filter_by_keys={"shortName": short}),
        decode_timedelta=True,
    ).rename({"latitude":"lat","longitude":"lon"})
    if ds.lat[0] > ds.lat[-1]:
        ds = ds.sortby("lat")
    if ds.lon.min() < 0:
        ds = ds.assign_coords(lon=((ds.lon % 360 + 360) % 360)).sortby("lon")
    ds = ds.sel(lat=slice(bbox["lat_min"], bbox["lat_max"]),
                lon=slice(bbox["lon_min"], bbox["lon_max"]))
    da = ds[xr_name]
    if "expver" in da.dims:
        da = da.isel(expver=-1)
    return da

# --- 0.1° → 0.5° block-average to GDHY centers (offset 0.25°) ---
def bin_to_half_degree(da: xr.DataArray, step=0.5, offset=0.25) -> xr.DataArray:
    lat_bins = (offset + step * np.round((da["lat"].values - offset) / step)).astype(np.float64)
    lon_bins = (offset + step * np.round((da["lon"].values - offset) / step)).astype(np.float64)
    da = da.assign_coords(lat_bin=("lat", lat_bins), lon_bin=("lon", lon_bins))
    da_c = da.groupby("lat_bin").mean("lat").groupby("lon_bin").mean("lon")
    return da_c.rename({"lat_bin":"lat","lon_bin":"lon"})

# --- select a given set of cells by lat/lon value (nearest) ---
def select_cells(da: xr.DataArray, latc: np.ndarray, lonc: np.ndarray) -> xr.DataArray:
    out = da.sel(lat=xr.DataArray(latc, dims="cell"),
                 lon=xr.DataArray(lonc, dims="cell"),
                 method="nearest")
    return out.assign_coords(cell=("cell", np.arange(len(latc))),
                             lat=("cell", latc), lon=("cell", lonc))

# --- wrap-around season helpers (Dec→Jan) ---
def season_wraps(months: list[int]) -> bool:
    """True if the season crosses Dec→Jan (e.g., [11,12,1,...]). Assumes months are in season order."""
    return months != sorted(months)

def crop_year_for_series(month: pd.Series, year: pd.Series, pivot_month: int) -> pd.Series:
    """Year + 1 for months >= pivot_month (e.g., Nov/Dec go to next harvest year)."""
    return year + (month >= pivot_month).astype(int)


In [7]:
# ---- monthly reducers ----
def monthly_means(da_m05: xr.DataArray, months: list[int]) -> xr.DataArray:
    return da_m05.where(da_m05["time"].dt.month.isin(months), drop=True)

def monthly_totals_from_daily_means(da_m05: xr.DataArray, months: list[int]) -> xr.DataArray:
    sel = da_m05.where(da_m05["time"].dt.month.isin(months), drop=True)
    days = xr.DataArray(sel["time"].dt.days_in_month, coords={"time": sel["time"]}, dims=["time"]).astype(np.float64)
    return sel * days

# ---- seasonal reducers (crop-year aware) ----
def seasonal_reduce(da_m05: xr.DataArray, months: list[int], kind: str) -> xr.DataArray:
    """kind: 'mean' for state; 'sum' for flux monthly totals."""
    if kind == "mean":
        sel = monthly_means(da_m05, months)
    elif kind == "sum":
        sel = monthly_totals_from_daily_means(da_m05, months)
    else:
        raise ValueError("kind must be 'mean' or 'sum'")
    if season_wraps(months):
        pivot = months[0]
        cy = xr.DataArray(sel["time"].dt.year + (sel["time"].dt.month >= pivot).astype(int),
                          coords={"time": sel["time"]}, dims=["time"], name="year")
        out = sel.groupby(cy).mean("time") if kind == "mean" else sel.groupby(cy).sum("time")
    else:
        out = sel.groupby("time.year").mean("time") if kind == "mean" else sel.groupby("time.year").sum("time")
        out = out.rename({"year":"year"})
    return out

# ---- process one ERA5 variable (returns seasonal_df, monthly_wide_df) ----
def process_era5_var_cropyear(short: str, xr_name: str, out_col: str, vtype: str,
                              season_months: list[int], latc: np.ndarray, lonc: np.ndarray,
                              bbox: dict, years: list[int], month_names: dict, ssr_to_mj: bool=False):
    """
    Crop-year–aware processor (safe for wrapping seasons like Nov→Jun).
    Returns: (seasonal_df, monthly_wide_df) with harvest-year 'year'.
    vtype: 'state' (mean) or 'flux' (sum of monthly totals).
    """

    wraps = season_wraps(season_months)
    pivot = season_months[0] if wraps else None

    # Calendar years to open:
    # - non-wrap: exactly the requested years
    # - wrap: include the *year before* start so Nov–Dec for the first harvest year exist
    y0 = years[0] - 1 if wraps else years[0]
    y1 = years[-1]
    cal_years = list(range(y0, y1 + 1))

    monthly_rows = []

    for y in cal_years:
        da = open_era5_var_year(y, short, xr_name, bbox)

        # unit conversions on monthly fields
        if out_col == "temperature":
            da = da - 273.15
        if out_col == "precipitation":
            da = da * 1000.0          # m -> mm (daily mean)
        if out_col == "potential_evaporation":
            da = -da * 1000.0         # flip sign, m -> mm (daily mean)

        da05 = bin_to_half_degree(da)

        # Build monthly values for requested months (state vs flux)
        if vtype == "state":
            dam = monthly_means(da05, season_months)                        # monthly means
        elif vtype == "flux":
            dam = monthly_totals_from_daily_means(da05, season_months)      # monthly totals
        else:
            raise ValueError("vtype must be 'state' or 'flux'")

        # Select analysis cells and go to a long table
        sub = select_cells(dam, latc, lonc)
        df = sub.to_dataframe(name=out_col).reset_index()

        # Calendar month/year and days_in_month
        t = pd.to_datetime(df["time"])
        df["month"] = t.dt.month
        df["month_name"] = df["month"].map(month_names)
        df["cal_year"] = t.dt.year
        df["days"] = t.dt.days_in_month.astype(float)

        # Assign harvest year (crop-year)
        if wraps:
            df["year"] = df["cal_year"] + (df["month"] >= pivot).astype(int)
        else:
            df["year"] = df["cal_year"]

        # Keep only season months
        df = df[df["month"].isin(season_months)].copy()

        monthly_rows.append(df[["lat","lon","year","month","month_name","days",out_col]])

        if (y - cal_years[0]) % 5 == 0:
            print(f"{out_col}: processed {y}", flush=True)

    # All monthly rows across needed calendar years
    m = pd.concat(monthly_rows, ignore_index=True)

    # Keep only harvest years in the requested range
    m = m[(m["year"] >= years[0]) & (m["year"] <= years[-1])].copy()

    # ----- Seasonal from monthly (robust across wrap) -----
    keys = ["lat","lon","year"]

    if vtype == "flux":
        # Already monthly totals → seasonal total = sum
        seas = (m.groupby(keys, as_index=False)[out_col].sum())
    else:
        # State monthly means → seasonal days-weighted mean
        num = (m.assign(wx=m[out_col]*m["days"])
                 .groupby(keys, as_index=False)[["wx"]].sum()
                 .rename(columns={"wx": "num"}))
        den = (m.groupby(keys, as_index=False)[["days"]].sum()
                 .rename(columns={"days": "den"}))
        seas = num.merge(den, on=keys, how="inner")
        seas[out_col] = seas["num"] / seas["den"]
        seas = seas[keys + [out_col]]

    # Consistent sort
    seasonal_df = seas.sort_values(keys).reset_index(drop=True)

    # ----- Monthly wide for just this season -----
    mon_order = [month_names[mn] for mn in season_months]
    wide = (m.pivot_table(index=keys, columns="month_name", values=out_col, aggfunc="first")
              .reset_index())
    for mn in mon_order:
        if mn not in wide.columns:
            wide[mn] = np.nan
    wide = wide[keys + mon_order]
    wide = wide.rename(columns={mn: f"{out_col}_{mn}" for mn in mon_order})

    # Optional SSR copies in MJ/m²
    if out_col == "solar_radiation" and ssr_to_mj:
        for mn in mon_order:
            col = f"solar_radiation_{mn}"
            if col in wide.columns:
                wide[f"{col}_MJm2"] = wide[col] / 1e6

    return seasonal_df, wide

# ---- build yield panel for a given crop dir and cell list ----
def build_yield_panel(gdhy_dir: str, crop_name: str, latc: np.ndarray, lonc: np.ndarray) -> pd.DataFrame:
    gdhy_dir = Path(gdhy_dir)
    rows = []
    for y in years:
        f = yield_path(gdhy_dir, y)
        ds = xr.open_dataset(f)
        varname = GDHY_VARNAME or [v for v in ds.data_vars][0]
        da = ds[varname].squeeze()
        if set(("lat","lon")).issubset(da.dims):
            da = da.transpose("lat","lon")
        else:
            raise ValueError(f"{f} missing lat/lon dims; got {da.dims}")
        # map the requested lat/lon centers to indices
        lat_vals = da["lat"].values; lon_vals = da["lon"].values
        def idx(val, axis_vals):
            i = int(np.argmin(np.abs(axis_vals - val)))
            if not np.isclose(axis_vals[i], val, atol=5e-4):
                raise ValueError("Grid mismatch locating indices.")
            return i
        ilat = np.array([idx(v, lat_vals) for v in latc])
        ilon = np.array([idx(v, lon_vals) for v in lonc])
        vals = da.values[ilat, ilon].astype("float32")
        rows.append(pd.DataFrame({"year": y, "lat": latc, "lon": lonc, f"yield_{crop_name}": vals}))
    return pd.concat(rows, ignore_index=True).sort_values(["lat","lon","year"]).reset_index(drop=True)

# ---- enforce balanced panel by yield (keeps only cells with all years present) ----
def balance_by_yield(yield_df: pd.DataFrame, yield_col: str) -> tuple[pd.DataFrame, np.ndarray, np.ndarray]:
    yrs_needed = len(years)
    cov = (yield_df.assign(ok=~yield_df[yield_col].isna())
                   .groupby(["lat","lon"])["ok"].sum())
    keep_pairs = cov[cov == yrs_needed].index.to_list()
    keep = pd.DataFrame(keep_pairs, columns=["lat","lon"])
    out = (yield_df.merge(keep, on=["lat","lon"], how="inner")
                  .sort_values(["lat","lon","year"]).reset_index(drop=True))
    latc = keep["lat"].to_numpy()
    lonc = keep["lon"].to_numpy()
    return out, latc, lonc


In [8]:
# ===== WINTER WHEAT =====
crop_name = "wheat_winter"
latc = LATC_ALL.copy(); lonc = LONC_ALL.copy()
bbox = make_bbox(latc, lonc)

# 1) Yield panel
yield_df = build_yield_panel(GDHY_WINTER_DIR, crop_name, latc, lonc)
YIELD_COL = f"yield_{crop_name}"
yield_df, latc, lonc = balance_by_yield(yield_df, YIELD_COL)
bbox = make_bbox(latc, lonc)
print(f"Winter wheat: kept {len(np.unique(list(zip(latc,lonc)), axis=0))} cells after balancing.")

# 2) ERA5 variables
vars_cfg = [
    ("2t","t2m","temperature","state"),
    ("tp","tp","precipitation","flux"),
    ("swvl1","swvl1","soil_water","state"),
    ("ssr","ssr","solar_radiation","flux"),
    ("pev","pev","potential_evaporation","flux"),
]
seasonal_tables = []; monthly_tables = []
for short, xr_name, out_col, vtype in vars_cfg:
    seas, mon = process_era5_var_cropyear(short, xr_name, out_col, vtype,
                                          winter_months, latc, lonc, bbox, years, month_names, SSR_TO_MJ)
    seasonal_tables.append(seas); monthly_tables.append(mon)

# 3) Merge seasonal into base & save
base = yield_df.copy()
for tbl in seasonal_tables:
    base = base.merge(tbl, on=["lat","lon","year"], how="inner")
assert len(base) == len(yield_df) and base.isna().sum().sum()==0

out_seasonal = DERIVED / f"{crop_name}_{region_tag}_{start_year}_{end_year}_allstressors.csv"
base.to_csv(out_seasonal, index=False)
print("Saved →", out_seasonal)

# 4) Add monthly & save
enriched = base.copy()
for mon in monthly_tables:
    enriched = enriched.merge(mon, on=["lat","lon","year"], how="inner")
assert len(enriched) == len(base) and enriched.isna().sum().sum()==0

out_monthly = DERIVED / f"{crop_name}_{region_tag}_{start_year}_{end_year}_allstressors_with_monthly.csv"
enriched.to_csv(out_monthly, index=False)
print("Saved →", out_monthly)


Winter wheat: kept 41 cells after balancing.
temperature: processed 1981
temperature: processed 1986
temperature: processed 1991
temperature: processed 1996
temperature: processed 2001
temperature: processed 2006
temperature: processed 2011
temperature: processed 2016
precipitation: processed 1981
precipitation: processed 1986
precipitation: processed 1991
precipitation: processed 1996
precipitation: processed 2001
precipitation: processed 2006
precipitation: processed 2011
precipitation: processed 2016
soil_water: processed 1981
soil_water: processed 1986
soil_water: processed 1991
soil_water: processed 1996
soil_water: processed 2001
soil_water: processed 2006
soil_water: processed 2011
soil_water: processed 2016
solar_radiation: processed 1981
solar_radiation: processed 1986
solar_radiation: processed 1991
solar_radiation: processed 1996
solar_radiation: processed 2001
solar_radiation: processed 2006
solar_radiation: processed 2011
solar_radiation: processed 2016
potential_evaporati

In [9]:
# ===== SPRING WHEAT =====
crop_name = "wheat_spring"
latc = LATC_ALL.copy(); lonc = LONC_ALL.copy()
bbox = make_bbox(latc, lonc)

# 1) Yield panel
yield_df = build_yield_panel(GDHY_SPRING_DIR, crop_name, latc, lonc)
YIELD_COL = f"yield_{crop_name}"
yield_df, latc, lonc = balance_by_yield(yield_df, YIELD_COL)
bbox = make_bbox(latc, lonc)
print(f"Spring wheat: kept {len(np.unique(list(zip(latc,lonc)), axis=0))} cells after balancing.")

# 2) ERA5 variables
vars_cfg = [
    ("2t","t2m","temperature","state"),
    ("tp","tp","precipitation","flux"),
    ("swvl1","swvl1","soil_water","state"),
    ("ssr","ssr","solar_radiation","flux"),
    ("pev","pev","potential_evaporation","flux"),
]
seasonal_tables = []; monthly_tables = []
for short, xr_name, out_col, vtype in vars_cfg:
    seas, mon = process_era5_var_cropyear(short, xr_name, out_col, vtype,
                                          spring_months, latc, lonc, bbox, years, month_names, SSR_TO_MJ)
    seasonal_tables.append(seas); monthly_tables.append(mon)

# 3) Merge seasonal into base & save
base = yield_df.copy()
for tbl in seasonal_tables:
    base = base.merge(tbl, on=["lat","lon","year"], how="inner")
assert len(base) == len(yield_df) and base.isna().sum().sum()==0

out_seasonal = DERIVED / f"{crop_name}_{region_tag}_{start_year}_{end_year}_allstressors.csv"
base.to_csv(out_seasonal, index=False)
print("Saved →", out_seasonal)

# 4) Add monthly & save
enriched = base.copy()
for mon in monthly_tables:
    enriched = enriched.merge(mon, on=["lat","lon","year"], how="inner")
assert len(enriched) == len(base) and enriched.isna().sum().sum()==0

out_monthly = DERIVED / f"{crop_name}_{region_tag}_{start_year}_{end_year}_allstressors_with_monthly.csv"
enriched.to_csv(out_monthly, index=False)
print("Saved →", out_monthly)


Spring wheat: kept 41 cells after balancing.
temperature: processed 1982
temperature: processed 1987
temperature: processed 1992
temperature: processed 1997
temperature: processed 2002
temperature: processed 2007
temperature: processed 2012
precipitation: processed 1982
precipitation: processed 1987
precipitation: processed 1992
precipitation: processed 1997
precipitation: processed 2002
precipitation: processed 2007
precipitation: processed 2012
soil_water: processed 1982
soil_water: processed 1987
soil_water: processed 1992
soil_water: processed 1997
soil_water: processed 2002
soil_water: processed 2007
soil_water: processed 2012
solar_radiation: processed 1982
solar_radiation: processed 1987
solar_radiation: processed 1992
solar_radiation: processed 1997
solar_radiation: processed 2002
solar_radiation: processed 2007
solar_radiation: processed 2012
potential_evaporation: processed 1982
potential_evaporation: processed 1987
potential_evaporation: processed 1992
potential_evaporation:

In [11]:
# === STACK WINTER + SPRING WHEAT INTO ONE LONG FILE (TYPE-LABELED) ===
from pathlib import Path
import pandas as pd
import numpy as np

# ---- EDIT THESE 3 LINES ----
DERIVED    = Path(r"..\italy_core_data\derived")
region_tag = "ITnorth_core41"   # use the same tag you used for the per-type exports
start_year, end_year = 1982, 2016
# ----------------------------

winter_csv = DERIVED / f"wheat_winter_{region_tag}_{start_year}_{end_year}_allstressors_with_monthly.csv"
spring_csv = DERIVED / f"wheat_spring_{region_tag}_{start_year}_{end_year}_allstressors_with_monthly.csv"
out_csv    = DERIVED / f"wheat_both_{region_tag}_{start_year}_{end_year}_allstressors_with_monthly_long.csv"

# ---- helpers ----
def load_and_tag(p: Path, wheat_type: str) -> pd.DataFrame:
    """Load CSV, tag wheat_type, normalize yield_* -> yield_wheat."""
    df = pd.read_csv(p)
    df["wheat_type"] = wheat_type
    ycols = [c for c in df.columns if c.startswith("yield_")]
    if not ycols:
        raise ValueError(f"No yield_* column found in {p.name}")
    if "yield_wheat" not in df.columns:
        df["yield_wheat"] = df[ycols[0]]
        for c in ycols:
            if c != "yield_wheat":
                df.drop(columns=c, inplace=True)
    return df

dw = load_and_tag(winter_csv, "winter")
ds = load_and_tag(spring_csv, "spring")

# required keys & seasonals in each input
key_cols = ["lat","lon","year","wheat_type"]
seasonal_cols = ["yield_wheat","temperature","precipitation","soil_water","solar_radiation","potential_evaporation"]
for req in seasonal_cols:
    for d, name in [(dw, winter_csv.name), (ds, spring_csv.name)]:
        if req not in d.columns:
            raise ValueError(f"Missing seasonal column '{req}' in {name}")

# ---- robust monthly union (no string parsing) ----
mon_names = ["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]
base_vars = ["temperature","precipitation","soil_water","solar_radiation","potential_evaporation"]

def find_monthly_cols(df: pd.DataFrame):
    cols = []
    for b in base_vars:
        for m in mon_names:
            c = f"{b}_{m}"
            if c in df.columns:
                cols.append(c)
    return cols

cols_all = set(find_monthly_cols(dw)) | set(find_monthly_cols(ds))
monthly_union = [f"{b}_{m}" for b in base_vars for m in mon_names if f"{b}_{m}" in cols_all]

# ensure both frames have the full monthly schema (add NaNs where missing)
for c in monthly_union:
    if c not in dw.columns: dw[c] = np.nan
    if c not in ds.columns: ds[c] = np.nan

# order columns consistently
ordered_cols = key_cols + seasonal_cols + monthly_union
dw = dw[ordered_cols].sort_values(["lat","lon","year"]).reset_index(drop=True)
ds = ds[ordered_cols].sort_values(["lat","lon","year"]).reset_index(drop=True)

# stack long
combined = pd.concat([dw, ds], ignore_index=True).sort_values(["lat","lon","year","wheat_type"]).reset_index(drop=True)

# sanity checks
dups = combined.duplicated(["lat","lon","year","wheat_type"]).sum()
assert dups == 0, f"Found duplicated keys: {dups}"
assert "yield_wheat" in combined.columns

# save
combined.to_csv(out_csv, index=False)
print("Saved stacked long wheat dataset →", out_csv)
print("Shape:", combined.shape)
print("Columns (first 20):", combined.columns[:20].tolist())


Saved stacked long wheat dataset → ..\italy_core_data\derived\wheat_both_ITnorth_core41_1982_2016_allstressors_with_monthly_long.csv
Shape: (2870, 55)
Columns (first 20): ['lat', 'lon', 'year', 'wheat_type', 'yield_wheat', 'temperature', 'precipitation', 'soil_water', 'solar_radiation', 'potential_evaporation', 'temperature_Jan', 'temperature_Feb', 'temperature_Mar', 'temperature_Apr', 'temperature_May', 'temperature_Jun', 'temperature_Jul', 'temperature_Nov', 'temperature_Dec', 'precipitation_Jan']
