In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Variance diagnostics from daily anomalies (OL vs DA)

Products written to --outdir:
  - variance_daily_fullperiod.nc
  - variance_withinmonth_daily.nc
  - variance_monthlymean_IAV.nc
  - variance_annualmean_IAV.nc

Optionally, if you *already* have monthly-mean files, you can pass
--ol-monthly and --da-monthly to compute the monthly-mean IAV from those
instead of resampling the daily anomalies.

Example (from your notebook shell):
  %run variance_diagnostics_from_anoms.py \
      --outdir ./yearly_outputs \
      --vars SFMC RZMC \
      --chunks time:365,tile:4096 \
      --decode-times \
      --use-monthly-from-daily

Or, with pre-made monthly means:
  %run variance_diagnostics_from_anoms.py \
      --outdir ./yearly_outputs \
      --vars SFMC RZMC \
      --ol-monthly /path/to/OL_monthly_means_*.nc \
      --da-monthly /path/to/DA_monthly_means_*.nc
"""

import os, glob, argparse, logging
import numpy as np
import xarray as xr
import dask
from dask.diagnostics import ProgressBar

# -----------------------------
# Dask / logging
# -----------------------------
dask.config.set({
    "scheduler": "threads",
    "array.slicing.split_large_chunks": True,
    "optimization.fuse.active": True
})

def setup_logger(verbosity: int = 1) -> logging.Logger:
    level = logging.INFO if verbosity == 1 else (logging.DEBUG if verbosity >= 2 else logging.WARNING)
    logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(message)s",
                        datefmt="%H:%M:%S", force=True)
    return logging.getLogger("variance-diags")

# -----------------------------
# Helpers
# -----------------------------
def parse_chunk_flag(s: str) -> dict:
    out = {}
    if not s: return out
    for kv in s.split(","):
        if not kv.strip(): continue
        k, v = kv.split(":"); out[k.strip()] = int(v.strip())
    return out

def ensure_monotonic_unique_time(ds: xr.Dataset) -> xr.Dataset:
    """Sort by time and drop duplicate timestamps (keeping first)."""
    t = ds["time"].values
    # argsort → sorted indices
    order = np.argsort(t)
    t_sorted = t[order]
    # unique mask on sorted values
    keep = np.ones(t_sorted.shape[0], dtype=bool)
    keep[1:] = t_sorted[1:] != t_sorted[:-1]
    idx = order[keep]
    return ds.isel(time=idx)

def open_mf(pattern, chunks, decode_times=True, log=None, label=""):
    files = sorted(glob.glob(pattern))
    if not files:
        raise FileNotFoundError(f"No files match pattern: {pattern}")
    if log: log.info(f"[open] {label}: {len(files)} files")
    ds = xr.open_mfdataset(
        files,
        combine="nested", concat_dim="time",
        engine="netcdf4", parallel=False,
        chunks=chunks,
        data_vars="minimal", coords="minimal", compat="override",
        mask_and_scale=False,
        decode_times=decode_times, decode_coords=decode_times, use_cftime=False,
    )
    return ds

def write_nc(ds: xr.Dataset, path: str, chunks, log=None):
    comp = dict(zlib=True, complevel=4)
    encoding = {v: {**comp} for v in ds.data_vars}
    # add chunk hints for 2D (time,tile) or (month,tile) or (tile,) vars
    for v, da in ds.data_vars.items():
        dims = da.dims
        if len(dims) == 2 and "tile" in dims:
            lead = [d for d in dims if d != "tile"][0]
            lead_chunk = min(int(ds.sizes.get(lead, 1)),
                             int(chunks.get(lead, chunks.get("time", 365))))
            tile_chunk = min(int(ds.sizes.get("tile", 1)),
                             int(chunks.get("tile", 4096)))
            encoding[v]["chunksizes"] = (lead_chunk, tile_chunk)
        elif dims == ("tile",):
            encoding[v]["chunksizes"] = (min(int(ds.sizes["tile"]),
                                             int(chunks.get("tile", 4096))),)
    if log: log.info(f"→ Writing {path}")
    delayed = ds.to_netcdf(path, engine="h5netcdf", encoding=encoding, compute=False)
    with ProgressBar(): dask.compute(delayed)
    if log: log.info(f"✓ Wrote {path}")

# -----------------------------
# Core computations
# -----------------------------
def daily_variance(ds_ol, ds_da, vars_, eps, log):
    out = []
    for v in vars_:
        s2_ol = ds_ol[v].var("time", skipna=True)
        s2_da = ds_da[v].var("time", skipna=True)
        d     = s2_da - s2_ol
        pct   = xr.where(s2_ol > eps, 100.0 * d / s2_ol, np.nan)
        out.append(xr.Dataset({
            f"{v}_daily_var_OL": s2_ol.astype("float32"),
            f"{v}_daily_var_DA": s2_da.astype("float32"),
            f"{v}_daily_var_delta": d.astype("float32"),
            f"{v}_daily_var_pct": pct.astype("float32"),
        }))
    ds = xr.merge(out)
    ds = ds.assign_coords(lat=ds_ol.lat, lon=ds_ol.lon)
    if log: log.info("[daily] computed.")
    return ds

def within_month_daily_variance(ds_ol, ds_da, vars_, eps, log):
    # Ensure datetime time
    if not np.issubdtype(ds_ol.time.dtype, np.datetime64):
        ds_ol = xr.decode_cf(ds_ol)
    if not np.issubdtype(ds_da.time.dtype, np.datetime64):
        ds_da = xr.decode_cf(ds_da)

    out = []
    mon_ol = ds_ol["time"].dt.month
    mon_da = ds_da["time"].dt.month

    def _ensure_month_dim(da: xr.DataArray) -> xr.DataArray:
        # Rename the grouping dimension (whatever it is) to "month"
        # (xarray may call it "group" or already "month")
        if "month" in da.dims:
            return da
        # pick the non-tile dim (group dim)
        gdim = next(d for d in da.dims if d != "tile")
        return da.rename({gdim: "month"})

    for v in vars_:
        s2m_ol = ds_ol[v].groupby(mon_ol).var("time", skipna=True)
        s2m_da = ds_da[v].groupby(mon_da).var("time", skipna=True)
        s2m_ol = _ensure_month_dim(s2m_ol)
        s2m_da = _ensure_month_dim(s2m_da)

        d   = s2m_da - s2m_ol
        pct = xr.where(s2m_ol > eps, 100.0 * d / s2m_ol, np.nan)

        out.append(xr.Dataset({
            f"{v}_month_dailyvar_OL": s2m_ol.astype("float32"),
            f"{v}_month_dailyvar_DA": s2m_da.astype("float32"),
            f"{v}_month_dailyvar_delta": d.astype("float32"),
            f"{v}_month_dailyvar_pct": pct.astype("float32"),
        }))

    ds = xr.merge(out).assign_coords(lat=ds_ol.lat, lon=ds_ol.lon)
    if log: log.info("[within-month] computed.")
    return ds

def monthlymean_variance_from_daily(ds_ol, ds_da, vars_, eps, log):
    out = []
    for v in vars_:
        m_ol = ds_ol[v].resample(time="1MS").mean(keep_attrs=True)
        m_da = ds_da[v].resample(time="1MS").mean(keep_attrs=True)
        s2_ol = m_ol.var("time", skipna=True)
        s2_da = m_da.var("time", skipna=True)
        d     = s2_da - s2_ol
        pct   = xr.where(s2_ol > eps, 100.0 * d / s2_ol, np.nan)
        out.append(xr.Dataset({
            f"{v}_monthlymean_var_OL": s2_ol.astype("float32"),
            f"{v}_monthlymean_var_DA": s2_da.astype("float32"),
            f"{v}_monthlymean_var_delta": d.astype("float32"),
            f"{v}_monthlymean_var_pct": pct.astype("float32"),
        }))
    ds = xr.merge(out).assign_coords(lat=ds_ol.lat, lon=ds_ol.lon)
    if log: log.info("[monthly-mean IAV] computed (from daily).")
    return ds

def monthlymean_variance_from_monthfiles(ol_pat, da_pat, vars_, chunks, eps, log):
    # Open monthly-mean stacks (assumed time is monthly and decoded)
    ds_ol = open_mf(ol_pat, chunks, decode_times=True, log=log, label="OL monthly")
    ds_da = open_mf(da_pat, chunks, decode_times=True, log=log, label="DA monthly")
    ds_ol, ds_da = xr.align(ds_ol, ds_da, join="exact")
    out = []
    for v in vars_:
        # If these are raw monthly means (not anomalies), you can subtract an OL monthly climatology:
        # mclim = ds_ol[v].groupby("time.dt.month").mean("time", skipna=True)
        # mol_anom = ds_ol[v] - mclim.sel(month=ds_ol["time.dt.month"])
        # mda_anom = ds_da[v] - mclim.sel(month=ds_da["time.dt.month"])
        # Here we assume your monthly means are already anomaly-consistent with the daily baseline.
        s2_ol = ds_ol[v].var("time", skipna=True)
        s2_da = ds_da[v].var("time", skipna=True)
        d     = s2_da - s2_ol
        pct   = xr.where(s2_ol > eps, 100.0 * d / s2_ol, np.nan)
        out.append(xr.Dataset({
            f"{v}_monthlymean_var_OL": s2_ol.astype("float32"),
            f"{v}_monthlymean_var_DA": s2_da.astype("float32"),
            f"{v}_monthlymean_var_delta": d.astype("float32"),
            f"{v}_monthlymean_var_pct": pct.astype("float32"),
        }))
    ds = xr.merge(out)
    if "lat" in ds_ol and "lon" in ds_ol:
        ds = ds.assign_coords(lat=ds_ol.lat, lon=ds_ol.lon)
    if log: log.info("[monthly-mean IAV] computed (from monthly files).")
    return ds

def annualmean_variance(ds_ol, ds_da, vars_, eps, log):
    out = []
    for v in vars_:
        a_ol = ds_ol[v].resample(time="1Y").mean(keep_attrs=True)
        a_da = ds_da[v].resample(time="1Y").mean(keep_attrs=True)
        s2_ol = a_ol.var("time", skipna=True)
        s2_da = a_da.var("time", skipna=True)
        d     = s2_da - s2_ol
        pct   = xr.where(s2_ol > eps, 100.0 * d / s2_ol, np.nan)
        out.append(xr.Dataset({
            f"{v}_annualmean_var_OL": s2_ol.astype("float32"),
            f"{v}_annualmean_var_DA": s2_da.astype("float32"),
            f"{v}_annualmean_var_delta": d.astype("float32"),
            f"{v}_annualmean_var_pct": pct.astype("float32"),
        }))
    ds = xr.merge(out).assign_coords(lat=ds_ol.lat, lon=ds_ol.lon)
    if log: log.info("[annual-mean IAV] computed.")
    return ds

# -----------------------------
# Args / main
# -----------------------------
def parse_args(argv=None):
    p = argparse.ArgumentParser(description="Variance diagnostics from daily anomalies (OL vs DA).")
    p.add_argument("--outdir", default="./yearly_outputs")
    p.add_argument("--vars", nargs="+", default=["SFMC","RZMC"])
    p.add_argument("--chunks", default="time:365,tile:4096")
    p.add_argument("--decode-times", action="store_true",
                   help="Decode CF times for dailies (recommended for groupby/resample).")
    p.add_argument("--eps", type=float, default=1e-10, help="Guard for percentage changes.")
    # Option: use pre-made monthly-mean files
    p.add_argument("--ol-monthly", default="", help="Glob for OL monthly-mean files (optional).")
    p.add_argument("--da-monthly", default="", help="Glob for DA monthly-mean files (optional).")
    p.add_argument("--use-monthly-from-daily", action="store_true",
                   help="Compute monthly-mean IAV by resampling daily anomalies (default).")
    p.add_argument("--verbose", type=int, default=1)
    return p.parse_args(argv)

def main(argv=None):
    args = parse_args(argv)
    log  = setup_logger(args.verbose)
    chunks = parse_chunk_flag(args.chunks)
    os.makedirs(args.outdir, exist_ok=True)

    # Input daily anomaly patterns
    ol_daily_pat = os.path.join(args.outdir, "OLv8_daily_anomalies_kept_*.nc")
    da_daily_pat = os.path.join(args.outdir, "DAv8_daily_anomalies_kept_*.nc")

    # Output files
    f_daily   = os.path.join(args.outdir, "variance_daily_fullperiod.nc")
    f_inmonth = os.path.join(args.outdir, "variance_withinmonth_daily.nc")
    f_mmiav   = os.path.join(args.outdir, "variance_monthlymean_IAV.nc")
    f_aiav    = os.path.join(args.outdir, "variance_annualmean_IAV.nc")

    # Open daily anomalies (decode times for month/year ops)
    log.info("Opening daily anomalies …")
    ds_ol = open_mf(ol_daily_pat, chunks, decode_times=args.decode_times, log=log, label="OL daily anoms")
    ds_da = open_mf(da_daily_pat, chunks, decode_times=args.decode_times, log=log, label="DA daily anoms")
    ds_ol, ds_da = xr.align(ds_ol, ds_da, join="exact")

    # after aligning ds_ol, ds_da
    if not np.issubdtype(ds_ol.time.dtype, np.datetime64):
        ds_ol = xr.decode_cf(ds_ol)
        ds_da = xr.decode_cf(ds_da)

    # A) Daily variability (full-period)
    if not os.path.exists(f_daily):
        daily_ds = daily_variance(ds_ol, ds_da, args.vars, args.eps, log)
        write_nc(daily_ds, f_daily, chunks, log)
    else:
        log.info(f"[skip] {f_daily} exists")

    # B) Within-month daily variance (seasonal)

    # make time strictly increasing and unique on both
    ds_ol = ensure_monotonic_unique_time(ds_ol)
    ds_da = ensure_monotonic_unique_time(ds_da)
    # re-align on the common time axis exactly
    common = np.intersect1d(ds_ol.time.values, ds_da.time.values)
    ds_ol = ds_ol.sel(time=common)
    ds_da = ds_da.sel(time=common)
    
    if not os.path.exists(f_inmonth):
        if not args.decode_times:
            log.info("Decoding CF time for within-month ops …")
            ds_ol = xr.decode_cf(ds_ol); ds_da = xr.decode_cf(ds_da)
        inmonth_ds = within_month_daily_variance(ds_ol, ds_da, args.vars, args.eps, log)
        write_nc(inmonth_ds, f_inmonth, chunks, log)
    else:
        log.info(f"[skip] {f_inmonth} exists")

    # C) Monthly-mean IAV
    if not os.path.exists(f_mmiav):
        if args.ol_monthly and args.da_monthly:
            mmiav_ds = monthlymean_variance_from_monthfiles(args.ol_monthly, args.da_monthly,
                                                            args.vars, chunks, args.eps, log)
            # Try to attach lat/lon if missing
            if "lat" not in mmiav_ds.coords or "lon" not in mmiav_ds.coords:
                mmiav_ds = mmiav_ds.assign_coords(lat=ds_ol.lat, lon=ds_ol.lon)
        else:
            if not args.decode_times:
                log.info("Decoding CF time for monthly-mean resample …")
                ds_ol = xr.decode_cf(ds_ol); ds_da = xr.decode_cf(ds_da)
            mmiav_ds = monthlymean_variance_from_daily(ds_ol, ds_da, args.vars, args.eps, log)
        write_nc(mmiav_ds, f_mmiav, chunks, log)
    else:
        log.info(f"[skip] {f_mmiav} exists")

    # D) Annual-mean IAV
    if not os.path.exists(f_aiav):
        if not args.decode_times:
            log.info("Decoding CF time for annual resample …")
            ds_ol = xr.decode_cf(ds_ol); ds_da = xr.decode_cf(ds_da)
        aiav_ds = annualmean_variance(ds_ol, ds_da, args.vars, args.eps, log)
        write_nc(aiav_ds, f_aiav, chunks, log)
    else:
        log.info(f"[skip] {f_aiav} exists")

    log.info("All variance products done.")

if __name__ == "__main__":
    main([])
