In [None]:
#!/usr/bin/env python3
import os
import numpy as np
import scipy.io as sio
from netCDF4 import Dataset

# ============================================================
# Build Stokes-drift-bands + wind stress forcing NetCDF
#   - Input structure: 5mps for intial mixed layer depth 10m
#       LES_DIR/
#          021/LESout.mat
#          031/LESout.mat
#          ...
#   - Output: one NetCDF per case
#
# CHANGE in this version:
#   - Output time axis is REGULAR every 10 minutes, starting from 0 minutes
#   - All fields are interpolated onto that 10-min grid
# ============================================================

# -----------------------
# USER SETTINGS
# -----------------------
LES_DIR = "/archive/bgr/Datasets/LES/MoreHurr/TS05_ML10/LT/"
OUT_DIR = "/archive/Qian.Xiao/Qian.Xiao/MOM6_kappa_ePBL/ice/10m_Forcing/"
os.makedirs(OUT_DIR, exist_ok=True)

CASE_LIST = [
    "021","031","036","038","040","042","044","046","048","050",
    "051","053","055","057","059","061","063","065","071","081",
]

# 读 wavenumber（推荐用你已有的14-band模板 forcing 文件）
TEMPLATE_NC = None  # e.g. "/path/to/template_14bands.nc"
WAVENUMBERS = [0.006, 0.01 , 0.02 , 0.04 , 0.06 , 0.08 , 0.1  , 0.2  , 0.4  , 0.6  ,
               0.8  , 1.   , 2.   , 4. ]  # length 14

# Lat/Lon 必须 -10,10
LAT_2 = np.array([-10.0, 10.0], dtype="f8")
LON_2 = np.array([-10.0, 10.0], dtype="f8")

# Fit controls
ZMAX_FIT  = 100.0       # meters (top 100m)
RIDGE_LAM = 1e-8        # small ridge
FILL_US   = 1.0e32      # fillvalue for Usx/Usy

TIME_UNITS = "minutes since 2011-04-01 00:00:00"
TIME_CAL   = "gregorian"

# Output time grid control
DT_MIN_OUT = 10  # minutes, fixed output interval


# -----------------------
# helpers
# -----------------------
def load_les_mat(case_id: str, root_dir: str):
    """Load LESout.mat from root_dir/<case_id>/LESout.mat"""
    f = os.path.join(root_dir, case_id, "LESout.mat")
    if not os.path.exists(f):
        raise FileNotFoundError(f)
    m = sio.loadmat(f)

    t_sec = np.asarray(m["t"]).squeeze().astype("f8")   # seconds
    z     = np.asarray(m["z"]).T.squeeze().astype("f8") # meters, positive down

    Us = np.asarray(m["Us"]).astype("f8")              # (Nt, Nz)
    Vs = np.asarray(m["Vs"]).astype("f8")              # (Nt, Nz)

    tau13l = np.asarray(m["tau13l"]).astype("f8")      # (Nt, Nz?) we use [:,0]
    tau23l = np.asarray(m["tau23l"]).astype("f8")

    return t_sec, z, Us, Vs, tau13l, tau23l


def get_k(template_nc, wavenumbers):
    if template_nc is not None and os.path.exists(template_nc):
        with Dataset(template_nc, "r") as nc:
            k = np.asarray(nc.variables["wavenumber"][:], dtype="f8")
        if k.size != 14:
            raise ValueError(f"Template wavenumber size != 14: {k.size}")
        return k

    if wavenumbers is None:
        raise ValueError("TEMPLATE_NC is None and WAVENUMBERS is None. Provide 14 wavenumbers (1/m).")
    k = np.asarray(wavenumbers, dtype="f8")
    if k.size != 14:
        raise ValueError(f"WAVENUMBERS length must be 14, got {k.size}")
    return k


def fit_exp_modes(z, prof_tz, k, zmax=100.0, ridge_lam=0.0):
    """
    Fit prof(z,t) ≈ sum_i c_i(t) * exp(-2*k_i*z), using z<=zmax.
    Return c(t,i) shape (Nt,14).
    """
    z = np.asarray(z, dtype="f8")
    prof = np.asarray(prof_tz, dtype="f8")
    k = np.asarray(k, dtype="f8")

    if prof.ndim != 2:
        raise ValueError(f"profile must be 2D (Nt,Nz). Got {prof.shape}")
    Nt, Nz = prof.shape
    if Nz != z.size:
        raise ValueError(f"z size {z.size} != profile Nz {Nz}")

    m = np.isfinite(z) & (z <= zmax)
    zfit = z[m]
    if zfit.size < 14:
        raise ValueError(f"Too few z points <= {zmax}m: {zfit.size} (<14)")

    A = np.exp(-2.0 * zfit[:, None] * k[None, :])  # (Nzfit,14)

    coeff = np.full((Nt, 14), np.nan, dtype="f8")

    for it in range(Nt):
        y = prof[it, m]
        good = np.isfinite(y)
        if good.sum() < 14:
            continue
        Ag = A[good, :]
        yg = y[good]

        M = Ag.T @ Ag
        if ridge_lam > 0:
            M = M + ridge_lam * np.eye(14)
        rhs = Ag.T @ yg
        coeff[it, :] = np.linalg.solve(M, rhs)

    return coeff


def interp_1d(x_old, y_old, x_new):
    """
    1D linear interpolation with NaN-safe handling.
    x_old, x_new in same units (minutes here).
    """
    x_old = np.asarray(x_old, dtype="f8")
    y_old = np.asarray(y_old, dtype="f8")
    x_new = np.asarray(x_new, dtype="f8")

    m = np.isfinite(x_old) & np.isfinite(y_old)
    if m.sum() < 2:
        return np.full_like(x_new, np.nan, dtype="f8")

    xo = x_old[m]
    yo = y_old[m]
    idx = np.argsort(xo)
    xo = xo[idx]
    yo = yo[idx]

    return np.interp(x_new, xo, yo)


def write_forcing_nc(out_path, t_min_i8, k,
                     taux_t, tauy_t, usx_t14, usy_t14,
                     description="Stokes drift for location XXX"):

    with Dataset(out_path, "w", format="NETCDF4") as nc:
        nc.createDimension("Time", None)
        nc.createDimension("Lat", 2)
        nc.createDimension("Lon", 2)
        nc.createDimension("wavenumber", 14)

        vtime = nc.createVariable("Time", "i8", ("Time",))
        vtime.units = TIME_UNITS
        vtime.calendar = TIME_CAL
        vtime[:] = t_min_i8

        vwn = nc.createVariable("wavenumber", "f8", ("wavenumber",), fill_value=np.nan)
        vwn[:] = k

        vlat = nc.createVariable("Lat", "f8", ("Lat",), fill_value=np.nan)
        vlon = nc.createVariable("Lon", "f8", ("Lon",), fill_value=np.nan)
        vlat[:] = LAT_2
        vlon[:] = LON_2

        vtaux = nc.createVariable("Taux", "f8", ("Time","Lat","Lon"), fill_value=np.nan)
        vtauy = nc.createVariable("Tauy", "f8", ("Time","Lat","Lon"), fill_value=np.nan)
        vtaux[:, :, :] = taux_t[:, None, None]
        vtauy[:, :, :] = tauy_t[:, None, None]

        for i in range(14):
            vx = nc.createVariable(f"Usx{i+1}", "f8", ("Time","Lat","Lon"), fill_value=FILL_US)
            vy = nc.createVariable(f"Usy{i+1}", "f8", ("Time","Lat","Lon"), fill_value=FILL_US)
            vx[:, :, :] = usx_t14[:, i][:, None, None]
            vy[:, :, :] = usy_t14[:, i][:, None, None]

        nc.description = description


def main():
    k = get_k(TEMPLATE_NC, WAVENUMBERS)

    for case_id in CASE_LIST:
        print(f"\n=== {case_id} ===")
        t_sec, z, Us, Vs, tau13l, tau23l = load_les_mat(case_id, LES_DIR)

        # LES original time in minutes (since 2011-04-01 00:00:00)
        t_min_les = t_sec / 60.0

        # stresses (your convention)
        taux_les = tau13l[:, 0] * 1000.0
        tauy_les = tau23l[:, 0] * 1000.0

        # 14-band fit on LES native times
        usx14_les = fit_exp_modes(z, Us, k, zmax=ZMAX_FIT, ridge_lam=RIDGE_LAM)
        usy14_les = fit_exp_modes(z, Vs, k, zmax=ZMAX_FIT, ridge_lam=RIDGE_LAM)

        # ------------------------------------------------------------
        # OUTPUT time grid: regular every 10 min, STARTING FROM 0
        # ------------------------------------------------------------
        DUR_MIN = 3 * 24 * 60          # 4320 minutes
        t_min_i8 = np.arange(0, DUR_MIN + DT_MIN_OUT, DT_MIN_OUT, dtype="i8")  # 0..4320
        if t_min_i8.size != 433:
            raise RuntimeError(f"Expected 433 time steps, got {t_min_i8.size}. Check DT_MIN_OUT.")
        t_min_out = t_min_i8.astype("f8")

        # interpolate tau to 10-min grid
        taux = interp_1d(t_min_les, taux_les, t_min_out)
        tauy = interp_1d(t_min_les, tauy_les, t_min_out)

        # interpolate 14-band coefficients to 10-min grid
        usx14 = np.zeros((t_min_i8.size, 14), dtype="f8")
        usy14 = np.zeros((t_min_i8.size, 14), dtype="f8")
        for jj in range(14):
            usx14[:, jj] = interp_1d(t_min_les, usx14_les[:, jj], t_min_out)
            usy14[:, jj] = interp_1d(t_min_les, usy14_les[:, jj], t_min_out)

        out_nc = os.path.join(OUT_DIR, f"05_{case_id}.nc")
        write_forcing_nc(
            out_nc, t_min_i8, k,
            taux, tauy, usx14, usy14,
            description=f"Stokes drift + wind stress forcing from LES {case_id}"
        )
        print("Wrote:", out_nc)


if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
import os
import numpy as np
import scipy.io as sio
from netCDF4 import Dataset

# ============================================================
#bulk LES fit
# Build Stokes-drift-bands + wind stress forcing NetCDF
#   - Input structure: 10mps, mixed layer depth: 10m
#       LES_DIR/
#          021/LESout.mat
#          031/LESout.mat
#          ...
#   - Output: one NetCDF per case
#
# CHANGE in this version:
#   - Output time axis is REGULAR every 10 minutes, starting from 0 minutes
#   - All fields are interpolated onto that 10-min grid
#
# IMPORTANT CHANGE (your request):
#   - Fit uses LAYER-MEAN basis consistent with LES z being cell centers (bulk)
#     instead of point-value exp(-2*k*z_center).
# ============================================================

# -----------------------
# USER SETTINGS
# -----------------------
LES_DIR = "/archive/bgr/Datasets/LES/MoreHurr/TS10_ML10/LT/"
OUT_DIR = "/archive/Qian.Xiao/Qian.Xiao/MOM6_kappa_ePBL/ice/10m_Forcing/"
os.makedirs(OUT_DIR, exist_ok=True)

CASE_LIST = [
    "021","031","036","038","040","042","044","046","048","050",
    "051","053","055","057","059","061","063","065","071","081",
]

# 读 wavenumber（推荐用你已有的14-band模板 forcing 文件）
TEMPLATE_NC = None  # e.g. "/path/to/template_14bands.nc"
WAVENUMBERS = [0.006, 0.01 , 0.02 , 0.04 , 0.06 , 0.08 , 0.1  , 0.2  , 0.4  , 0.6  ,
               0.8  , 1.   , 2.   , 4. ]  # length 14

# Lat/Lon 必须 -10,10
LAT_2 = np.array([-10.0, 10.0], dtype="f8")
LON_2 = np.array([-10.0, 10.0], dtype="f8")

# Fit controls
ZMAX_FIT  = 100.0       # meters (top 100m)
RIDGE_LAM = 1e-8        # small ridge
FILL_US   = 1.0e32      # fillvalue for Usx/Usy

TIME_UNITS = "minutes since 2011-04-01 00:00:00"
TIME_CAL   = "gregorian"

# Output time grid control
DT_MIN_OUT = 10  # minutes, fixed output interval


# -----------------------
# helpers
# -----------------------
def load_les_mat(case_id: str, root_dir: str):
    """Load LESout.mat from root_dir/<case_id>/LESout.mat"""
    f = os.path.join(root_dir, case_id, "LESout.mat")
    if not os.path.exists(f):
        raise FileNotFoundError(f)
    m = sio.loadmat(f)

    t_sec = np.asarray(m["t"]).squeeze().astype("f8")    # seconds
    z     = np.asarray(m["z"]).T.squeeze().astype("f8")  # meters, positive down (CELL CENTERS)

    Us = np.asarray(m["Us"]).astype("f8")                # (Nt, Nz) layer-mean at z-centers
    Vs = np.asarray(m["Vs"]).astype("f8")                # (Nt, Nz)

    tau13l = np.asarray(m["tau13l"]).astype("f8")        # (Nt, Nz?) we use [:,0]
    tau23l = np.asarray(m["tau23l"]).astype("f8")

    return t_sec, z, Us, Vs, tau13l, tau23l


def get_k(template_nc, wavenumbers):
    if template_nc is not None and os.path.exists(template_nc):
        with Dataset(template_nc, "r") as nc:
            k = np.asarray(nc.variables["wavenumber"][:], dtype="f8")
        if k.size != 14:
            raise ValueError(f"Template wavenumber size != 14: {k.size}")
        return k

    if wavenumbers is None:
        raise ValueError("TEMPLATE_NC is None and WAVENUMBERS is None. Provide 14 wavenumbers (1/m).")
    k = np.asarray(wavenumbers, dtype="f8")
    if k.size != 14:
        raise ValueError(f"WAVENUMBERS length must be 14, got {k.size}")
    return k


def centers_to_interfaces(zc):
    """
    Convert cell-center depths to interfaces.
    Assumes top boundary at z=0 and roughly monotonic spacing.
    zc: (Nz,) positive down
    Returns:
      zi: (Nz+1,) interfaces with zi[0]=0
      dz: (Nz,) layer thickness
    """
    zc = np.asarray(zc, dtype="f8").squeeze()
    if zc.ndim != 1:
        raise ValueError(f"zc must be 1D, got shape {zc.shape}")
    if zc.size < 2:
        raise ValueError("Need at least 2 z levels to infer interfaces.")
    if np.any(np.diff(zc) <= 0):
        # if not strictly increasing, sort (defensive)
        order = np.argsort(zc)
        zc = zc[order]

    Nz = zc.size
    zi = np.empty(Nz + 1, dtype="f8")
    zi[0] = 0.0
    zi[1:-1] = 0.5 * (zc[:-1] + zc[1:])
    # bottom interface extrapolated using last spacing
    zi[-1] = zc[-1] + 0.5 * (zc[-1] - zc[-2])

    dz = zi[1:] - zi[:-1]
    if np.any(dz <= 0):
        raise ValueError("Non-positive layer thickness inferred; check z centers.")
    return zi, dz


def layermean_exp_basis(zi, k):
    """
    Build layer-mean basis for exp(-2*k*z) over each layer.
    For layer [z_top, z_bot]:
      <exp(-2*k*z)> = (exp(-2*k*z_top) - exp(-2*k*z_bot)) / (2*k*dz)
    zi: (Nz+1,) interfaces
    k:  (Nb,)
    Returns Abar: (Nz, Nb)
    """
    zi = np.asarray(zi, dtype="f8")
    k  = np.asarray(k,  dtype="f8")

    ztop = zi[:-1][:, None]
    zbot = zi[1:][:, None]
    dz   = (zbot - ztop)

    # avoid divide by 0 for any pathological k (shouldn't happen here)
    kk = np.where(np.abs(k) > 0, k, np.nan)
    Abar = (np.exp(-2.0 * kk[None, :] * ztop) - np.exp(-2.0 * kk[None, :] * zbot)) / (2.0 * kk[None, :] * dz)
    return Abar


def fit_exp_modes(zc, prof_tz, k, zmax=100.0, ridge_lam=0.0):
    """
    LAYER-MEAN fit (consistent with LES z being cell centers / bulk):
      prof_layermean(z,t) ≈ sum_i c_i(t) * <exp(-2*k_i*z)>_layer
    where < > is layer-average over each cell.

    Inputs:
      zc: cell-center depths (Nz,)
      prof_tz: (Nt, Nz) layer-mean profile (Us or Vs)
      k: (14,)
    Returns:
      coeff: (Nt, 14) interpreted as surface amplitudes for each band.
    """
    zc = np.asarray(zc, dtype="f8").squeeze()
    prof = np.asarray(prof_tz, dtype="f8")
    k = np.asarray(k, dtype="f8").squeeze()

    if prof.ndim != 2:
        raise ValueError(f"profile must be 2D (Nt,Nz). Got {prof.shape}")
    Nt, Nz = prof.shape
    if Nz != zc.size:
        raise ValueError(f"z size {zc.size} != profile Nz {Nz}")

    # build interfaces and layer-mean basis
    zi, dz = centers_to_interfaces(zc)
    Abar_full = layermean_exp_basis(zi, k)  # (Nz,14)

    # only use layers with center <= zmax
    m = np.isfinite(zc) & (zc <= zmax)
    if m.sum() < 14:
        raise ValueError(f"Too few z layers with center <= {zmax}m: {m.sum()} (<14)")

    A = Abar_full[m, :]  # (Nzfit,14)

    coeff = np.full((Nt, 14), np.nan, dtype="f8")
    for it in range(Nt):
        y = prof[it, m]
        good = np.isfinite(y)
        if good.sum() < 14:
            continue
        Ag = A[good, :]
        yg = y[good]

        M = Ag.T @ Ag
        if ridge_lam > 0:
            M = M + ridge_lam * np.eye(14)
        rhs = Ag.T @ yg
        coeff[it, :] = np.linalg.solve(M, rhs)

    return coeff


def interp_1d(x_old, y_old, x_new):
    """
    1D linear interpolation with NaN-safe handling.
    x_old, x_new in same units (minutes here).
    """
    x_old = np.asarray(x_old, dtype="f8")
    y_old = np.asarray(y_old, dtype="f8")
    x_new = np.asarray(x_new, dtype="f8")

    m = np.isfinite(x_old) & np.isfinite(y_old)
    if m.sum() < 2:
        return np.full_like(x_new, np.nan, dtype="f8")

    xo = x_old[m]
    yo = y_old[m]
    idx = np.argsort(xo)
    xo = xo[idx]
    yo = yo[idx]

    return np.interp(x_new, xo, yo)


def write_forcing_nc(out_path, t_min_i8, k,
                     taux_t, tauy_t, usx_t14, usy_t14,
                     description="Stokes drift for location XXX"):

    with Dataset(out_path, "w", format="NETCDF4") as nc:
        nc.createDimension("Time", None)
        nc.createDimension("Lat", 2)
        nc.createDimension("Lon", 2)
        nc.createDimension("wavenumber", 14)

        vtime = nc.createVariable("Time", "i8", ("Time",))
        vtime.units = TIME_UNITS
        vtime.calendar = TIME_CAL
        vtime[:] = t_min_i8

        vwn = nc.createVariable("wavenumber", "f8", ("wavenumber",), fill_value=np.nan)
        vwn[:] = k

        vlat = nc.createVariable("Lat", "f8", ("Lat",), fill_value=np.nan)
        vlon = nc.createVariable("Lon", "f8", ("Lon",), fill_value=np.nan)
        vlat[:] = LAT_2
        vlon[:] = LON_2

        vtaux = nc.createVariable("Taux", "f8", ("Time","Lat","Lon"), fill_value=np.nan)
        vtauy = nc.createVariable("Tauy", "f8", ("Time","Lat","Lon"), fill_value=np.nan)
        vtaux[:, :, :] = taux_t[:, None, None]
        vtauy[:, :, :] = tauy_t[:, None, None]

        for i in range(14):
            vx = nc.createVariable(f"Usx{i+1}", "f8", ("Time","Lat","Lon"), fill_value=FILL_US)
            vy = nc.createVariable(f"Usy{i+1}", "f8", ("Time","Lat","Lon"), fill_value=FILL_US)
            vx[:, :, :] = usx_t14[:, i][:, None, None]
            vy[:, :, :] = usy_t14[:, i][:, None, None]

        nc.description = description


def main():
    k = get_k(TEMPLATE_NC, WAVENUMBERS)

    for case_id in CASE_LIST:
        print(f"\n=== {case_id} ===")
        t_sec, z, Us, Vs, tau13l, tau23l = load_les_mat(case_id, LES_DIR)

        # LES original time in minutes (since 2011-04-01 00:00:00)
        t_min_les = t_sec / 60.0

        # stresses (your convention)
        taux_les = tau13l[:, 0] * 1000.0
        tauy_les = tau23l[:, 0] * 1000.0

        # 14-band fit on LES native times (LAYER-MEAN basis)
        usx14_les = fit_exp_modes(z, Us, k, zmax=ZMAX_FIT, ridge_lam=RIDGE_LAM)
        usy14_les = fit_exp_modes(z, Vs, k, zmax=ZMAX_FIT, ridge_lam=RIDGE_LAM)

        # ------------------------------------------------------------
        # OUTPUT time grid: regular every 10 min, STARTING FROM 0
        # ------------------------------------------------------------
        DUR_MIN = 3 * 24 * 60  # 4320 minutes
        t_min_i8 = np.arange(0, DUR_MIN + DT_MIN_OUT, DT_MIN_OUT, dtype="i8")  # 0..4320
        if t_min_i8.size != 433:
            raise RuntimeError(f"Expected 433 time steps, got {t_min_i8.size}. Check DT_MIN_OUT.")
        t_min_out = t_min_i8.astype("f8")

        # interpolate tau to 10-min grid
        taux = interp_1d(t_min_les, taux_les, t_min_out)
        tauy = interp_1d(t_min_les, tauy_les, t_min_out)

        # interpolate 14-band coefficients to 10-min grid
        usx14 = np.zeros((t_min_i8.size, 14), dtype="f8")
        usy14 = np.zeros((t_min_i8.size, 14), dtype="f8")
        for jj in range(14):
            usx14[:, jj] = interp_1d(t_min_les, usx14_les[:, jj], t_min_out)
            usy14[:, jj] = interp_1d(t_min_les, usy14_les[:, jj], t_min_out)

        out_nc = os.path.join(OUT_DIR, f"10_{case_id}.nc")
        write_forcing_nc(
            out_nc, t_min_i8, k,
            taux, tauy, usx14, usy14,
            description=f"Stokes drift + wind stress forcing from LES {case_id}"
        )
        print("Wrote:", out_nc)


if __name__ == "__main__":
    main()

    


In [None]:
#!/usr/bin/env python3
import os
import glob
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta
from netCDF4 import Dataset

# ============================================================
# Compare forcing NetCDF vs LES (LESout.mat)
#   Panels: check for my own wind-wave output files 10m 5mps
#using LES z0 as the first layer of z=0 to fit Or as z0 layer-average bulk stokes drift
#     Row1: Taux, Tauy
#     Row2: Usx(top cell), Usy(top cell)
#     Row3: |Us|(top cell)
#     Row4: profile check @ one forcing time (Usx(z), Usy(z), |Us|(z))
#   No saving, only plt.show()
# ============================================================

# -----------------------
# USER SETTINGS
# -----------------------
FORCING_NC = "/archive/Qian.Xiao/Qian.Xiao/MOM6_kappa_ePBL/ice/10m_Forcing/10_046.nc"
LES_ROOT   = "/archive/bgr/Datasets/LES/MoreHurr/TS10_ML10/LT/"  # contains 021/031/... and one extra subdir
CASE_ID    = None        # None -> infer from forcing filename "05_081.nc" -> "081"

IT_FORCE_CHECK = -1      # forcing time index for profile check (-1 = last)
ZMAX_PLOT  = 240.0       # meters for profile plot

# conventions
RHO_W = 1000.0           # for LES tau*1000 like you use
FILL_US_THRES = 1e20     # treat >1e20 as missing (forcing uses 1e32)

# -----------------------
# time helpers (avoid cftime)
# -----------------------
def parse_time_units(units_str: str):
    """
    Parse units like: "minutes since 2011-04-01 00:00:00"
    Return (base_datetime, seconds_per_unit)
    """
    s = units_str.strip()
    if "since" not in s:
        raise ValueError(f"Unrecognized time units (no 'since'): {units_str}")
    left, right = s.split("since", 1)
    unit = left.strip().lower()
    base_str = right.strip()

    fmts = ["%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M", "%Y-%m-%d"]
    base = None
    for fmt in fmts:
        try:
            base = datetime.strptime(base_str, fmt)
            break
        except ValueError:
            pass
    if base is None:
        base = datetime.fromisoformat(base_str)

    sec_per = {
        "seconds": 1.0, "second": 1.0, "sec": 1.0, "s": 1.0,
        "minutes": 60.0, "minute": 60.0, "min": 60.0,
        "hours": 3600.0, "hour": 3600.0, "hr": 3600.0, "h": 3600.0,
        "days": 86400.0, "day": 86400.0, "d": 86400.0,
    }
    if unit not in sec_per:
        raise ValueError(f"Unsupported time unit '{unit}' in units='{units_str}'")
    return base, sec_per[unit]

def timevals_to_datetimes(time_vals, units_str):
    base, sec_per_unit = parse_time_units(units_str)
    t = np.asarray(time_vals, dtype="f8")
    return np.array([base + timedelta(seconds=float(v) * sec_per_unit) for v in t], dtype=object)

# -----------------------
# path + IO helpers
# -----------------------
def infer_case_id_from_forcing(path_nc):
    b = os.path.basename(path_nc)
    stem = b.split(".nc")[0]          # "05_081"
    parts = stem.split("_")
    if len(parts) >= 2 and parts[-1].isdigit():
        return parts[-1]
    return None

def find_les_mat(case_id, root_dir):
    # one extra directory layer, so recursive search
    cdir = os.path.join(root_dir, case_id)
    cand = sorted(glob.glob(os.path.join(cdir, "**", "LESout.mat"), recursive=True))
    if not cand:
        raise FileNotFoundError(f"Cannot find LESout.mat under: {cdir} (recursive)")
    return cand[0]

def load_les(les_mat_path):
    m = sio.loadmat(les_mat_path)
    t_sec = np.asarray(m["t"]).squeeze().astype("f8")       # seconds since base
    z     = np.asarray(m["z"]).squeeze().astype("f8")       # meters, positive down

    Us    = np.asarray(m["Us"]).astype("f8")                # (Nt,Nz) or (Nz,Nt)
    Vs    = np.asarray(m["Vs"]).astype("f8")
    if Us.shape[0] != t_sec.size and Us.shape[1] == t_sec.size:
        Us = Us.T
        Vs = Vs.T

    # tau from LES (your convention)
    tau13l = np.asarray(m["tau13l"]).astype("f8")           # (Nt,Nz?) or (Nz,Nt)
    tau23l = np.asarray(m["tau23l"]).astype("f8")
    if tau13l.shape[0] != t_sec.size and tau13l.shape[1] == t_sec.size:
        tau13l = tau13l.T
        tau23l = tau23l.T
    taux = tau13l[:, 0] * RHO_W
    tauy = tau23l[:, 0] * RHO_W

    return t_sec, z, Us, Vs, taux, tauy

def load_forcing(nc_path):
    with Dataset(nc_path, "r") as nc:
        t = np.asarray(nc.variables["Time"][:], dtype="f8")  # minutes since base (stored as int64)
        tunits = nc.variables["Time"].units
        k = np.asarray(nc.variables["wavenumber"][:], dtype="f8")  # (14,)

        # tau in forcing
        taux = np.asarray(nc.variables["Taux"][:, 0, 0], dtype="f8")
        tauy = np.asarray(nc.variables["Tauy"][:, 0, 0], dtype="f8")

        # 14 bands
        Nt = t.size
        Usx14 = np.zeros((Nt, 14), dtype="f8")
        Usy14 = np.zeros((Nt, 14), dtype="f8")
        for i in range(14):
            vx = np.asarray(nc.variables[f"Usx{i+1}"][:, 0, 0], dtype="f8")
            vy = np.asarray(nc.variables[f"Usy{i+1}"][:, 0, 0], dtype="f8")
            vx = np.where(np.abs(vx) > FILL_US_THRES, np.nan, vx)
            vy = np.where(np.abs(vy) > FILL_US_THRES, np.nan, vy)
            Usx14[:, i] = vx
            Usy14[:, i] = vy

    return t, tunits, k, taux, tauy, Usx14, Usy14

# -----------------------
# reconstruction helpers
# -----------------------
def recon_from_bands_at_z(Us14_ti, k, z):
    """recon(z,t) = sum_i Us_i(t) * exp(-2*k_i*z)"""
    z = float(z)
    w = np.exp(-2.0 * k[None, :] * z)  # (1,14)
    return np.nansum(Us14_ti * w, axis=1)

def recon_profile_from_bands(Us14_i, k, zvec):
    """single-time profile"""
    zvec = np.asarray(zvec, dtype="f8")
    A = np.exp(-2.0 * zvec[:, None] * k[None, :])  # (Nz,14)
    return np.nansum(A * Us14_i[None, :], axis=1)

# -----------------------
# main
# -----------------------
def main():
    case_id = CASE_ID or infer_case_id_from_forcing(FORCING_NC)
    if case_id is None:
        raise ValueError("Cannot infer CASE_ID from forcing filename; set CASE_ID explicitly.")
    les_mat = find_les_mat(case_id, LES_ROOT)

    print("FORCING:", FORCING_NC)
    print("LES MAT :", les_mat)

    # Load
    tF_min, tF_units, k, tauxF, tauyF, Usx14, Usy14 = load_forcing(FORCING_NC)
    tL_sec, zL, UsL, VsL, tauxL, tauyL = load_les(les_mat)

    # Absolute datetime axes
    tF_dt = timevals_to_datetimes(tF_min, tF_units)
    base_dt, _ = parse_time_units(tF_units)  # should be 2011-04-01 00:00:00
    tL_dt = np.array([base_dt + timedelta(seconds=float(s)) for s in tL_sec], dtype=object)

    # top cell depth (LES)
    z0 = float(zL[0])

    # "surface/top cell" Stokes drift series
    UsxF_top = recon_from_bands_at_z(Usx14, k, z0)
    UsyF_top = recon_from_bands_at_z(Usy14, k, z0)
    absF_top = np.sqrt(UsxF_top**2 + UsyF_top**2)

    UsxL_top = UsL[:, 0]
    UsyL_top = VsL[:, 0]
    absL_top = np.sqrt(UsxL_top**2 + UsyL_top**2)

    # profile check at forcing index
    itF = IT_FORCE_CHECK if IT_FORCE_CHECK >= 0 else (len(tF_min) + IT_FORCE_CHECK)
    itF = int(np.clip(itF, 0, len(tF_min)-1))

    # nearest LES time to forcing time
    tF_here_min = float(tF_min[itF])
    tL_min = tL_sec / 60.0
    itL = int(np.argmin(np.abs(tL_min - tF_here_min)))
    dt_min = float(tL_min[itL] - tF_here_min)

    print(f"[profile] forcing it={itF}, t={tF_dt[itF]}")
    print(f"[profile] LES nearest it={itL}, t={tL_dt[itL]}, Δt={dt_min:+.2f} min")

    zmask = zL <= ZMAX_PLOT
    zP = zL[zmask]

    UsxF_prof = recon_profile_from_bands(Usx14[itF, :], k, zP)
    UsyF_prof = recon_profile_from_bands(Usy14[itF, :], k, zP)
    absF_prof = np.sqrt(UsxF_prof**2 + UsyF_prof**2)

    UsxL_prof = UsL[itL, zmask]
    UsyL_prof = VsL[itL, zmask]
    absL_prof = np.sqrt(UsxL_prof**2 + UsyL_prof**2)

    # -----------------------
    # PLOT (like your screenshot layout)
    # -----------------------
    fig = plt.figure(figsize=(13.5, 11.5))
    gs = fig.add_gridspec(nrows=4, ncols=6, height_ratios=[1.0, 1.0, 1.0, 1.25],
                          hspace=0.55, wspace=0.65)

    ax_tx  = fig.add_subplot(gs[0, 0:3])
    ax_ty  = fig.add_subplot(gs[0, 3:6])
    ax_ux  = fig.add_subplot(gs[1, 0:3])
    ax_uy  = fig.add_subplot(gs[1, 3:6])
    ax_abs = fig.add_subplot(gs[2, :])

    ax_px  = fig.add_subplot(gs[3, 0:2])
    ax_py  = fig.add_subplot(gs[3, 2:4])
    ax_pa  = fig.add_subplot(gs[3, 4:6])

    fmt = mdates.DateFormatter("%m-%d %H:%M")
    for ax in (ax_tx, ax_ty, ax_ux, ax_uy, ax_abs):
        ax.xaxis.set_major_formatter(fmt)
        ax.tick_params(axis="x", rotation=25)

    # --- tau time series
    ax_tx.plot(tF_dt, tauxF, label="forcing")
    ax_tx.plot(tL_dt, tauxL, "--", label="LES")
    ax_tx.set_title("Taux")
    ax_tx.legend(frameon=True)

    ax_ty.plot(tF_dt, tauyF, label="forcing")
    ax_ty.plot(tL_dt, tauyL, "--", label="LES")
    ax_ty.set_title("Tauy")
    ax_ty.legend(frameon=True)

    # --- Us top-cell series
    ax_ux.plot(tF_dt, UsxF_top, label="forcing")
    ax_ux.plot(tL_dt, UsxL_top, "--", label="LES")
    ax_ux.set_title("Us_x (surface/top cell)")
    ax_ux.legend(frameon=True)

    ax_uy.plot(tF_dt, UsyF_top, label="forcing")
    ax_uy.plot(tL_dt, UsyL_top, "--", label="LES")
    ax_uy.set_title("Us_y (surface/top cell)")
    ax_uy.legend(frameon=True)

    ax_abs.plot(tF_dt, absF_top, label="forcing |Us|")
    ax_abs.plot(tL_dt, absL_top, "--", label="LES |Us|")
    ax_abs.set_title("|Us| comparison")
    ax_abs.legend(frameon=True)

    # profile check header text
    fig.text(
        0.5, 0.275,
        f"Profile check @ forcing time {tF_dt[itF]} (LES nearest, Δt={dt_min:+.2f} min)",
        ha="center", va="center", fontsize=12
    )

    # profiles
    ax_px.plot(UsxF_prof, zP, label="forcing")
    ax_px.plot(UsxL_prof, zP, "--", label="LES")
    ax_px.set_title("Usx(z)")
    ax_px.set_ylabel("z (m, positive down)")
    ax_px.grid(True, alpha=0.25)
    ax_px.invert_yaxis()

    ax_py.plot(UsyF_prof, zP, label="forcing")
    ax_py.plot(UsyL_prof, zP, "--", label="LES")
    ax_py.set_title("Usy(z)")
    ax_py.grid(True, alpha=0.25)
    ax_py.invert_yaxis()

    ax_pa.plot(absF_prof, zP, label="forcing")
    ax_pa.plot(absL_prof, zP, "--", label="LES")
    ax_pa.set_title("|Us|(z)")
    ax_pa.grid(True, alpha=0.25)
    ax_pa.invert_yaxis()
    ax_pa.legend(frameon=True)

    plt.show()

if __name__ == "__main__":
    main()


In [None]:
import os
import re
import numpy as np
import xarray as xr
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import datetime as dt
from netCDF4 import num2date

# =========================
# USER SETTINGS check for SCM_windwaveInputfiles.ipynb wind-wave input file VS LES data 5mps 32m MLD
# =========================
FORCING_NC= f"/archive/Qian.Xiao/Qian.Xiao/MOM6_kappa_ePBL/ice/10m_Forcing/Hurr05_046.nc"
LES_MAT    = f"/archive/bgr/Datasets/LES/Hurr/LES_HUR/TC021_PROF.mat"   # <- 改成你的 LESout.mat 路径

CASE="1"

NBAND = 14
LAT_IDX = 0
LON_IDX = 0

N_FORC_MAX = 100           # 你说前100就够
PROFILE_FORC_IT = -1       # forcing 选哪个时刻做 profile check（-1=最后一个）

LES_TAUX_SCALE = 1000.0    # 你说 tau*1000
LES_TAUY_SCALE = 1000.0

# =========================
# helpers
# =========================
def parse_units_base_datetime(units_str: str) -> dt.datetime:
    m = re.search(r"since\s+(.+)$", units_str.strip())
    if not m:
        raise ValueError(f"Cannot parse base datetime from units: {units_str}")
    base = m.group(1).strip()
    for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"):
        try:
            return dt.datetime.strptime(base, fmt)
        except ValueError:
            pass
    try:
        return dt.datetime.fromisoformat(base.replace("Z",""))
    except Exception as e:
        raise ValueError(f"Cannot parse base datetime '{base}' from units '{units_str}': {e}")

def to_py_datetime_list(tlist):
    """
    Convert cftime datetimes (or numpy datetime64) to python datetime.datetime.
    Matplotlib needs python datetime or numpy datetime64.
    """
    out = []
    for t in tlist:
        # already python datetime
        if isinstance(t, dt.datetime):
            out.append(t)
            continue

        # numpy datetime64
        if isinstance(t, np.datetime64):
            # convert to python datetime (naive)
            ts = (t - np.datetime64("1970-01-01T00:00:00")) / np.timedelta64(1, "s")
            out.append(dt.datetime(1970,1,1) + dt.timedelta(seconds=float(ts)))
            continue

        # cftime objects: have year/month/day/hour/minute/second
        # second may be float; handle microseconds safely
        sec = float(getattr(t, "second"))
        isec = int(sec)
        usec = int(round((sec - isec) * 1e6))
        out.append(dt.datetime(int(t.year), int(t.month), int(t.day),
                               int(getattr(t, "hour", 0)),
                               int(getattr(t, "minute", 0)),
                               isec, usec))
    return out

def clean_fill(x):
    x = np.asarray(x, dtype=float)
    x = np.where(np.abs(x) > 1e20, np.nan, x)  # 处理 1e32 fill
    return x

def pick_point(da, lat_i=0, lon_i=0, nmax=None):
    x = da.isel(Lat=lat_i, Lon=lon_i).values
    if nmax is not None:
        x = x[:nmax]
    return clean_fill(x)

def forcing_profile_from_bands(usx_band, usy_band, k, z):
    """
    Us(z) = Σ Us0_i * exp(-2*k_i*z)
    k: 1/m, z: m (positive down)
    """
    z = np.asarray(z, dtype=float)
    k = np.asarray(k, dtype=float)
    decay = np.exp(-2.0 * k[:, None] * z[None, :])   # (nband, nz)
    usx_z = np.nansum(usx_band[:, None] * decay, axis=0)
    usy_z = np.nansum(usy_band[:, None] * decay, axis=0)
    us_z  = np.sqrt(usx_z**2 + usy_z**2)
    return usx_z, usy_z, us_z

def nearest_idx_datetime(dts, dt_target):
    sec = np.array([(x - dt_target).total_seconds() for x in dts], dtype=float)
    return int(np.argmin(np.abs(sec)))

def guess_les_time_seconds(t_array):
    """
    你如果某些 LES t 不是秒，这里自动猜一下（尽量不误判）：
    - max < 10   -> days
    - max < 300  -> hours
    - else       -> seconds
    """
    tmax = float(np.nanmax(t_array))
    if tmax < 10:
        return t_array * 86400.0
    if tmax < 300:
        return t_array * 3600.0
    return t_array

# =========================
# main
# =========================
if not os.path.exists(FORCING_NC):
    raise FileNotFoundError(FORCING_NC)
if not os.path.exists(LES_MAT):
    raise FileNotFoundError(LES_MAT)

# ---- read forcing ----
dsF = xr.open_dataset(FORCING_NC, decode_times=False)

units = dsF["Time"].attrs.get("units", "minutes since 2011-04-01 00:00:00")
cal   = dsF["Time"].attrs.get("calendar", "gregorian")

tF_raw = np.asarray(dsF["Time"].values, dtype=float)
if N_FORC_MAX is not None:
    tF_raw = tF_raw[:N_FORC_MAX]

# 关键：num2date -> cftime；再转 python datetime
try:
    tF_tmp = num2date(tF_raw, units=units, calendar=cal,
                      only_use_cftime_datetimes=False, only_use_python_datetimes=True)
except TypeError:
    # 旧版 netCDF4 没这些参数
    tF_tmp = num2date(tF_raw, units=units, calendar=cal)
tF_dt = to_py_datetime_list(tF_tmp)

tauxF = pick_point(dsF["Taux"], LAT_IDX, LON_IDX, N_FORC_MAX)
tauyF = pick_point(dsF["Tauy"], LAT_IDX, LON_IDX, N_FORC_MAX)

k = np.asarray(dsF["wavenumber"].values, dtype=float)[:NBAND]  # 按 1/m 用

usx_bands = np.zeros((NBAND, len(tF_dt)), dtype=float)
usy_bands = np.zeros((NBAND, len(tF_dt)), dtype=float)
for i in range(1, NBAND + 1):
    usx_bands[i-1, :] = pick_point(dsF[f"Usx{i}"], LAT_IDX, LON_IDX, N_FORC_MAX)
    usy_bands[i-1, :] = pick_point(dsF[f"Usy{i}"], LAT_IDX, LON_IDX, N_FORC_MAX)

dsF.close()

usx0F = np.nansum(usx_bands, axis=0)
usy0F = np.nansum(usy_bands, axis=0)
us0F  = np.sqrt(usx0F**2 + usy0F**2)

# ---- read LES ----
M = sio.loadmat(LES_MAT)

tL_raw = np.asarray(M["t"]).squeeze().astype(float)
tL_sec = guess_les_time_seconds(tL_raw)   # 自动把 days/hours 也转成 sec（如果本来就是 sec 不变）

base_dt = parse_units_base_datetime(units)  # forcing 的 epoch
tL_dt = [base_dt + dt.timedelta(seconds=float(s)) for s in tL_sec]

z = np.asarray(M["z"]).T.squeeze().astype(float)

tauxL = np.asarray(M["tau13l"])[:, 0].astype(float) * LES_TAUX_SCALE
tauyL = np.asarray(M["tau23l"])[:, 0].astype(float) * LES_TAUY_SCALE

if "Us" in M and "Vs" in M:
    usx0L = np.asarray(M["Us"])[:, 0].astype(float)
    usy0L = np.asarray(M["Vs"])[:, 0].astype(float)
    us0L  = np.sqrt(usx0L**2 + usy0L**2)
else:
    raise KeyError("LES mat missing 'Us'/'Vs' arrays for Stokes drift components.")

print("[forcing] start/end:", tF_dt[0], tF_dt[-1], "Nt=", len(tF_dt))
print("[LES]     start/end:", tL_dt[0], tL_dt[-1], "Nt=", len(tL_dt))

# =========================
# plotting (ABSOLUTE time axis)
# =========================
fig = plt.figure(figsize=(11.5, 10.0))
gs = fig.add_gridspec(3, 2, height_ratios=[1, 1, 1.1], hspace=0.35, wspace=0.25)

ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[1, 0])
ax4 = fig.add_subplot(gs[1, 1])
ax5 = fig.add_subplot(gs[2, :])

ax1.plot(tF_dt, tauxF, label="forcing")
ax1.plot(tL_dt, tauxL, "--", label="LES")
ax1.set_title("Taux"); ax1.legend()

ax2.plot(tF_dt, tauyF, label="forcing")
ax2.plot(tL_dt, tauyL, "--", label="LES")
ax2.set_title("Tauy"); ax2.legend()

ax3.plot(tF_dt, usx0F, label="forcing")
ax3.plot(tL_dt, usx0L, "--", label="LES")
ax3.set_title("Us_x (surface/top cell)"); ax3.legend()

ax4.plot(tF_dt, usy0F, label="forcing")
ax4.plot(tL_dt, usy0L, "--", label="LES")
ax4.set_title("Us_y (surface/top cell)"); ax4.legend()

ax5.plot(tF_dt, us0F, label="forcing |Us|")
ax5.plot(tL_dt, us0L, "--", label="LES |Us|")
ax5.set_title("|Us| comparison"); ax5.legend()

for ax in (ax1, ax2, ax3, ax4, ax5):
    ax.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d %H:%M"))
    ax.tick_params(axis="x", rotation=25)

fig.suptitle(f"ABS time comparison (CASE {CASE})", y=0.995, fontsize=13)
plt.show()

# =========================
# Profile check at one time
# =========================
itF = PROFILE_FORC_IT
tF_pick = tF_dt[itF]
itL = nearest_idx_datetime(tL_dt, tF_pick)

print(f"[profile] forcing it={itF}, t={tF_pick}")
print(f"[profile] LES nearest it={itL}, t={tL_dt[itL]}, Δt={(tL_dt[itL]-tF_pick).total_seconds()/60:.2f} min")

usx_band_t = usx_bands[:, itF]
usy_band_t = usy_bands[:, itF]
usxF_z, usyF_z, usF_z = forcing_profile_from_bands(usx_band_t, usy_band_t, k, z)

usxL_z = np.asarray(M["Us"])[itL, :].astype(float)
usyL_z = np.asarray(M["Vs"])[itL, :].astype(float)
usL_z  = np.sqrt(usxL_z**2 + usyL_z**2)

fig2, axs = plt.subplots(1, 3, figsize=(10.8, 3.8), sharey=True)
axs[0].plot(usxF_z, z, label="forcing"); axs[0].plot(usxL_z, z, "--", label="LES"); axs[0].set_title("Usx(z)")
axs[1].plot(usyF_z, z, label="forcing"); axs[1].plot(usyL_z, z, "--", label="LES"); axs[1].set_title("Usy(z)")
axs[2].plot(usF_z,  z, label="forcing"); axs[2].plot(usL_z,  z, "--", label="LES"); axs[2].set_title("|Us|(z)")

for ax in axs:
    ax.invert_yaxis()
    ax.grid(True, alpha=0.25)
axs[0].set_ylabel("z (m, positive down)")
axs[2].legend(loc="best")
fig2.suptitle(f"Profile check @ forcing time {tF_pick} (LES nearest)", y=1.02)
plt.tight_layout()
plt.show()


