In [2]:
from __future__ import annotations

import tempfile
from pathlib import Path
from dataclasses import dataclass
from typing import Optional


import numpy as np
import xarray as xr
import pygmt
from pygmt.clib import Session


# ==========================
# 配置
# ==========================
@dataclass
class Config:
    # 目标区 (W,E,S,N)
    target_region: tuple[float, float, float, float] = (111.0, 117.0, 14.5, 18.5)

    # 外扩度数
    buffer_deg_lon: float = 5.0
    buffer_deg_lat: float = 5.0

    # 密度（kg/m^3）
    rho_water: float = 1030.0
    rho_crust: float = 2800.0

    # Parker/gravfft 参数
    nterms: int = 8
    taper_percent: int = 15
    extension_mode: str = "m"  # e/m/n
    fft_mode: str = "a"
    detrend_mode: Optional[str] = None  # 推荐用 "a"（去均值）或 "l"（不去）
    field_opt: str = "f+s"  # 关键：-Ff+s（去均值时加回 slab） :contentReference[oaicite:5]{index=5}

    # 本地小窗文件（不覆盖外扩区则自动下载远程网格）
    faa_file: Path = Path(r"E:\wjy\Gravity\SCS_Gravity\data\Download_Data\faa_01m_111E-117E_14.5N-18.5N.nc")
    bathy_file: Path = Path(r"E:\wjy\Gravity\SCS_Gravity\data\Download_Data\gebco_01m_111E-117E_14.5N-18.5N.nc")

    out_dir: Path = Path(r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\CBA_expand_crop")

    # 自动下载 GMT 远程数据集
    auto_fetch_if_needed: bool = True
    fetch_resolution: str = "01m"  # @earth_faa_01m / @earth_relief_01m

    # 不强制 _g/_p，而是“偏好 + 自动回退”
    fetch_reg_preference: str = "g"
    force_match_to_faa: bool = True

    # 是否保留 gravfft 临时目录
    keep_temp_dir: bool = True

    # NaN 填补 & QC 参数
    max_nan_frac_allowed: float = 0.02
    ocean_negative_frac_warn: float = 0.50
    ocean_negative_frac_flip: float = 0.30

    # 画图参数
    map_proj: str = "M12c"
    cmap: str = "turbo"
    dpi: int = 300

    # GMT verbose: q(quiet)/n(normal)/v(verbose)
    gmt_verbose_level: str = "n"


# ==========================
# 工具函数
# ==========================
def _guess_grid_var(ds: xr.Dataset, preferred: list[str] | None = None) -> str:
    if preferred is None:
        preferred = []
    for name in preferred:
        if name in ds.data_vars:
            return name
    for vname, da in ds.data_vars.items():
        if da.ndim == 2:
            return vname
    raise ValueError(f"No 2D grid found. data_vars={list(ds.data_vars)}")


def _standardize_lonlat_da(da: xr.DataArray) -> xr.DataArray:
    if "lon" not in da.coords and "x" in da.coords:
        da = da.rename({"x": "lon"})
    if "lat" not in da.coords and "y" in da.coords:
        da = da.rename({"y": "lat"})
    if "lon" not in da.coords or "lat" not in da.coords:
        raise ValueError(f"Grid missing lon/lat coords, coords={list(da.coords)}")

    if np.any(np.diff(da["lon"].values) < 0):
        da = da.sortby("lon")
    if np.any(np.diff(da["lat"].values) < 0):
        da = da.sortby("lat")
    return da


def _subset_region_da(da: xr.DataArray, region: tuple[float, float, float, float]) -> xr.DataArray:
    w, e, s, n = region
    da = _standardize_lonlat_da(da)
    return da.sel(lon=slice(w, e), lat=slice(s, n))


def _expand_region(region: tuple[float, float, float, float], dlon: float, dlat: float) -> tuple[float, float, float, float]:
    w, e, s, n = region
    return (w - dlon, e + dlon, s - dlat, n + dlat)


def _grid_step_deg(coord: xr.DataArray) -> float:
    v = np.asarray(coord.values, dtype=float)
    dv = np.diff(v)
    dv = dv[np.isfinite(dv)]
    if dv.size == 0:
        raise ValueError("Cannot determine grid step (empty diffs).")
    return float(np.median(np.abs(dv)))


def _coverage_ok(da: xr.DataArray, region: tuple[float, float, float, float], tol_cells: float = 0.55) -> bool:
    w, e, s, n = region
    da = _standardize_lonlat_da(da)

    lon_min = float(da.lon.min())
    lon_max = float(da.lon.max())
    lat_min = float(da.lat.min())
    lat_max = float(da.lat.max())

    dx = _grid_step_deg(da.lon)
    dy = _grid_step_deg(da.lat)
    tol_lon = tol_cells * dx
    tol_lat = tol_cells * dy

    ok = (lon_min <= w + tol_lon) and (lon_max >= e - tol_lon) and (lat_min <= s + tol_lat) and (lat_max >= n - tol_lat)

    print(
        f"[QC] coverage lon[{lon_min:.6f},{lon_max:.6f}] lat[{lat_min:.6f},{lat_max:.6f}] "
        f"dx={dx:.6f} dy={dy:.6f} tol_lon={tol_lon:.6f} tol_lat={tol_lat:.6f}  want={region}  ok={ok}"
    )
    return ok


def _require_nan_frac_below(name: str, da: xr.DataArray, max_nan_frac: float) -> None:
    v = da.values
    frac = np.isnan(v).sum() / v.size
    print(f"[QC] {name}: NaN fraction = {frac:.3%}")
    if frac > max_nan_frac:
        raise ValueError(f"{name}: NaN fraction too high ({frac:.3%} > {max_nan_frac:.3%}).")


def _fill_nans_2d(da: xr.DataArray, name: str, prefer_nearest: bool = False) -> xr.DataArray:
    da = _standardize_lonlat_da(da)
    if not np.isnan(da.values).any():
        return da

    print(f"[INFO] Filling NaNs for {name} ... prefer_nearest={prefer_nearest}")
    out = da

    if prefer_nearest:
        out = out.interpolate_na(dim="lon", method="nearest", fill_value="extrapolate")
        out = out.interpolate_na(dim="lat", method="nearest", fill_value="extrapolate")
        if np.isnan(out.values).any():
            out = out.interpolate_na(dim="lon", method="linear", fill_value="extrapolate")
            out = out.interpolate_na(dim="lat", method="linear", fill_value="extrapolate")
    else:
        out = out.interpolate_na(dim="lon", method="linear", fill_value="extrapolate")
        out = out.interpolate_na(dim="lat", method="linear", fill_value="extrapolate")
        if np.isnan(out.values).any():
            out = out.interpolate_na(dim="lon", method="nearest", fill_value="extrapolate")
            out = out.interpolate_na(dim="lat", method="nearest", fill_value="extrapolate")

    if np.isnan(out.values).any():
        nleft = int(np.isnan(out.values).sum())
        print(f"[WARN] {name}: still has {nleft} NaNs after interpolation; fill with 0.")
        out = out.fillna(0.0)

    return out


def _qc_stats(name: str, da: xr.DataArray) -> None:
    v = da.values
    print(
        f"[QC] {name}: min={float(np.nanmin(v)):.3f} "
        f"max={float(np.nanmax(v)):.3f} mean={float(np.nanmean(v)):.3f} std={float(np.nanstd(v)):.3f}"
    )


def _fix_bathy_sign_if_needed(bathy: xr.DataArray, cfg: Config) -> xr.DataArray:
    bathy = _standardize_lonlat_da(bathy)
    neg_frac = float((bathy < 0).sum() / bathy.size)
    _qc_stats("Bathy(before sign check)", bathy)
    print(f"[QC] bathy<0 fraction = {neg_frac:.3f}")

    if neg_frac < cfg.ocean_negative_frac_warn:
        print(f"[WARN] bathy<0 fraction is low (<{cfg.ocean_negative_frac_warn}). Check if bathy sign is correct.")

    if neg_frac < cfg.ocean_negative_frac_flip:
        vmin = float(np.nanmin(bathy.values))
        vmax = float(np.nanmax(bathy.values))
        if vmin > -1e-3 and vmax > 0:
            print("[WARN] Bathy seems depth-positive. Flipping sign: elevation = -depth.")
            bathy = -bathy
            _qc_stats("Bathy(after flip)", bathy)
    return bathy


def _to_tmp_grid(da: xr.DataArray, outpath: Path) -> None:
    da2 = _standardize_lonlat_da(da.astype(np.float32))
    ds = da2.to_dataset(name="z").rename({"lon": "x", "lat": "y"})
    ds.to_netcdf(outpath, engine="scipy")


def _read_grid(path: Path) -> xr.DataArray:
    ds = xr.open_dataset(path)
    vname = _guess_grid_var(ds, preferred=["z"])
    da = ds[vname]
    ds.close()
    return _standardize_lonlat_da(da)


def _coerce_grid_like(src: xr.DataArray, tmpl: xr.DataArray) -> xr.DataArray:
    src = _standardize_lonlat_da(src)
    tmpl = _standardize_lonlat_da(tmpl)
    same_shape = (src.sizes["lon"] == tmpl.sizes["lon"]) and (src.sizes["lat"] == tmpl.sizes["lat"])
    if same_shape:
        return src.assign_coords(lon=tmpl["lon"].values, lat=tmpl["lat"].values)
    return src.interp(lon=tmpl["lon"], lat=tmpl["lat"], method="linear")


def _snap_or_interp_like(da: xr.DataArray, tmpl: xr.DataArray, name: str) -> xr.DataArray:
    da = _standardize_lonlat_da(da)
    tmpl = _standardize_lonlat_da(tmpl)

    lon_ok = (da.sizes["lon"] == tmpl.sizes["lon"]) and np.allclose(da.lon.values, tmpl.lon.values, atol=1e-10, rtol=0)
    lat_ok = (da.sizes["lat"] == tmpl.sizes["lat"]) and np.allclose(da.lat.values, tmpl.lat.values, atol=1e-10, rtol=0)

    if lon_ok and lat_ok:
        print(f"[QC] {name}: coords match template; skip interpolation.")
        return da.assign_coords(lon=tmpl.lon.values, lat=tmpl.lat.values)

    print(f"[QC] {name}: coords differ; do linear interpolation.")
    return da.interp(lon=tmpl.lon, lat=tmpl.lat, method="linear")


def _fit_plane_slope_km(da: xr.DataArray) -> tuple[float, float, float]:
    da = _standardize_lonlat_da(da)
    lon = da.lon.values
    lat = da.lat.values
    Lon, Lat = np.meshgrid(lon, lat)

    lat0 = np.deg2rad(float(np.nanmean(Lat)))
    x = (Lon - np.nanmean(Lon)) * 111.32 * np.cos(lat0)
    y = (Lat - np.nanmean(Lat)) * 110.57

    z = da.values
    m = np.isfinite(z)
    A = np.c_[x[m], y[m], np.ones(m.sum())]
    coef, *_ = np.linalg.lstsq(A, z[m], rcond=None)
    a, b, c = coef
    return float(a), float(b), float(c)


# ==========================
# GMT wrappers
# ==========================
def gmt_gravfft(
    ingrid_nc: Path,
    outgrid_nc: Path,
    drho: float,
    nterms: int,
    taper_percent: int = 15,
    extension_mode: str = "m",
    fft_mode: str = "a",
    detrend_mode: Optional[str] = None,
    field_opt: str = "f+s",  # 新增：-F 选项
    gmt_verbose_level: str = "n",
    verbose: bool = True,
) -> None:
    if extension_mode not in ("e", "m", "n"):
        raise ValueError("extension_mode must be one of 'e','m','n'")
    if fft_mode not in ("a", "f", "m", "r", "s"):
        raise ValueError("fft_mode must be one of 'a','f','m','r','s'")
    if detrend_mode is not None and detrend_mode not in ("a", "d", "h", "l"):
        raise ValueError("detrend_mode must be one of 'a','d','h','l', or None")
    if not (0 <= taper_percent <= 100):
        raise ValueError("taper_percent must be in [0,100]")
    if not (1 <= nterms <= 10):
        raise ValueError("nterms must be in [1,10]")
    if gmt_verbose_level not in ("q", "n", "v"):
        raise ValueError("gmt_verbose_level must be one of 'q','n','v'")

    Nopt = f"-N{fft_mode}+{extension_mode}+t{taper_percent}+v"
    if detrend_mode is not None:
        Nopt = f"-N{fft_mode}+{detrend_mode}+{extension_mode}+t{taper_percent}+v"

    cmd_parts = [
        ingrid_nc.as_posix(),
        f"-D{drho}",
        f"-G{outgrid_nc.as_posix()}",
        f"-E{nterms}",
        f"-F{field_opt}",  # 关键：-Ff+s :contentReference[oaicite:6]{index=6}
        Nopt,
        "-fg",
        f"-V{gmt_verbose_level}",
    ]
    cmd = " ".join(cmd_parts)

    if verbose:
        print("[GMT] gravfft", cmd)

    with Session() as ses:
        ses.call_module("gravfft", cmd)

    if not outgrid_nc.exists():
        raise FileNotFoundError(f"gravfft did not create output grid: {outgrid_nc}")


def _grdcut_remote(dataset: str, region: tuple[float, float, float, float], out_nc: Path) -> None:
    w, e, s, n = region
    out_nc.parent.mkdir(parents=True, exist_ok=True)
    pygmt.grdcut(grid=dataset, region=[w, e, s, n], outgrid=out_nc.as_posix(), verbose="q")
    if not out_nc.exists():
        raise FileNotFoundError(f"grdcut did not create output: {out_nc}")


def _grdcut_remote_with_fallback(
    base: str,
    res: str,
    reg_pref: str,
    region: tuple[float, float, float, float],
    out_nc: Path,
) -> str:
    candidates = [f"@{base}_{res}_{reg_pref}", f"@{base}_{res}"]
    if reg_pref == "g":
        candidates.append(f"@{base}_{res}_p")
    elif reg_pref == "p":
        candidates.append(f"@{base}_{res}_g")

    last_err = None
    for dsname in candidates:
        try:
            print(f"[INFO] Trying remote dataset: {dsname}")
            _grdcut_remote(dsname, region, out_nc)
            print(f"[INFO] Success: {dsname}")
            return dsname
        except Exception as e:
            last_err = e
            print(f"[WARN] Failed: {dsname}  ({type(e).__name__}: {e})")

    raise RuntimeError(f"All remote dataset candidates failed for {base}_{res}. Last error: {last_err}")


def _load_local_grid(path: Path, preferred_vars: list[str]) -> xr.DataArray:
    ds = xr.open_dataset(path)
    vname = _guess_grid_var(ds, preferred=preferred_vars)
    da = ds[vname]
    ds.close()
    return _standardize_lonlat_da(da)


def _resample_grid_to_template_file(
    src_grid_nc: Path,
    tmpl: xr.DataArray,
    out_nc: Path,
    registration: str,
) -> None:
    tmpl = _standardize_lonlat_da(tmpl)
    dx = _grid_step_deg(tmpl.lon)
    dy = _grid_step_deg(tmpl.lat)
    region = [float(tmpl.lon.min()), float(tmpl.lon.max()), float(tmpl.lat.min()), float(tmpl.lat.max())]

    out_nc.parent.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] grdsample -> match template: {src_grid_nc.name} -> {out_nc.name}  reg={registration} dx={dx} dy={dy}")
    pygmt.grdsample(
        grid=str(src_grid_nc),
        outgrid=str(out_nc),
        region=region,
        spacing=[dx, dy],
        registration=registration,
        interpolation="b",
        verbose="q",
    )
    if not out_nc.exists():
        raise FileNotFoundError(f"grdsample did not create output: {out_nc}")


def _load_faa_and_bathy_covering_region(cfg: Config, region_exp: tuple[float, float, float, float]) -> tuple[xr.DataArray, xr.DataArray]:
    faa0 = _load_local_grid(cfg.faa_file, ["faa", "FAA", "free_air", "faa_mgal", "z"])
    bathy0 = _load_local_grid(cfg.bathy_file, ["elevation", "z", "gebco", "topo", "bathy"])

    if _coverage_ok(faa0, region_exp) and _coverage_ok(bathy0, region_exp):
        return faa0, bathy0

    if not cfg.auto_fetch_if_needed:
        raise ValueError("Local grids do not cover expanded region and auto_fetch_if_needed=False.")

    print("[INFO] Local grids too small. Fetching expanded grids via GMT remote datasets...")
    res = cfg.fetch_resolution
    pref = cfg.fetch_reg_preference

    faa_nc = cfg.out_dir / f"faa_{res}_expanded_{region_exp[0]:.1f}_{region_exp[1]:.1f}_{region_exp[2]:.1f}_{region_exp[3]:.1f}.nc"
    faa_ds_used = _grdcut_remote_with_fallback("earth_faa", res, pref, region_exp, faa_nc)

    rel_nc = cfg.out_dir / f"relief_{res}_expanded_{region_exp[0]:.1f}_{region_exp[1]:.1f}_{region_exp[2]:.1f}_{region_exp[3]:.1f}.nc"
    rel_ds_used = _grdcut_remote_with_fallback("earth_relief", res, pref, region_exp, rel_nc)

    faa1 = _load_local_grid(faa_nc, ["z", "faa", "FAA"])
    bathy1 = _load_local_grid(rel_nc, ["z", "elevation", "topo", "bathy"])

    print(f"[INFO] FAA dataset used: {faa_ds_used}")
    print(f"[INFO] Relief dataset used: {rel_ds_used}")

    if cfg.force_match_to_faa:
        dx = _grid_step_deg(faa1.lon)
        frac = (float(faa1.lon.min()) / dx) - np.floor(float(faa1.lon.min()) / dx)
        reg = "p" if abs(frac - 0.5) < 1e-2 else "g"
        rel_rs_nc = cfg.out_dir / f"relief_{res}_expanded_MATCHFAA_{reg}.nc"
        _resample_grid_to_template_file(rel_nc, faa1, rel_rs_nc, registration=reg)
        bathy1 = _load_local_grid(rel_rs_nc, ["z", "elevation", "topo", "bathy"])

    return faa1, bathy1


def compute_cba_expand_then_crop(cfg: Config) -> dict[str, xr.DataArray]:
    region_exp = _expand_region(cfg.target_region, cfg.buffer_deg_lon, cfg.buffer_deg_lat)
    print("Target region:", cfg.target_region)
    print("Expanded region:", region_exp)

    faa0, bathy0 = _load_faa_and_bathy_covering_region(cfg, region_exp)

    faa_exp = _subset_region_da(faa0, region_exp)
    bathy_exp = _subset_region_da(bathy0, region_exp)

    _require_nan_frac_below("FAA(expanded)", faa_exp, cfg.max_nan_frac_allowed)
    _require_nan_frac_below("Bathy(expanded)", bathy_exp, cfg.max_nan_frac_allowed)

    faa_exp = _fill_nans_2d(faa_exp, "FAA(expanded)", prefer_nearest=False)
    bathy_exp = _fill_nans_2d(bathy_exp, "Bathy(expanded)", prefer_nearest=True)
    bathy_exp = _fix_bathy_sign_if_needed(bathy_exp, cfg)

    h_ocean = bathy_exp.where(bathy_exp < 0.0, 0.0)
    h_land = bathy_exp.where(bathy_exp > 0.0, 0.0)

    drho_wc = cfg.rho_crust - cfg.rho_water
    drho_topo = cfg.rho_crust - 0.0

    td = Path(tempfile.mkdtemp(prefix="cba_gravfft_"))
    print("[INFO] Using temp dir:", td)

    ocean_nc_in = td / "h_ocean.nc"
    land_nc_in = td / "h_land.nc"
    _to_tmp_grid(h_ocean, ocean_nc_in)
    _to_tmp_grid(h_land, land_nc_in)

    g_wc_nc = td / "g_wc.nc"
    g_topo_nc = td / "g_topo.nc"

    gmt_gravfft(
        ocean_nc_in, g_wc_nc, drho_wc, cfg.nterms,
        taper_percent=cfg.taper_percent,
        extension_mode=cfg.extension_mode,
        fft_mode=cfg.fft_mode,
        detrend_mode=cfg.detrend_mode,
        field_opt=cfg.field_opt,
        gmt_verbose_level=cfg.gmt_verbose_level,
    )
    gmt_gravfft(
        land_nc_in, g_topo_nc, drho_topo, cfg.nterms,
        taper_percent=cfg.taper_percent,
        extension_mode=cfg.extension_mode,
        fft_mode=cfg.fft_mode,
        detrend_mode=cfg.detrend_mode,
        field_opt=cfg.field_opt,
        gmt_verbose_level=cfg.gmt_verbose_level,
    )

    g_wc = _read_grid(g_wc_nc)
    g_topo = _read_grid(g_topo_nc)

    g_wc = _coerce_grid_like(g_wc, faa_exp)
    g_topo = _coerce_grid_like(g_topo, faa_exp)
    faa_exp, g_wc, g_topo = xr.align(faa_exp, g_wc, g_topo, join="override")

    _qc_stats("FAA(exp)", faa_exp)
    _qc_stats("g_wc(exp)", g_wc)
    _qc_stats("g_topo(exp)", g_topo)

    cba_exp = faa_exp - (g_wc + g_topo)
    cba_exp.name = "CBA_raw"
    _qc_stats("CBA_raw(exp)", cba_exp)

    w, e, s, n = cfg.target_region
    cba_tgt = cba_exp.sel(lon=slice(w, e), lat=slice(s, n))
    g_wc_tgt = g_wc.sel(lon=slice(w, e), lat=slice(s, n))
    g_topo_tgt = g_topo.sel(lon=slice(w, e), lat=slice(s, n))
    faa_tgt = faa_exp.sel(lon=slice(w, e), lat=slice(s, n))

    tmpl0 = _load_local_grid(cfg.faa_file, ["faa", "FAA", "free_air", "faa_mgal", "z"])
    tmpl = _subset_region_da(tmpl0, cfg.target_region)

    cba_tgt = _snap_or_interp_like(cba_tgt, tmpl, "CBA_raw(tgt)")
    g_wc_tgt = _snap_or_interp_like(g_wc_tgt, tmpl, "g_wc(tgt)")
    g_topo_tgt = _snap_or_interp_like(g_topo_tgt, tmpl, "g_topo(tgt)")
    faa_tgt = _snap_or_interp_like(faa_tgt, tmpl, "FAA(tgt)")

    def _print_plane(name: str, da: xr.DataArray) -> None:
        a, b, _ = _fit_plane_slope_km(da)
        print(f"[QC] {name} plane slope: a={a:.4f} mGal/km (E-W), b={b:.4f} mGal/km (N-S)")

    _print_plane("FAA(tgt)", faa_tgt)
    _print_plane("g_wc(tgt)", g_wc_tgt)
    _print_plane("g_topo(tgt)", g_topo_tgt)
    _print_plane("g_wc+g_topo(tgt)", (g_wc_tgt + g_topo_tgt))
    _print_plane("CBA_raw(tgt)", cba_tgt)

    cfg.out_dir.mkdir(parents=True, exist_ok=True)
    out_nc = cfg.out_dir / f"CBA_raw_expand{cfg.buffer_deg_lon:.1f}deg_{cfg.buffer_deg_lat:.1f}deg.nc"
    xr.Dataset(
        data_vars=dict(
            FAA=faa_tgt.astype("float32"),
            g_wc=g_wc_tgt.astype("float32"),
            g_topo=g_topo_tgt.astype("float32"),
            CBA_raw=cba_tgt.astype("float32"),
        ),
        coords=dict(lon=cba_tgt.lon, lat=cba_tgt.lat),
    ).to_netcdf(out_nc)
    print("Saved:", out_nc)

    return {"CBA_raw": cba_tgt, "FAA": faa_tgt, "g_wc": g_wc_tgt, "g_topo": g_topo_tgt}


def quick_map_fullrange(da: xr.DataArray, title: str, out_png: Path, cfg: Config) -> None:
    fig = pygmt.Figure()
    region = [float(da.lon.min()), float(da.lon.max()), float(da.lat.min()), float(da.lat.max())]

    da0 = da.load()
    vmin0 = float(np.nanmin(da0.values))
    vmax0 = float(np.nanmax(da0.values))
    series = [-max(abs(vmin0), abs(vmax0)), max(abs(vmin0), abs(vmax0))] if (vmin0 < 0 < vmax0) else [vmin0, vmax0]

    pygmt.makecpt(cmap=cfg.cmap, series=series, background=True)
    fig.grdimage(da0, region=region, projection=cfg.map_proj, frame=["a", f"+t{title}"], cmap=True)
    fig.coast(region=region, projection=cfg.map_proj, shorelines="0.8p,black", borders="1/0.5p,black")
    fig.colorbar(frame='af+l"mGal"')
    fig.savefig(out_png, dpi=cfg.dpi)
    print("Saved:", out_png)


def main() -> None:
    cfg = Config()

    # 关键设置：用 +a 去均值，并用 -Ff+s 加回 slab，避免默认 +h(mid-value) 造成的大基线漂移 :contentReference[oaicite:7]{index=7}
    cfg.detrend_mode = "a"
    cfg.field_opt = "f+s"

    cfg.out_dir.mkdir(parents=True, exist_ok=True)
    print(cfg)

    res = compute_cba_expand_then_crop(cfg)

    da = res["CBA_raw"].load()
    print("[QC] CBA_raw(tgt) stats:",
          float(np.nanmin(da)), float(np.nanmax(da)), float(np.nanmean(da)), float(np.nanstd(da)))

    out_png = cfg.out_dir / "QC_CBA_fullrange.png"
    quick_map_fullrange(res["CBA_raw"], "CBA_raw (expanded->crop)", out_png, cfg)


if __name__ == "__main__":
    main()


Config(target_region=(111.0, 117.0, 14.5, 18.5), buffer_deg_lon=5.0, buffer_deg_lat=5.0, rho_water=1030.0, rho_crust=2800.0, nterms=8, taper_percent=15, extension_mode='m', fft_mode='a', detrend_mode=None, faa_file=WindowsPath('E:/wjy/Gravity/SCS_Gravity/data/Download_Data/faa_01m_111E-117E_14.5N-18.5N.nc'), bathy_file=WindowsPath('E:/wjy/Gravity/SCS_Gravity/data/Download_Data/gebco_01m_111E-117E_14.5N-18.5N.nc'), out_dir=WindowsPath('E:/wjy/Gravity/SCS_Gravity/out/Outdata/CBA_expand_crop'), auto_fetch_if_needed=True, fetch_resolution='01m', fetch_reg_preference='g', force_match_to_faa=True, keep_temp_dir=True, max_nan_frac_allowed=0.02, ocean_negative_frac_warn=0.5, ocean_negative_frac_flip=0.3, map_proj='M12c', cmap='turbo', dpi=300, gmt_verbose_level='n')
Target region: (111.0, 117.0, 14.5, 18.5)
Expanded region: (106.0, 122.0, 9.5, 23.5)
[QC] coverage lon[111.008333,116.991667] lat[14.508333,18.491667] dx=0.016667 dy=0.016667 tol_lon=0.009167 tol_lat=0.009167  want=(106.0, 122.0, 9