# ETCCDI **pr** Indices Driver

Run cells **top-to-bottom**. Steps:
1. Parameters
2. Helpers
3. Locate & load data
4. Compute indices
5. Save to NetCDF


In [None]:

# === Step 1: Parameters ===
ROOT = "/home/jovyan/shared/NEX-GDDP-CMIP6/"
MODEL = "GISS-E2-1-G"
SCENARIO = "historical"   # "historical","ssp126","ssp245","ssp370","ssp585"
MEMBER = "r1i1p1f2"
START_YEAR = 1985
END_YEAR   = 2014
SEASON = "annual"         # "annual","MAM","JJA","SON","DJF"
OUTDIR = "./etccdi_out"


In [2]:

# === Step 2: Helpers ===
from pathlib import Path
import re, warnings
import xarray as xr
import numpy as np

try:
    from tqdm.auto import tqdm
except Exception:
    def tqdm(x, **kwargs): return x

SEASONS = ("annual","MAM","JJA","SON","DJF")

class Args:
    def __init__(self, model, scenario, member, start_year, end_year, season):
        self.model = model
        self.scenario = scenario
        self.member = member
        self.start_year = start_year
        self.end_year = end_year
        self.season = season

def _validate_args(args, available_scenarios=("historical","ssp126","ssp245","ssp370","ssp585")):
    if args.scenario == "historical":
        if args.start_year < 1950 or args.end_year > 2014:
            raise ValueError("Historical is only valid for 1950–2014.")
    else:
        if args.start_year < 2015 or args.end_year > 2100:
            raise ValueError("Future SSP scenarios are only valid for 2015–2100.")
    if args.season not in SEASONS:
        raise ValueError(f"season must be one of {SEASONS}")
    if args.end_year < args.start_year:
        raise ValueError("end_year must be >= start_year")

def _suppress_nan_warnings():
    warnings.filterwarnings("ignore", message="All-NaN slice encountered", category=RuntimeWarning)
    warnings.filterwarnings("ignore", message="Mean of empty slice", category=RuntimeWarning)

def _year_from_name(name: str):
    m = re.search(r'_(\d{4})_v', name)
    return int(m.group(1)) if m else None

def _select_years(files, start_year, end_year):
    out = []
    for p in files:
        y = _year_from_name(Path(p).name)
        if y is not None and start_year <= y <= end_year:
            out.append(Path(p))
    return sorted(out)

def _scan_standard(root: str, model: str, scenario: str, member: str, varname: str):
    return sorted((Path(root) / model / scenario / member / varname).glob("*.nc"))

def _scan_recursive(root: str, model: str, scenario: str, member: str, varname: str):
    out = []
    root = Path(root)
    for p in root.rglob("*.nc"):
        n = p.name
        if n.startswith(f"{varname}_day_") and f"_{model}_" in n and f"_{scenario}_" in n and f"_{member}_" in n:
            out.append(p)
    return sorted(out)

def _build_file_list_flexible(root: str, model: str, scenario: str, member: str, varname: str,
                              start_year: int, end_year: int):
    std_dir = Path(root) / model / scenario / member / varname
    if std_dir.exists():
        std = _select_years(sorted(std_dir.glob("*.nc")), start_year, end_year)
        if std:
            return std, {"method": "standard", "looked_in": str(std_dir)}
    rec_root = Path(root) / model / scenario
    if rec_root.exists():
        rec = _select_years(_scan_recursive(rec_root, model, scenario, member, varname), start_year, end_year)
        if rec:
            return rec, {"method": "recursive", "looked_in": str(rec_root)}
    rec2 = _select_years(_scan_recursive(root, model, scenario, member, varname), start_year, end_year)
    if rec2:
        return rec2, {"method": "recursive", "looked_in": str(root)}
    raise FileNotFoundError(f"No files found for {model}/{scenario}/{member}/{varname} {start_year}-{end_year}")

def _open_dataset(files, varname: str):
    ds = xr.open_mfdataset([str(f) for f in files], combine="by_coords", decode_times=True)
    if varname not in ds:
        cand = [k for k in ds.data_vars if k.lower()==varname.lower()]
        if not cand:
            raise KeyError(f"Variable '{varname}' not in dataset. Available: {list(ds.data_vars)}")
        varname = cand[0]
    da = ds[varname].sortby("time")
    return da

def _subset_season(da: xr.DataArray, season: str) -> xr.DataArray:
    if season == "annual":
        return da
    month_sets = {"MAM":[3,4,5],"JJA":[6,7,8],"SON":[9,10,11],"DJF":[12,1,2]}
    months = month_sets[season]
    sub = da.sel(time=da["time"].dt.month.isin(months))
    if season == "DJF":
        year = xr.where(sub["time"].dt.month==12, sub["time"].dt.year+1, sub["time"].dt.year)
        sub = sub.assign_coords(season_year=("time", year.values))
    return sub

def _maybe_mm_per_day(pr: xr.DataArray) -> xr.DataArray:
    units = str(pr.attrs.get("units", "")).lower()
    pr2 = pr
    if "kg m-2 s-1" in units or units=="kg m-2 s-1" or units=="kg/m^2/s" or "kg m**-2 s**-1" in units:
        pr2 = pr * 86400.0
        pr2.attrs["units"] = "mm day-1"
    elif "m day-1" in units or "m/day" in units:
        pr2 = pr * 1000.0
        pr2.attrs["units"] = "mm day-1"
    return pr2


In [3]:

# === Step 2.5 (optional): Hotfix for assign_coords(doy=...) error in ETCCDI_pr_indices ===
# Run this only if Step 4 errors with "Using a DataArray object to construct a variable is ambiguous"
try:
    import importlib, re, pathlib, ETCCDI_pr_indices as _et
    _p = pathlib.Path(_et.__file__)
    _s = _p.read_text()
    _new = re.sub(r'assign_coords\(\s*doy=\(\s*"time"\s*,\s*doy\s*\)\s*\)',
                  'assign_coords(doy=(\"time\", doy.values))', _s)
    if _new != _s:
        _p.write_text(_new); importlib.reload(_et)
        print("Patched ETCCDI_pr_indices for assign_coords(doy=...)")
    else:
        print("Hotfix not needed.")
except Exception as _e:
    print("Hotfix check skipped:", _e)


Hotfix not needed.


In [4]:

# === Step 3: Locate & load (pr) ===
args = Args(MODEL, SCENARIO, MEMBER, START_YEAR, END_YEAR, SEASON)
_validate_args(args)

files, info = _build_file_list_flexible(ROOT, MODEL, SCENARIO, MEMBER, "pr", START_YEAR, END_YEAR)
print(f"Found {len(files)} files via {info['method']} in {info['looked_in']}")
print("Example:", files[0].name if files else "None")

da = _open_dataset(files, "pr")
da = _subset_season(da, SEASON)
print("Loaded pr shape:", tuple(da.shape), "time span:", str(da['time'].values[0])[:10], "->", str(da['time'].values[-1])[:10])


FileNotFoundError: No files found for GISS-E2-1-G/historical/r1i1p1f2/pr 1985-2014

In [None]:

# === Step 4: Compute ETCCDI pr indices ===
_suppress_nan_warnings()
import ETCCDI_pr_indices as et
import xarray as xr

# Auto-load if Step 3 wasn't run
try:
    da
except NameError:
    args = Args(MODEL, SCENARIO, MEMBER, START_YEAR, END_YEAR, SEASON)
    _validate_args(args)
    files, info = _build_file_list_flexible(ROOT, MODEL, SCENARIO, MEMBER, "pr", START_YEAR, END_YEAR)
    print(f"Auto-loaded {len(files)} files via {info['method']} in {info['looked_in']}")
    da = _open_dataset(files, "pr"); da = _subset_season(da, SEASON)

pr = _maybe_mm_per_day(da)
'''
tasks = [
    ("Rx1day",  lambda: et.Rx1day(pr, period="annual")),
    ("Rx5day",  lambda: et.Rx5day(pr, period="annual")),
    ("SDII",    lambda: et.SDII(pr, period="annual")),
    ("R10mm",   lambda: et.R10mm(pr, period="annual")),
    ("R20mm",   lambda: et.R20mm(pr, period="annual")),
    ("PRCPTOT", lambda: et.PRCPTOT(pr, period="annual")),
    ("R95pTOT", lambda: et.R95pTOT(pr, period="annual")),
    ("R99pTOT", lambda: et.R99pTOT(pr, period="annual")),
    ("CDD",     lambda: et.CDD(pr, period="annual")),
    ("CWD",     lambda: et.CWD(pr, period="annual")),
]
'''
tasks = [
    ("Rx1day",  lambda: et.Rx1day(pr, period="annual")),
    ("Rx5day",  lambda: et.Rx5day(pr, period="annual")),
    ("SDII",    lambda: et.SDII(pr, period="annual")),
    ("R10mm",   lambda: et.R10mm(pr, period="annual")),
    ("R20mm",   lambda: et.R20mm(pr, period="annual")),
    ("PRCPTOT", lambda: et.PRCPTOT(pr, period="annual")),
    ("CDD",     lambda: et.CDD(pr, period="annual")),
    ("CWD",     lambda: et.CWD(pr, period="annual"))
]

out = {}
for name, fn in tqdm(tasks, desc="ETCCDI pr", unit="idx"):
    print(f"[pr] computing {name} ...")
    out[name] = fn()

# Align time axes
aligned = xr.align(*out.values(), join='outer')
out = {k: v for (k,_), v in zip(out.items(), aligned)}
ds = xr.Dataset(out)
# Mask PRCPTOT/R10mm/R20mm at locations with all-NaN input over time
spatial_dims = [d for d in da.dims if d != 'time']
obs_mask = da.where(np.isfinite(da)).count('time') > 0
for key in ['PRCPTOT','R10mm','R20mm']:
    if key in ds.data_vars:
        ds[key] = ds[key].where(obs_mask)



In [None]:

# === Step 5: Save to NetCDF (simple) ===
from pathlib import Path as _P
_P(OUTDIR).mkdir(parents=True, exist_ok=True)
outfile = _P(OUTDIR) / f"ETCCDI_pr_{MODEL}_{SCENARIO}_{SEASON}_{START_YEAR}-{END_YEAR}.nc"

for v in ds.data_vars:
    ds[v].attrs.setdefault("long_name", v)

DISCLAIMER_TEXT = ("ALL INDICES PRODUCED BY THESE SCRIPTS/MODULES ARE EXPERIMENTAL. "
                   "USERS MUST REVIEW THE CODE AND VALIDATE RESULTS AGAINST KNOWN REFERENCES "
                   "BEFORE ANY OPERATIONAL OR DECISION SUPPORT USE. THERE MAY BE ERRORS.")
ds.attrs.update({
    "model": MODEL,
    "scenario": SCENARIO,
    "member": MEMBER,
    "period": f"{START_YEAR}-{END_YEAR}",
    "season": SEASON,
    "disclaimer": DISCLAIMER_TEXT
})

import warnings, gc, numpy as np
import netCDF4 as nc

# Prepare coordinate references
_time = ds.coords["time"]
_lat = ds.coords["lat"]
_lon = ds.coords["lon"]

def _encode_time_days(time_index):
    """Return numeric days since epoch plus units/calendar for CF-ish encoding."""
    try:
        t64 = np.array(time_index.values).astype("datetime64[ns]")
        base = np.datetime64("1970-01-01T00:00:00")
        days = (t64 - base) / np.timedelta64(1, "D")
        return np.asarray(days, dtype="float64"), "days since 1970-01-01 00:00:00", "standard"
    except Exception:
        # Fallback: index units
        return np.arange(time_index.size, dtype="float64"), "index", "none"

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)

_base = outfile  # we'll append _{var}.nc

for _v in ds.data_vars:
    _da = ds[_v].astype("float32")
    # enforce canonical ordering for safety
    try:
        _da = _da.transpose("time", "lat", "lon")
    except Exception:
        pass
    _dims = tuple(_da.dims)
    _shape = tuple(_da.shape)

    _out_path = _base.with_name(_base.stem + f"_{_v}.nc")
    tvals, tunits, tcal = _encode_time_days(_time)

    # Write with netCDF4 (no compression, very fast)
    with nc.Dataset(str(_out_path), "w", format="NETCDF4") as _nc:
        # Dimensions
        _nc.createDimension("time", size=_da.sizes.get("time", len(tvals)))
        _nc.createDimension("lat", size=_da.sizes.get("lat", len(_lat)))
        _nc.createDimension("lon", size=_da.sizes.get("lon", len(_lon)))

        # Coordinate variables
        tv = _nc.createVariable("time", "f8", ("time",))
        yv = _nc.createVariable("lat", "f4", ("lat",))
        xv = _nc.createVariable("lon", "f4", ("lon",))
        tv[:] = tvals
        yv[:] = np.asarray(_lat.values, dtype="float32")
        xv[:] = np.asarray(_lon.values, dtype="float32")
        tv.units = tunits; tv.calendar = tcal
        yv.units = getattr(getattr(_lat, "attrs", {}), "get", lambda k, d=None: d)("units", "degrees_north")
        xv.units = getattr(getattr(_lon, "attrs", {}), "get", lambda k, d=None: d)("units", "degrees_east")

        # Data variable
        var = _nc.createVariable(_v, "f4", ("time", "lat", "lon"), zlib=False)
        var[:] = np.asarray(_da.values, dtype="float32")
        var.long_name = ds[_v].attrs.get("long_name", _v)
        if "units" in ds[_v].attrs:
            var.units = ds[_v].attrs["units"]

        # Global attributes
        for _k, _val in ds.attrs.items():
            try:
                setattr(_nc, _k, str(_val))
            except Exception:
                pass

    print(f"Wrote {_out_path.name} | dims={_dims} shape={_shape}")
    del _da
    gc.collect()
