In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Annual-batched daily OL/DA pipeline (streaming, resume-aware, simplified & fast)

Pass A (resume-aware):
  - Build a global keep_tile mask requiring ≥ MIN_VALID_FRAC valid days for BOTH OL and DA.
  - Skips if keep_tile.nc already exists.

Pass B (resume-aware):
  - Build global OL (CNTL) daily climatology (DOY=1..366) over kept tiles, then apply cyclic smoothing.
  - Skips if OLv8_climatology_DOY_smooth_kept.nc already exists.

Pass C (resume-aware):
  - Write per-year OL/DA anomaly files using the global climatology and keep_tile subset.
  - Skips any year whose outputs already exist.

Outputs (in --outdir):
  - keep_tile.nc
  - OLv8_climatology_DOY_smooth_kept.nc
  - OLv8_daily_anomalies_kept_YYYY.nc
  - DAv8_daily_anomalies_kept_YYYY.nc
"""

import os, re, glob, time, argparse, logging
import numpy as np
import xarray as xr
import dask
from dask.diagnostics import ProgressBar
from dask.utils import SerializableLock

# -----------------------------
# Basic config / tunables
# -----------------------------
DEFAULT_READ_ENGINE  = "netcdf4"
DEFAULT_WRITE_ENGINE = "h5netcdf"
_DEFAULT_CHUNKS = {"time": 31, "tile": 16384}

TEMP_THRESH_K   = 275.15   # 2 °C
SNOW_EPS        = 1e-2     # 1% snow cover
MIN_VALID_FRAC  = 0.7
_TS_RE = re.compile(r"\.(\d{8})_1200z\.nc4$")   # ...YYYYMMDD_1200z.nc4

# Dask setup: threads play nicer with HDF5
_H5_LOCK = SerializableLock()
dask.config.set({
    "scheduler": "threads",
    "array.slicing.split_large_chunks": True,
    "optimization.fuse.active": True
})
os.environ.setdefault("HDF5_USE_FILE_LOCKING", "FALSE")

# -----------------------------
# Logging
# -----------------------------
def setup_logger(verbosity: int = 1) -> logging.Logger:
    level = logging.INFO if verbosity == 1 else (logging.DEBUG if verbosity >= 2 else logging.WARNING)
    logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(message)s",
                        datefmt="%H:%M:%S", force=True)
    return logging.getLogger("daily-da-yearly")

def stamp(log, msg): log.info(msg)

# -----------------------------
# Helpers: discovery & opening
# -----------------------------
def _parse_ts(bname: str):
    m = _TS_RE.search(bname)
    if not m: return None
    ymd = m.group(1)
    return np.datetime64(f"{ymd[:4]}-{ymd[4:6]}-{ymd[6:8]}T12:00")

def collect_daily_files(root_dir: str, file_prefix: str, start_date, end_date, log=None):
    pattern = os.path.join(root_dir, "**", f"{file_prefix}.tavg24_1d_lnd_Nt.*_1200z.nc4")
    hits = glob.glob(pattern, recursive=True)
    if not hits: raise FileNotFoundError(f"No *_1200z.nc4 under {root_dir} for {file_prefix}")
    start = np.datetime64(str(start_date), "ns"); end = np.datetime64(str(end_date), "ns")
    files, times = [], []
    for p in hits:
        ts = _parse_ts(os.path.basename(p))
        if ts is not None and start <= ts <= end:
            files.append(p); times.append(ts)
    if not files: raise FileNotFoundError(f"No daily files for {file_prefix} within [{start_date}..{end_date}]")
    order = np.argsort(np.asarray(times))
    files = [files[i] for i in order]
    times = np.asarray(times, dtype="datetime64[ns]")[order]
    if log:
        stamp(log, f"[{file_prefix}] in-range files: {len(files)} / hits: {len(hits)}")
        stamp(log, f"[{file_prefix}] time span: {str(times[0])} … {str(times[-1])}")
    return files, times

def split_by_year(files, times):
    years = np.array([int(str(t)[:4]) for t in times])
    out = {}
    for y in np.unique(years):
        idx = np.where(years == y)[0]
        out[y] = ([files[i] for i in idx], times[idx])
    return out

def speed_open_mfdataset(files, varnames, engine="netcdf4", chunks=None, log=None):
    if chunks is None: chunks = _DEFAULT_CHUNKS
    want = set(varnames)
    def _pre(ds):
        keep = [v for v in ds.variables if v in want]
        if not keep: raise KeyError(f"Requested vars missing. Asked: {varnames}")
        return ds[keep]
    if log: stamp(log, f"[open] {len(files)} files …")
    t0 = time.perf_counter()
    ds = xr.open_mfdataset(files, combine="nested", concat_dim="time", preprocess=_pre,
                           engine=engine, parallel=False, lock=_H5_LOCK, chunks=chunks,
                           data_vars="minimal", coords="minimal", compat="override",
                           mask_and_scale=False, decode_times=False, decode_coords=False,
                           use_cftime=False)
    if log: stamp(log, f"[open] done in {time.perf_counter()-t0:.1f}s")
    return ds

def _open_first_for_latlon(path: str, engine: str = "netcdf4"):
    with xr.open_dataset(path, engine=engine, chunks={}) as ds0:
        lat = ds0["lat"].values; lon = ds0["lon"].values
    return lat, lon

def batched(seq, n):
    for i in range(0, len(seq), n):
        yield seq[i:i+n], slice(i, min(i+n, len(seq)))

# -----------------------------
# Masking & DOY
# -----------------------------
def apply_frozen_snow_mask(sm_da, tsoil, frsnow, temp_thresh=TEMP_THRESH_K, snow_eps=SNOW_EPS):
    return sm_da.where(~((tsoil < temp_thresh) | (frsnow > snow_eps)))

def mask_vars(ds: xr.Dataset, anom_vars, temp_k, snow_eps) -> xr.Dataset:
    out = xr.Dataset(coords=ds.coords)
    for v in anom_vars:
        sm = ds[v]
        tsoil  = ds["TSOIL1"] if "TSOIL1" in ds else xr.full_like(sm, np.nan)
        frsnow = ds["FRLANDSNO"] if "FRLANDSNO" in ds else xr.full_like(sm, np.nan)
        out[v] = apply_frozen_snow_mask(sm, tsoil, frsnow, temp_k, snow_eps)
    return out

def doy_index(da_time):
    return xr.where(da_time.dt.dayofyear == 366, 365, da_time.dt.dayofyear)

# -----------------------------
# Write util (handles time/tile and dayofyear/tile)
# -----------------------------
def write_nc(ds: xr.Dataset, path: str, engine: str, chunks=None, log=None):
    if chunks is None: chunks = _DEFAULT_CHUNKS
    comp = dict(zlib=True, complevel=4)
    encoding = {}
    for v in ds.data_vars:
        dims = ds[v].dims
        if len(dims) == 2 and dims[-1] == "tile":
            # Respect chunk hints if present
            d0 = dims[0]
            d0_len = int(ds.sizes.get(d0, 1))
            tile_len = int(ds.sizes.get("tile", 1))
            d0_chunk = min(d0_len, int(chunks.get(d0, chunks.get("time", 31))))
            tile_chunk = min(tile_len, int(chunks.get("tile", 16384)))
            encoding[v] = {**comp, "chunksizes": (d0_chunk, tile_chunk)}
    if log: stamp(log, f"→ Writing {path}")
    delayed = ds.to_netcdf(path, engine=engine, encoding=encoding, compute=False)
    with ProgressBar(): dask.compute(delayed)
    if log: stamp(log, f"✓ Wrote {path}")

# -----------------------------
# Arg parsing
# -----------------------------
def parse_chunk_flag(s: str) -> dict:
    out = {}
    if not s: return out
    for kv in s.split(","):
        if not kv.strip(): continue
        k, v = kv.split(":"); out[k.strip()] = int(v.strip())
    return out

def parse_args(argv=None):
    p = argparse.ArgumentParser(description="Annual-batched OL/DA anomalies with global keep_tile and climatology (streaming, resume-aware).")
    # Roots/prefixes
    p.add_argument("--ol-root", default="/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg")
    p.add_argument("--ol-prefix", default="LS_OLv8_M36")
    p.add_argument("--da-root", default="/discover/nobackup/projects/land_da/M21C_land_sweeper/LS_DAv8_M36_v2/LS_DAv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg")
    p.add_argument("--da-prefix", default="LS_DAv8_M36")
    # Vars
    p.add_argument("--vars", nargs="+", default=["SFMC","RZMC","PRECTOTCORRLAND","FRLANDSNO","TSOIL1"])
    p.add_argument("--anom-vars", nargs="+", default=["SFMC","RZMC"])
    # Time & I/O
    p.add_argument("--start-date", default="2000-01-01")
    p.add_argument("--end-date",   default="2024-12-31")
    p.add_argument("--read-engine", choices=["h5netcdf","netcdf4"], default=DEFAULT_READ_ENGINE)
    p.add_argument("--write-engine", choices=["h5netcdf","netcdf4"], default=DEFAULT_WRITE_ENGINE)
    p.add_argument("--chunks", default="time:31,tile:8192")
    p.add_argument("--outdir", default="./yearly_outputs")
    # Mask/clim
    p.add_argument("--temp-K", type=float, default=TEMP_THRESH_K)
    p.add_argument("--snow-eps", type=float, default=SNOW_EPS)
    p.add_argument("--min-valid-frac", type=float, default=MIN_VALID_FRAC)
    p.add_argument("--clim-window", type=int, default=31)
    # Batching
    p.add_argument("--batch-days", type=int, default=60, help="Days per file batch in streaming steps (45–90 good).")
    # Verbosity
    p.add_argument("--verbose", type=int, default=1)
    return p.parse_args(argv)

# -----------------------------
# Main
# -----------------------------
def main(argv=None):
    args = parse_args(argv)
    log = setup_logger(args.verbose)
    chunks = parse_chunk_flag(args.chunks)
    os.makedirs(args.outdir, exist_ok=True)

    keep_path = os.path.join(args.outdir, "keep_tile.nc")
    clim_path = os.path.join(args.outdir, "OLv8_climatology_DOY_smooth_kept.nc")

    stamp(log, "=== Pipeline start ===")
    files_ol, times_ol = collect_daily_files(args.ol_root, args.ol_prefix, args.start_date, args.end_date, log)
    files_da, times_da = collect_daily_files(args.da_root, args.da_prefix, args.start_date, args.end_date, log)
    yearmap_ol = split_by_year(files_ol, times_ol)
    yearmap_da = split_by_year(files_da, times_da)
    years = sorted(set(yearmap_ol.keys()) & set(yearmap_da.keys()))
    stamp(log, f"Years with both OL & DA: {years[0]}–{years[-1]} ({len(years)} years)")

    lat, lon = _open_first_for_latlon(yearmap_ol[years[0]][0][0], args.read_engine)

    # Peek to get ntiles
    one_file = [yearmap_ol[years[0]][0][0]]
    tmp = speed_open_mfdataset(one_file, args.vars, engine=args.read_engine, chunks=chunks, log=log)
    ntiles = int(tmp.sizes["tile"])
    del tmp

    # -------------------------
    # PASS A: keep_tile (resume-aware)
    # -------------------------
    stamp(log, "=== PASS A: valid-day keep mask ===")
    BATCH = max(1, int(args.batch_days))

    if os.path.exists(keep_path):
        stamp(log, f"[A] Found existing keep mask → {keep_path} (skipping PASS A)")
        kds = xr.load_dataset(keep_path)
        keep_tile = kds["keep_tile"].values.astype(bool)
        keep_idx  = np.where(keep_tile)[0]
        n_keep = int(keep_tile.sum())
        stamp(log, f"[A] keep tiles: {n_keep}/{len(keep_tile)} (loaded)")
    else:
        def accumulate_valid_for_filebatch(files, times):
            ds = speed_open_mfdataset(files, args.vars + ["TSOIL1","FRLANDSNO"], engine=args.read_engine, chunks=chunks, log=log)
            ds = ds.assign_coords({"time": ("time", times), "lat": ("tile", lat), "lon": ("tile", lon)})
            val = None
            for v in args.anom_vars:
                sm = ds[v]
                tsoil = ds["TSOIL1"] if "TSOIL1" in ds else xr.full_like(sm, np.nan)
                frsn  = ds["FRLANDSNO"] if "FRLANDSNO" in ds else xr.full_like(sm, np.nan)
                ok    = apply_frozen_snow_mask(sm, tsoil, frsn, args.temp_K, args.snow_eps).notnull()
                val   = ok if val is None else (val & ok)
            cnt = val.astype("int8").sum("time", dtype="int64")
            with ProgressBar():
                return cnt.compute().values, int(val.sizes["time"])

        accum_valid_ol = np.zeros((ntiles,), dtype="int64")
        accum_valid_da = np.zeros((ntiles,), dtype="int64")
        accum_days_ol = 0; accum_days_da = 0

        for y in years:
            f_ol, t_ol = yearmap_ol[y]; f_da, t_da = yearmap_da[y]
            for fbatch, s in batched(f_ol, BATCH):
                tbatch = t_ol[s]; cnt, days = accumulate_valid_for_filebatch(fbatch, tbatch)
                accum_valid_ol += cnt; accum_days_ol += days
            for fbatch, s in batched(f_da, BATCH):
                tbatch = t_da[s]; cnt, days = accumulate_valid_for_filebatch(fbatch, tbatch)
                accum_valid_da += cnt; accum_days_da += days
            stamp(log, f"[A] {y} done.")

        frac_ol = accum_valid_ol / max(1, accum_days_ol)
        frac_da = accum_valid_da / max(1, accum_days_da)
        keep_tile = (frac_ol >= args.min_valid_frac) & (frac_da >= args.min_valid_frac)
        keep_idx = np.where(keep_tile)[0]
        n_keep = int(keep_tile.sum())
        stamp(log, f"[A] keep tiles: {n_keep}/{ntiles}")

        keep_ds = xr.Dataset(dict(keep_tile=xr.DataArray(
            keep_tile, dims=("tile",),
            coords={"tile": np.arange(ntiles), "lat": ("tile", lat), "lon": ("tile", lon)}
        )))
        write_nc(keep_ds, keep_path, args.write_engine, {"tile": _DEFAULT_CHUNKS["tile"]}, log)

    # -------------------------
    # PASS B: climatology on kept tiles (resume-aware)
    # -------------------------
    stamp(log, "=== PASS B: OL climatology (DOY) ===")
    DOY = 366
    BATCH = max(1, int(args.batch_days))

    if os.path.exists(clim_path):
        stamp(log, f"[B] Found existing climatology → {clim_path} (skipping PASS B)")
        clim_ds = xr.load_dataset(clim_path)
    else:
        clim_sums = {v: np.zeros((DOY, len(keep_idx)), dtype=np.float64) for v in args.anom_vars}
        clim_cnts = {v: np.zeros((DOY, len(keep_idx)), dtype=np.int64)   for v in args.anom_vars}

        def accumulate_doy_for_batch(files_batch, times_batch):
            ds = speed_open_mfdataset(files_batch, args.vars, engine=args.read_engine, chunks=chunks, log=log)
            ds = ds.assign_coords({"time": ("time", times_batch), "lat": ("tile", lat), "lon": ("tile", lon)})
            olm = mask_vars(ds, args.anom_vars, args.temp_K, args.snow_eps).isel(tile=keep_idx)
            # Group with 366→365 mapping (so Feb 29 merges with DOY 365), but we keep length 366 in reindex
            doy = xr.where(olm.time.dt.dayofyear == 366, 365, olm.time.dt.dayofyear)
            full = xr.DataArray(np.arange(1, 367), dims="dayofyear", name="dayofyear")
            sums_b, cnts_b = {}, {}
            for v in args.anom_vars:
                g = olm[v].groupby(doy)
                with ProgressBar():
                    s = g.sum("time", skipna=True).reindex(dayofyear=full, fill_value=0.0).compute()
                    c = g.count("time").reindex(dayofyear=full, fill_value=0).compute()
                sums_b[v] = s.values  # (366, n_keep)
                cnts_b[v] = c.values
            return sums_b, cnts_b

        for y in years:
            f_ol, t_ol = yearmap_ol[y]
            stamp(log, f"[B] {y}")
            for fbatch, s in batched(f_ol, BATCH):
                tbatch = t_ol[s]
                sums_b, cnts_b = accumulate_doy_for_batch(fbatch, tbatch)
                for v in args.anom_vars:
                    clim_sums[v] += sums_b[v]
                    clim_cnts[v] += cnts_b[v]

        # If DOY=366 never observed, copy 365 over so length stays 366 and continuous
        for v in args.anom_vars:
            zero_366 = clim_cnts[v][-1, :] == 0
            clim_sums[v][-1, zero_366] = clim_sums[v][-2, zero_366]
            clim_cnts[v][-1, zero_366] = clim_cnts[v][-2, zero_366]

        # Compute mean and apply circular smoothing that preserves length (366)
        clim_vars = {}
        for v in args.anom_vars:
            count = clim_cnts[v]
            summ  = clim_sums[v]
            with np.errstate(invalid="ignore", divide="ignore"):
                clim = summ / np.maximum(count, 1)

            def cyclic_smooth_same_len(arr2d, window: int):
                w = max(1, int(window))
                if w == 1: return arr2d
                left = w // 2
                right = (w - 1) // 2
                pad = np.pad(arr2d, ((left, right), (0,0)), mode="wrap")
                kernel = np.ones(w, dtype=np.float64) / w
                out = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="valid"), 0, pad)
                return out  # (366, n_keep)

            clim = cyclic_smooth_same_len(clim, args.clim_window)
            clim_vars[f"{v}_clim"] = clim

        coords = dict(dayofyear=np.arange(1, DOY+1),
                      tile=("tile", keep_idx),
                      lat=("tile", lat[keep_idx]),
                      lon=("tile", lon[keep_idx]))
        clim_ds = xr.Dataset({k: (("dayofyear","tile"), v) for k, v in clim_vars.items()}, coords=coords)
        clim_ds.attrs.update(baseline="CNTL masked DOY mean (global, smoothed)",
                             smoothing=f"cyclic {args.clim_window}-day")
        write_nc(clim_ds, clim_path, args.write_engine, {"dayofyear": DOY, "tile": _DEFAULT_CHUNKS["tile"]}, log)

    # -------------------------
    # PASS C: per-year anomalies (resume-aware)
    # -------------------------
    stamp(log, "=== PASS C: yearly anomalies (OL & DA) ===")
    for y in years:
        out_ol_y = os.path.join(args.outdir, f"OLv8_daily_anomalies_kept_{y}.nc")
        out_da_y = os.path.join(args.outdir, f"DAv8_daily_anomalies_kept_{y}.nc")

        if os.path.exists(out_ol_y) and os.path.exists(out_da_y):
            stamp(log, f"[C] {y}: outputs exist → skipping")
            continue

        stamp(log, f"[C] {y}")
        f_ol, t_ol = yearmap_ol[y]; f_da, t_da = yearmap_da[y]

        ds_ol_y = speed_open_mfdataset(f_ol, args.vars, engine=args.read_engine, chunks=chunks, log=log)
        ds_ol_y = ds_ol_y.assign_coords({"time": ("time", t_ol), "lat": ("tile", lat), "lon": ("tile", lon)}).isel(tile=clim_ds.tile.values)
        ds_da_y = speed_open_mfdataset(f_da, args.vars, engine=args.read_engine, chunks=chunks, log=log)
        ds_da_y = ds_da_y.assign_coords({"time": ("time", t_da), "lat": ("tile", lat), "lon": ("tile", lon)}).isel(tile=clim_ds.tile.values)

        olm = mask_vars(ds_ol_y, args.anom_vars, args.temp_K, args.snow_eps)
        dam = mask_vars(ds_da_y, args.anom_vars, args.temp_K, args.snow_eps)

        doy = doy_index(olm.time)  # 366→365 mapping for Feb 29
        an_ol = xr.Dataset(coords=dict(time=olm.time, tile=clim_ds.tile, lat=clim_ds.lat, lon=clim_ds.lon))
        an_da = xr.Dataset(coords=dict(time=dam.time, tile=clim_ds.tile, lat=clim_ds.lat, lon=clim_ds.lon))
        for v in args.anom_vars:
            base = clim_ds[f"{v}_clim"].sel(dayofyear=doy).transpose("time","tile")
            an_ol[v] = olm[v] - base
            an_da[v] = dam[v] - base

        # Write only missing outputs (resume-aware)
        if not os.path.exists(out_ol_y):
            write_nc(an_ol, out_ol_y, args.write_engine, chunks, log)
        if not os.path.exists(out_da_y):
            write_nc(an_da, out_da_y, args.write_engine, chunks, log)

        del ds_ol_y, ds_da_y, olm, dam, an_ol, an_da

    stamp(log, "=== Done ===")
    stamp(log, f" keep_tile: {keep_path}")
    stamp(log, f" climatology: {clim_path}")
    stamp(log, f" anomalies: {args.outdir}/OLv8_daily_anomalies_kept_YYYY.nc & DAv8_daily_anomalies_kept_YYYY.nc")

if __name__ == "__main__":
    main([])
