In [None]:
#!/usr/bin/env python3 
# ------------------------------------------------------------
# Convert WW3 banded surface Stokes drift (USP: ussp/vssp)
# to MOM6 DataOverride format for 5mps used for wave-forcing file for 3D ideazlied hurricane cases:
#   TIME, LATITUDE, LONGITUDE, frequency, vert
#   x_T, y_T, x_vert_T, y_vert_T
#   Usx1..UsxN, Usy1..UsyN
#  three dimensional
# Assumes WW3 time is: "days since 1990-01-01 00:00:00"
# Writes MOM6 time as: "days since 2011-04-01 00:00:00"
# ------------------------------------------------------------
import numpy as np
import netCDF4 as NC
from netCDF4 import num2date, date2num

# --------- user settings ----------
INP = "/archive/Qian.Xiao/Qian.Xiao/FMS_Wave_Coupling_GOTM_kapp/3D_hurricane/4_deg/ww3.201104_usp.nc"
OUT = "/archive/Qian.Xiao/Qian.Xiao/FMS_Wave_Coupling_GOTM_kapp/3D_hurricane/4_deg/StokesDriftBands_201104_3DHurricaneCases.nc"
MOM6_TIME_UNITS = "days since 2011-04-01 00:00:00"
MOM6_CAL = "gregorian"
# ---------------------------------

with NC.Dataset(INP, "r") as ds:
    # --- time in WW3 file
    t_ww3 = ds.variables["time"][:]  # (Nt,)
    ww3_units = ds.variables["time"].units
    ww3_cal = getattr(ds.variables["time"], "calendar", "standard")

    # Convert WW3 time -> datetime -> MOM6 numeric time
    t_dt = num2date(t_ww3, units=ww3_units, calendar=ww3_cal)
    t_out = date2num(t_dt, units=MOM6_TIME_UNITS, calendar=MOM6_CAL).astype("f4")

    # --- frequency bands
    freq = ds.variables["f"][:].astype("f4")  # (Nb,)
    Nb = int(freq.size)

    # --- lon/lat in WW3 output are 2D (latitude, longitude)
    lon2 = ds.variables["longitude"][:, :].astype("f4")  # (Ny,Nx)
    lat2 = ds.variables["latitude"][:, :].astype("f4")   # (Ny,Nx)

    # 1D axes (regular grid assumption)
    lon1 = lon2[0, :].astype("f4")  # (Nx,)
    lat1 = lat2[:, 0].astype("f4")  # (Ny,)
    Ny, Nx = lat2.shape

    # --- stokes drift (time, f, latitude, longitude)
    ussp = ds.variables["ussp"][:]  # (Nt,Nb,Ny,Nx)
    vssp = ds.variables["vssp"][:]  # (Nt,Nb,Ny,Nx)

# Infer grid spacing for vertex arrays
dlon = float(lon1[1] - lon1[0]) if Nx > 1 else 0.0
dlat = float(lat1[1] - lat1[0]) if Ny > 1 else 0.0

# 2D T-point coords
x_T = np.tile(lon1[None, :], (Ny, 1)).astype("f4")
y_T = np.tile(lat1[:, None], (1, Nx)).astype("f4")

# Vertex arrays (vert=4)
x_vert = np.empty((4, Ny, Nx), dtype="f4")
y_vert = np.empty((4, Ny, Nx), dtype="f4")
x_vert[0,:,:] = x_T - 0.5*dlon;  y_vert[0,:,:] = y_T - 0.5*dlat
x_vert[1,:,:] = x_T + 0.5*dlon;  y_vert[1,:,:] = y_T - 0.5*dlat
x_vert[2,:,:] = x_T - 0.5*dlon;  y_vert[2,:,:] = y_T + 0.5*dlat
x_vert[3,:,:] = x_T + 0.5*dlon;  y_vert[3,:,:] = y_T + 0.5*dlat

# Write MOM6 override-style NetCDF
with NC.Dataset(OUT, "w", format="NETCDF3_CLASSIC") as out:
    out.description = "WW3 derived frequency bands formatted for MOM6."

    out.createDimension("TIME", None)
    out.createDimension("LATITUDE", Ny)
    out.createDimension("LONGITUDE", Nx)
    out.createDimension("frequency", Nb)
    out.createDimension("vert", 4)

    vT = out.createVariable("TIME", "f4", ("TIME",))
    vT.units = MOM6_TIME_UNITS
    vT.calendar = MOM6_CAL
    vT.long_name = "time"
    vT.standard_name = "time"
    vT.time_origin = "01-APR-2011 00:00:00"
    vT[:] = t_out

    vLon = out.createVariable("LONGITUDE", "f4", ("LONGITUDE",))
    vLat = out.createVariable("LATITUDE", "f4", ("LATITUDE",))
    vLon[:] = lon1
    vLat[:] = lat1

    vF = out.createVariable("frequency", "f4", ("frequency",))
    vF.long_name = "wave_frequency"
    vF.standard_name = "wave_frequency"
    vF.units = "s-1"
    vF[:] = freq

    vxT = out.createVariable("x_T", "f4", ("LATITUDE","LONGITUDE"))
    vyT = out.createVariable("y_T", "f4", ("LATITUDE","LONGITUDE"))
    vxT[:,:] = x_T
    vyT[:,:] = y_T

    vxV = out.createVariable("x_vert_T", "f4", ("vert","LATITUDE","LONGITUDE"))
    vyV = out.createVariable("y_vert_T", "f4", ("vert","LATITUDE","LONGITUDE"))
    vxV[:,:,:] = x_vert
    vyV[:,:,:] = y_vert

    # Usx1..UsxNb, Usy1..UsyNb
    for b in range(Nb):
        vx = out.createVariable(f"Usx{b+1}", "f4", ("TIME","LATITUDE","LONGITUDE"))
        vy = out.createVariable(f"Usy{b+1}", "f4", ("TIME","LATITUDE","LONGITUDE"))
        vx.units = "m s-1"
        vy.units = "m s-1"
        vx[:, :, :] = ussp[:, b, :, :].astype("f4")
        vy[:, :, :] = vssp[:, b, :, :].astype("f4")

print(f"[OK] wrote {OUT}  Nt={len(t_out)}  Nb={Nb}  Ny={Ny}  Nx={Nx}")


In [None]:
#!/usr/bin/env python3
#same type but with 2mps  cases
import numpy as np
import netCDF4 as NC
from netCDF4 import date2num
import datetime as dt

INP = "/archive/Qian.Xiao/Qian.Xiao/FMS_Wave_Coupling_GOTM_kapp/3D_hurricane/4_deg/2mps/ww3.201104_usp.nc"
OUT = "/archive/Qian.Xiao/Qian.Xiao/FMS_Wave_Coupling_GOTM_kapp/3D_hurricane/4_deg/2mps/201104_3DHurricaneCases.nc"

MOM6_TIME_UNITS = "days since 2011-04-01 00:00:00"
MOM6_CAL = "gregorian"

# 
OUT_FORMAT = "NETCDF4_CLASSIC"  # "
USE_CHUNK  = True              # 

def main():
    with NC.Dataset(INP, "r") as ds:
        # 关掉 mask/scale（更快）
        for vn in ["ussp", "vssp", "longitude", "latitude", "time", "f"]:
            if vn in ds.variables:
                ds.variables[vn].set_auto_maskandscale(False)

        # ------------------------------------------------------------
        # TIME: WW3 time -> MOM6 "days since 2011-04-01"
        # 
        # ------------------------------------------------------------
        t_ww3    = ds.variables["time"][:].astype("f8")  # days since ww3_units
        ww3_units = ds.variables["time"].units
        ww3_cal   = getattr(ds.variables["time"], "calendar", "standard")

        t0_ww3 = date2num(dt.datetime(2011, 4, 1, 0, 0, 0), units=ww3_units, calendar=ww3_cal)
        t_out  = (t_ww3 - t0_ww3).astype("f8")  # days since 2011-04-01

        Nt_in = int(t_out.size)
        if Nt_in < 2:
            raise RuntimeError("Input time length < 2, cannot infer dt.")

        dt_out = float(np.nanmedian(np.diff(t_out)))
        if not np.isfinite(dt_out) or dt_out <= 0:
            raise RuntimeError(f"Bad inferred dt_out={dt_out}")

        # 
        t_out_ext = np.concatenate([t_out, [t_out[-1] + dt_out]]).astype("f8")
        Nt_out = int(t_out_ext.size)  # Nt_in + 1

        # ------------------------------------------------------------
        # frequency
        # ------------------------------------------------------------
        freq = ds.variables["f"][:].astype("f4")
        Nb = int(freq.size)

        # ------------------------------------------------------------
        # lon/lat (2D in your WW3 output)
        # ------------------------------------------------------------
        lon2 = ds.variables["longitude"][:, :].astype("f4")  # (Ny,Nx)
        lat2 = ds.variables["latitude"][:, :].astype("f4")
        lon1 = lon2[0, :].astype("f4")
        lat1 = lat2[:, 0].astype("f4")
        Ny, Nx = lat2.shape

        # ---- grid helpers
        dlon = float(lon1[1] - lon1[0]) if Nx > 1 else 0.0
        dlat = float(lat1[1] - lat1[0]) if Ny > 1 else 0.0

        x_T = np.tile(lon1[None, :], (Ny, 1)).astype("f4")
        y_T = np.tile(lat1[:, None], (1, Nx)).astype("f4")

        x_vert = np.empty((4, Ny, Nx), dtype="f4")
        y_vert = np.empty((4, Ny, Nx), dtype="f4")
        x_vert[0, :, :] = x_T - 0.5 * dlon;  y_vert[0, :, :] = y_T - 0.5 * dlat
        x_vert[1, :, :] = x_T + 0.5 * dlon;  y_vert[1, :, :] = y_T - 0.5 * dlat
        x_vert[2, :, :] = x_T - 0.5 * dlon;  y_vert[2, :, :] = y_T + 0.5 * dlat
        x_vert[3, :, :] = x_T + 0.5 * dlon;  y_vert[3, :, :] = y_T + 0.5 * dlat

        ussp_var = ds.variables["ussp"]  # (time,f,lat,lon)
        vssp_var = ds.variables["vssp"]

        # ------------------------------------------------------------
        # write output
        # ------------------------------------------------------------
        with NC.Dataset(OUT, "w", format=OUT_FORMAT) as out:
            out.description = "WW3 derived frequency bands formatted for MOM6. (appended +1 time record)"

            out.createDimension("TIME", None)
            out.createDimension("LATITUDE", Ny)
            out.createDimension("LONGITUDE", Nx)
            out.createDimension("frequency", Nb)
            out.createDimension("vert", 4)

            # TIME
            vT = out.createVariable("TIME", "f8", ("TIME",))
            vT.units = MOM6_TIME_UNITS
            vT.calendar = MOM6_CAL
            vT.long_name = "time"
            vT.standard_name = "time"
            vT.time_origin = "01-APR-2011 00:00:00"
            vT[:] = t_out_ext

            vLon = out.createVariable("LONGITUDE", "f4", ("LONGITUDE",))
            vLat = out.createVariable("LATITUDE", "f4", ("LATITUDE",))
            vLon[:] = lon1
            vLat[:] = lat1

            vF = out.createVariable("frequency", "f4", ("frequency",))
            vF.long_name = "wave_frequency"
            vF.standard_name = "wave_frequency"
            vF.units = "s-1"
            vF[:] = freq

            out.createVariable("x_T", "f4", ("LATITUDE", "LONGITUDE"))[:, :] = x_T
            out.createVariable("y_T", "f4", ("LATITUDE", "LONGITUDE"))[:, :] = y_T
            out.createVariable("x_vert_T", "f4", ("vert", "LATITUDE", "LONGITUDE"))[:, :, :] = x_vert
            out.createVariable("y_vert_T", "f4", ("vert", "LATITUDE", "LONGITUDE"))[:, :, :] = y_vert

            # 
            kw = {}
            if OUT_FORMAT.startswith("NETCDF4") and USE_CHUNK:
                # TIME 维 chunk=1；空间维全块（Ny,Nx）
                kw = dict(zlib=True, complevel=1, chunksizes=(1, Ny, Nx))

            Usx = []
            Usy = []
            for b in range(Nb):
                vx = out.createVariable(f"Usx{b+1}", "f4", ("TIME", "LATITUDE", "LONGITUDE"), **kw)
                vy = out.createVariable(f"Usy{b+1}", "f4", ("TIME", "LATITUDE", "LONGITUDE"), **kw)
                vx.units = "m s-1"
                vy.units = "m s-1"
                Usx.append(vx)
                Usy.append(vy)

            # 
            last_us_t = None
            last_vs_t = None

            for t in range(Nt_in):
                if (t % 12 == 0) or (t >= Nt_in - 3):
                    print(f"[write] t={t+1}/{Nt_in} (of input)", flush=True)

                # 每次只读一张 (Nb,Ny,Nx)
                us_t = ussp_var[t, :, :, :].astype("f4")
                vs_t = vssp_var[t, :, :, :].astype("f4")
                last_us_t, last_vs_t = us_t, vs_t

                for b in range(Nb):
                    Usx[b][t, :, :] = us_t[b, :, :]
                    Usy[b][t, :, :] = vs_t[b, :, :]

                
                if t % 24 == 0:
                    out.sync()

            #
            t = Nt_in
            print(f"[append] extra time record: index {t} / {Nt_out-1}", flush=True)
            for b in range(Nb):
                Usx[b][t, :, :] = last_us_t[b, :, :]
                Usy[b][t, :, :] = last_vs_t[b, :, :]

            out.sync()

    print(f"[OK] wrote {OUT}")
    print(f"     input Nt={Nt_in}, output Nt={Nt_out}, appended dt(days)={dt_out:g}")

if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
# ------------------------------------------------------------
# Extract 1D SCM wave forcing (360 files) from 3D MOM6 override-format
# StokesDriftBands file. This is for 5mps 1D wav-forcing input file 360 sampling points
#
# Output per station:
#   TIME, LATITUDE(2), LONGITUDE(2), frequency, vert
#   x_T, y_T, x_vert_T, y_vert_T
#   Usx1..UsxNb, Usy1..UsyNb   (TIME,2,2) constant patch
# ------------------------------------------------------------
import os
import numpy as np
import netCDF4 as NC

# ============================
# USER CONFIG
# ============================
INP = "/archive/Qian.Xiao/Qian.Xiao/FMS_Wave_Coupling_GOTM_kapp/3D_hurricane/4_deg/201104_3DHurricaneCases.nc"
OUT_DIR = "/archive/Qian.Xiao/Qian.Xiao/MOM_GLS_LT/3D_hurricane/forcing_1D_5mps_360_wave"
os.makedirs(OUT_DIR, exist_ok=True)

# IMPORTANT: pick tracer-center x, away from boundary
X_REF_KM = 2947.5

# domain mapping used by your grid design
Lx_km = 3000.0
Ly_km = 1800.0
lon_min, lon_max = -13.5, 13.5
lat_min, lat_max = -8.1, 8.1

# dummy 2x2 patch coords (match SCM fixed hgrid)
PATCH_LON = np.array([-1.0, 1.0], dtype=np.float32)
PATCH_LAT = np.array([-1.0, 1.0], dtype=np.float32)

OUT_FORMAT = "NETCDF4_CLASSIC"
COMPRESS = True
# ============================

def lon_to_xkm(lon_1d):
    return (lon_1d - lon_min) / (lon_max - lon_min) * Lx_km

def lat_to_ykm(lat_1d):
    return (lat_1d - lat_min) / (lat_max - lat_min) * Ly_km

def main():
    with NC.Dataset(INP, "r") as ds:
        # speed: disable mask/scale
        for vn in ds.variables:
            try:
                ds.variables[vn].set_auto_maskandscale(False)
            except Exception:
                pass

        TIME = ds.variables["TIME"][:]
        time_units = getattr(ds.variables["TIME"], "units", "")
        time_cal   = getattr(ds.variables["TIME"], "calendar", "gregorian")

        freq = ds.variables["frequency"][:].astype("f4")
        Nb = int(freq.size)

        lon = ds.variables["LONGITUDE"][:].astype("f8")
        lat = ds.variables["LATITUDE"][:].astype("f8")
        Nx = lon.size
        Ny = lat.size

        xkm = lon_to_xkm(lon)
        ykm = lat_to_ykm(lat)

        # pick ix by nearest xkm (this guarantees no “lon->x mapping” ambiguity)
        ix = int(np.argmin(np.abs(xkm - X_REF_KM)))
        edge_dist = min(ix, Nx-1-ix)
        print(f"[X] X_REF_KM={X_REF_KM} -> ix={ix}, picked xkm={xkm[ix]:.2f}, edge_dist={edge_dist} cells")

        # sanity: Ny should be 360 and ykm should be 2.5..1797.5
        if Ny != 360:
            raise RuntimeError(f"Unexpected Ny={Ny}, expected 360. Check input file LATITUDE.")
        # optional quick check
        if not (abs(ykm[0]-2.5) < 1e-2 and abs(ykm[-1]-1797.5) < 1e-2):
            print(f"[WARN] ykm endpoints look unusual: ykm[0]={ykm[0]:.3f}, ykm[-1]={ykm[-1]:.3f}")

        # patch geometry
        x_T = np.tile(PATCH_LON[None, :], (2, 1)).astype("f4")
        y_T = np.tile(PATCH_LAT[:, None], (1, 2)).astype("f4")
        dlon = float(PATCH_LON[1] - PATCH_LON[0])
        dlat = float(PATCH_LAT[1] - PATCH_LAT[0])

        x_vert = np.empty((4, 2, 2), dtype="f4")
        y_vert = np.empty((4, 2, 2), dtype="f4")
        x_vert[0,:,:] = x_T - 0.5*dlon;  y_vert[0,:,:] = y_T - 0.5*dlat
        x_vert[1,:,:] = x_T + 0.5*dlon;  y_vert[1,:,:] = y_T - 0.5*dlat
        x_vert[2,:,:] = x_T - 0.5*dlon;  y_vert[2,:,:] = y_T + 0.5*dlat
        x_vert[3,:,:] = x_T + 0.5*dlon;  y_vert[3,:,:] = y_T + 0.5*dlat

        kw = {}
        if OUT_FORMAT.startswith("NETCDF4") and COMPRESS:
            kw = dict(zlib=True, complevel=1, chunksizes=(1,2,2))

        # station loop: sid=1..360 corresponds to iy=0..359 exactly
        for sid in range(1, 361):
            iy = sid - 1
            y_km = float(ykm[iy])

            out_fn = os.path.join(OUT_DIR, f"Wave_5mps_{sid:03d}.nc")

            with NC.Dataset(out_fn, "w", format=OUT_FORMAT) as out:
                out.description = "1D SCM wave forcing extracted from 3D StokesDriftBands (constant 2x2 patch)."
                out.source_file = INP
                out.X_REF_KM = float(X_REF_KM)
                out.picked_ix = int(ix)
                out.picked_iy = int(iy)
                out.picked_xkm = float(xkm[ix])
                out.picked_ykm = float(y_km)
                out.picked_lon = float(lon[ix])
                out.picked_lat = float(lat[iy])

                out.createDimension("TIME", None)
                out.createDimension("LATITUDE", 2)
                out.createDimension("LONGITUDE", 2)
                out.createDimension("frequency", Nb)
                out.createDimension("vert", 4)

                vT = out.createVariable("TIME", "f8", ("TIME",))
                vT.units = time_units
                vT.calendar = time_cal
                vT[:] = TIME

                out.createVariable("LATITUDE",  "f4", ("LATITUDE",))[:]  = PATCH_LAT
                out.createVariable("LONGITUDE", "f4", ("LONGITUDE",))[:] = PATCH_LON
                out.createVariable("frequency", "f4", ("frequency",))[:] = freq

                out.createVariable("x_T",      "f4", ("LATITUDE","LONGITUDE"))[:,:] = x_T
                out.createVariable("y_T",      "f4", ("LATITUDE","LONGITUDE"))[:,:] = y_T
                out.createVariable("x_vert_T", "f4", ("vert","LATITUDE","LONGITUDE"))[:,:,:] = x_vert
                out.createVariable("y_vert_T", "f4", ("vert","LATITUDE","LONGITUDE"))[:,:,:] = y_vert

                # bands
                for b in range(1, Nb+1):
                    vinx = ds.variables[f"Usx{b}"]   # (TIME, LATITUDE, LONGITUDE) in 3D override
                    viny = ds.variables[f"Usy{b}"]

                    srx = vinx[:, iy, ix].astype("f4")  # (Nt,)
                    sry = viny[:, iy, ix].astype("f4")

                    vx = out.createVariable(f"Usx{b}", "f4", ("TIME","LATITUDE","LONGITUDE"), **kw)
                    vy = out.createVariable(f"Usy{b}", "f4", ("TIME","LATITUDE","LONGITUDE"), **kw)
                    vx.units = "m s-1"
                    vy.units = "m s-1"

                    vx[:, :, :] = srx[:, None, None]
                    vy[:, :, :] = sry[:, None, None]

            if sid % 30 == 0 or sid in (1, 360):
                print(f"[OK] {sid:03d}/360  iy={iy} ix={ix}  y_km={y_km:.1f} -> {out_fn}")

    print(f"[DONE] wrote 360 wave files into {OUT_DIR}")

if __name__ == "__main__":
    main()


In [None]:
#!/usr/bin/env python3
# ------------------------------------------------------------
# Extract 1D SCM wave forcing (360 files) from 3D MOM6 override-format
# StokesDriftBands file. for 2mps cases 1D wave-forcing input file
#
# Output per station:
#   TIME, LATITUDE(2), LONGITUDE(2), frequency, vert
#   x_T, y_T, x_vert_T, y_vert_T
#   Usx1..UsxNb, Usy1..UsyNb   (TIME,2,2) constant patch
# ------------------------------------------------------------
import os
import numpy as np
import netCDF4 as NC

# ============================
# USER CONFIG
# ============================
INP = "/archive/Qian.Xiao/Qian.Xiao/FMS_Wave_Coupling_GOTM_kapp/3D_hurricane/4_deg/2mps/201104_3DHurricaneCases.nc"
OUT_DIR = "/archive/Qian.Xiao/Qian.Xiao/MOM_GLS_LT/3D_hurricane/forcing_1D_2mps_360_wave"
os.makedirs(OUT_DIR, exist_ok=True)

# IMPORTANT: pick tracer-center x, away from boundary
X_REF_KM = 2947.5

# domain mapping used by your grid design
Lx_km = 3000.0
Ly_km = 1800.0
lon_min, lon_max = -13.5, 13.5
lat_min, lat_max = -8.1, 8.1

# dummy 2x2 patch coords (match SCM fixed hgrid)
PATCH_LON = np.array([-1.0, 1.0], dtype=np.float32)
PATCH_LAT = np.array([-1.0, 1.0], dtype=np.float32)

OUT_FORMAT = "NETCDF4_CLASSIC"
COMPRESS = True
# ============================

def lon_to_xkm(lon_1d):
    return (lon_1d - lon_min) / (lon_max - lon_min) * Lx_km

def lat_to_ykm(lat_1d):
    return (lat_1d - lat_min) / (lat_max - lat_min) * Ly_km

def main():
    with NC.Dataset(INP, "r") as ds:
        # speed: disable mask/scale
        for vn in ds.variables:
            try:
                ds.variables[vn].set_auto_maskandscale(False)
            except Exception:
                pass

        TIME = ds.variables["TIME"][:]
        time_units = getattr(ds.variables["TIME"], "units", "")
        time_cal   = getattr(ds.variables["TIME"], "calendar", "gregorian")

        freq = ds.variables["frequency"][:].astype("f4")
        Nb = int(freq.size)

        lon = ds.variables["LONGITUDE"][:].astype("f8")
        lat = ds.variables["LATITUDE"][:].astype("f8")
        Nx = lon.size
        Ny = lat.size

        xkm = lon_to_xkm(lon)
        ykm = lat_to_ykm(lat)

        # pick ix by nearest xkm (this guarantees no “lon->x mapping” ambiguity)
        ix = int(np.argmin(np.abs(xkm - X_REF_KM)))
        edge_dist = min(ix, Nx-1-ix)
        print(f"[X] X_REF_KM={X_REF_KM} -> ix={ix}, picked xkm={xkm[ix]:.2f}, edge_dist={edge_dist} cells")

        # sanity: Ny should be 360 and ykm should be 2.5..1797.5
        if Ny != 360:
            raise RuntimeError(f"Unexpected Ny={Ny}, expected 360. Check input file LATITUDE.")
        # optional quick check
        if not (abs(ykm[0]-2.5) < 1e-2 and abs(ykm[-1]-1797.5) < 1e-2):
            print(f"[WARN] ykm endpoints look unusual: ykm[0]={ykm[0]:.3f}, ykm[-1]={ykm[-1]:.3f}")

        # patch geometry
        x_T = np.tile(PATCH_LON[None, :], (2, 1)).astype("f4")
        y_T = np.tile(PATCH_LAT[:, None], (1, 2)).astype("f4")
        dlon = float(PATCH_LON[1] - PATCH_LON[0])
        dlat = float(PATCH_LAT[1] - PATCH_LAT[0])

        x_vert = np.empty((4, 2, 2), dtype="f4")
        y_vert = np.empty((4, 2, 2), dtype="f4")
        x_vert[0,:,:] = x_T - 0.5*dlon;  y_vert[0,:,:] = y_T - 0.5*dlat
        x_vert[1,:,:] = x_T + 0.5*dlon;  y_vert[1,:,:] = y_T - 0.5*dlat
        x_vert[2,:,:] = x_T - 0.5*dlon;  y_vert[2,:,:] = y_T + 0.5*dlat
        x_vert[3,:,:] = x_T + 0.5*dlon;  y_vert[3,:,:] = y_T + 0.5*dlat

        kw = {}
        if OUT_FORMAT.startswith("NETCDF4") and COMPRESS:
            kw = dict(zlib=True, complevel=1, chunksizes=(1,2,2))

        # station loop: sid=1..360 corresponds to iy=0..359 exactly
        for sid in range(1, 361):
            iy = sid - 1
            y_km = float(ykm[iy])

            out_fn = os.path.join(OUT_DIR, f"Wave_2mps_{sid:03d}.nc")

            with NC.Dataset(out_fn, "w", format=OUT_FORMAT) as out:
                out.description = "1D SCM wave forcing extracted from 3D StokesDriftBands (constant 2x2 patch)."
                out.source_file = INP
                out.X_REF_KM = float(X_REF_KM)
                out.picked_ix = int(ix)
                out.picked_iy = int(iy)
                out.picked_xkm = float(xkm[ix])
                out.picked_ykm = float(y_km)
                out.picked_lon = float(lon[ix])
                out.picked_lat = float(lat[iy])

                out.createDimension("TIME", None)
                out.createDimension("LATITUDE", 2)
                out.createDimension("LONGITUDE", 2)
                out.createDimension("frequency", Nb)
                out.createDimension("vert", 4)

                vT = out.createVariable("TIME", "f8", ("TIME",))
                vT.units = time_units
                vT.calendar = time_cal
                vT[:] = TIME

                out.createVariable("LATITUDE",  "f4", ("LATITUDE",))[:]  = PATCH_LAT
                out.createVariable("LONGITUDE", "f4", ("LONGITUDE",))[:] = PATCH_LON
                out.createVariable("frequency", "f4", ("frequency",))[:] = freq

                out.createVariable("x_T",      "f4", ("LATITUDE","LONGITUDE"))[:,:] = x_T
                out.createVariable("y_T",      "f4", ("LATITUDE","LONGITUDE"))[:,:] = y_T
                out.createVariable("x_vert_T", "f4", ("vert","LATITUDE","LONGITUDE"))[:,:,:] = x_vert
                out.createVariable("y_vert_T", "f4", ("vert","LATITUDE","LONGITUDE"))[:,:,:] = y_vert

                # bands
                for b in range(1, Nb+1):
                    vinx = ds.variables[f"Usx{b}"]   # (TIME, LATITUDE, LONGITUDE) in 3D override
                    viny = ds.variables[f"Usy{b}"]

                    srx = vinx[:, iy, ix].astype("f4")  # (Nt,)
                    sry = viny[:, iy, ix].astype("f4")

                    vx = out.createVariable(f"Usx{b}", "f4", ("TIME","LATITUDE","LONGITUDE"), **kw)
                    vy = out.createVariable(f"Usy{b}", "f4", ("TIME","LATITUDE","LONGITUDE"), **kw)
                    vx.units = "m s-1"
                    vy.units = "m s-1"

                    vx[:, :, :] = srx[:, None, None]
                    vy[:, :, :] = sry[:, None, None]

            if sid % 30 == 0 or sid in (1, 360):
                print(f"[OK] {sid:03d}/360  iy={iy} ix={ix}  y_km={y_km:.1f} -> {out_fn}")

    print(f"[DONE] wrote 360 wave files into {OUT_DIR}")

if __name__ == "__main__":
    main()
