In [None]:
# JUPYTER-ONLY FEATURE BUILDER (no argparse, no CLI)
# It won't ask for --start/--end. Just run this cell.

import os, re, glob, warnings, gc
import numpy as np
import xarray as xr

DATA_DIR = "./data"           # where your era_YYYY_MM_*.nc live
OUT_DIR  = "./features_safe"  # where features_YYYY_MM.nc will be written
os.makedirs(OUT_DIR, exist_ok=True)

CHUNKS = {'time': 1, 'latitude': 200, 'longitude': 200}

def _open_nc(path):
    return xr.open_dataset(path, chunks=CHUNKS, cache=False)

def _infer_level_from_name(path):
    m = re.search(r'_(\d{3})\.nc$', os.path.basename(path))
    return int(m.group(1)) if m else None

def _normalize_vars(ds):
    rename_map = {
        'u_component_of_wind': 'u',
        'v_component_of_wind': 'v',
        'temperature': 't',
        'relative_humidity': 'r',
        'vorticity': 'vo',
    }
    to_rename = {k:v for k,v in rename_map.items() if k in ds.data_vars}
    if to_rename:
        ds = ds.rename(to_rename)
    return ds

def _ensure_level_dim(ds, lvl_from_name):
    if 'level' in ds.coords:
        return ds
    for cand in ['pressure_level', 'isobaricInhPa']:
        if cand in ds.coords:
            return ds.rename({cand: 'level'})
    if lvl_from_name is not None:
        return ds.expand_dims({'level': [lvl_from_name]})
    warnings.warn("Could not infer 'level'; keeping file as-is (may be skipped).")
    return ds

def _pick_one_var(ds, varname, level):
    if varname not in ds:
        return None
    da = ds[varname]
    if 'level' in da.dims:
        if level in da['level']:
            da = da.sel(level=level)
        else:
            return None
    return da

def shear_mag(u200, v200, u850, v850):
    du = u200 - u850
    dv = v200 - v850
    return np.sqrt(du**2 + dv**2)

def _find_file(year, month, key):
    patt = os.path.join(DATA_DIR, f"era_{year}_{month:02d}_{key}.nc")
    matches = glob.glob(patt)
    return matches[0] if matches else None

def compute_features_for_month_safe(year, month):
    f_uv200 = _find_file(year, month, "u_component_of_wind-v_component_of_wind_200")
    f_uv850 = _find_file(year, month, "u_component_of_wind-v_component_of_wind_850")
    f_tr600 = _find_file(year, month, "temperature-relative_humidity_600")
    f_t850  = _find_file(year, month, "temperature-relative_humidity_850")
    f_t200  = _find_file(year, month, "temperature-relative_humidity_200")
    f_vo850 = _find_file(year, month, "vorticity_850")

    if not any([f_uv200, f_uv850, f_tr600, f_t850, f_t200, f_vo850]):
        print(f"[{year}-{month:02d}] no files found. Skipping.")
        return

    vo850 = None
    if f_vo850:
        ds = _open_nc(f_vo850)
        ds = _normalize_vars(_ensure_level_dim(ds, _infer_level_from_name(f_vo850)))
        vo850 = _pick_one_var(ds, 'vo', 850)
        ds.close(); del ds; gc.collect()

    u200 = v200 = None
    if f_uv200:
        ds = _open_nc(f_uv200)
        ds = _normalize_vars(_ensure_level_dim(ds, _infer_level_from_name(f_uv200)))
        u200 = _pick_one_var(ds, 'u', 200)
        v200 = _pick_one_var(ds, 'v', 200)
        ds.close(); del ds; gc.collect()

    u850 = v850 = None
    if f_uv850:
        ds = _open_nc(f_uv850)
        ds = _normalize_vars(_ensure_level_dim(ds, _infer_level_from_name(f_uv850)))
        u850 = _pick_one_var(ds, 'u', 850)
        v850 = _pick_one_var(ds, 'v', 850)
        ds.close(); del ds; gc.collect()

    rh600 = t600 = None
    if f_tr600:
        ds = _open_nc(f_tr600)
        ds = _normalize_vars(_ensure_level_dim(ds, _infer_level_from_name(f_tr600)))
        rh600 = _pick_one_var(ds, 'r', 600)
        t600  = _pick_one_var(ds, 't', 600)
        ds.close(); del ds; gc.collect()

    t850 = None
    if f_t850:
        ds = _open_nc(f_t850)
        ds = _normalize_vars(_ensure_level_dim(ds, _infer_level_from_name(f_t850)))
        t850 = _pick_one_var(ds, 't', 850)
        ds.close(); del ds; gc.collect()

    t200 = None
    if f_t200:
        ds = _open_nc(f_t200)
        ds = _normalize_vars(_ensure_level_dim(ds, _infer_level_from_name(f_t200)))
        t200 = _pick_one_var(ds, 't', 200)
        ds.close(); del ds; gc.collect()

    feat = xr.Dataset()
    for candidate in [vo850, u200, v200, u850, v850, rh600, t600, t850, t200]:
        if candidate is not None:
            for c in ['time','latitude','longitude']:
                if c in candidate.coords and c not in feat.coords:
                    feat = feat.assign_coords({c: candidate[c]})
            break

    if vo850 is not None:
        feat['vo850'] = vo850

    if all(da is not None for da in [u200, v200, u850, v850]):
        ws = shear_mag(u200, v200, u850, v850)
        ws.name = 'wshear_200_850'
        ws.attrs.update(units='m s^-1', long_name='Vertical wind shear magnitude (200â€“850 hPa)')
        feat['wshear_200_850'] = ws

    if rh600 is not None:
        feat['rh600'] = rh600
        feat['rh600'].attrs.update(units='%', long_name='Relative humidity at 600 hPa')

    if (t600 is not None) and (t850 is not None):
        d1 = (t600 - t850).rename('t600_minus_t850')
        d1.attrs.update(units='K', long_name='T600 - T850')
        feat['t600_minus_t850'] = d1

    if (t200 is not None) and (t850 is not None):
        d2 = (t200 - t850).rename('t200_minus_t850')
        d2.attrs.update(units='K', long_name='T200 - T850')
        feat['t200_minus_t850'] = d2

    out_nc = os.path.join(OUT_DIR, f"features_{year}_{month:02d}.nc")
    encoding = {v: {'zlib': True, 'complevel': 4, 'dtype': 'float32'} for v in feat.data_vars}
    with xr.set_options(file_cache_maxsize=1):
        feat.to_netcdf(out_nc, encoding=encoding, engine='netcdf4', compute=True)

    del feat; gc.collect()
    print(f"[{year}-{month:02d}] wrote {out_nc}")

def features_for_span(start_year: int, end_year: int, months=range(1,13)):
    start_year = int(start_year)
    end_year   = int(end_year)
    assert start_year <= end_year, "start_year must be <= end_year"
    months = sorted(set(int(m) for m in months))
    assert all(1 <= m <= 12 for m in months), "months must be 1..12"

    total_built = 0
    for y in range(start_year, end_year + 1):
        print(f"\n=== FEATURES {y} ===")
        for m in months:
            out_nc = os.path.join(OUT_DIR, f"features_{y}_{m:02d}.nc")
            if os.path.exists(out_nc):
                print(f"[{y}-{m:02d}] exists, skip.")
                continue
            try:
                compute_features_for_month_safe(y, m)
                total_built += 1
            except Exception as e:
                warnings.warn(f"[{y}-{m:02d}] failed: {e}")
                continue
    print(f"\nDone. Wrote/verified {total_built} monthly feature files.")

# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# Run it here for your requested span:
features_for_span(1994, 2024, months=range(1,13))
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
