In [1]:
import math
import time
import traceback
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import scipy.stats as stats
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple, List, Optional
from joblib import Parallel, delayed

# 尝试导入加速库
try:
    from scipy.spatial import cKDTree
    HAVE_SCIPY = True
except ImportError:
    HAVE_SCIPY = False

# ==============================================================================
# 1. 全局配置 (Configuration)
# ==============================================================================
# 运行模式: "demo" (生成模拟数据) 或 "real" (读取外部文件)
MODE = "real"  

# 路径配置
OUTDIR = Path(r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\12_30")
# [Real模式下需配置]
SCHEMEB_NC = r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\schemeB_surfaces_mock.nc"
OBS_CSV    = r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\obs_mock.csv"

# MCMC 迭代参数
N_ITER   = 50000
BURN_IN  = 5000
THIN     = 10
SEED     = 20251225
N_CHAINS = 4
N_JOBS   = 1  # Windows下建议设为1避免多进程报错，Linux可设为-1

# 网格与几何参数
DX_KM   = 10.0
DY_KM   = 10.0
DZ_KM   = 2.0
ZMAX_KM = 60.0

# 先验范围 (Voronoi 维度 K)
PRIOR_K_MIN = 2
PRIOR_K_MAX = 30

# 先验标准差 (Density)
SIGMA_RHO_UC = 150.0
SIGMA_RHO_LC = 150.0
SIGMA_RHO_M  = 100.0

# 数据噪声先验 (Jeffreys prior bounds)
SIGMA_MIN_MGAL = 0.05
SIGMA_MAX_MGAL = 5.0

# Proposal 步长
STEP_XY_KM     = 8.0
STEP_RHO       = 50.0
STEP_LOG_SIGMA = 0.08

# Proposal 概率权重
P_BIRTH = 0.15
P_DEATH = 0.15
P_MOVE  = 0.35
P_VALUE = 0.30
P_SIGMA = 0.05

# 绘图配置
PPD_NDEPTH  = 220
PPD_RHO_MIN = -600.0
PPD_RHO_MAX = 600.0
PPD_NBINS   = 160
PPD_COLUMN_MODE = "center" # "center" 或 "manual"
PPD_IX = None
PPD_IY = None

# 物理常量
G_SI = 6.67430e-11
SI_TO_MGAL = 1e5
KM_TO_M = 1000.0

# ==============================================================================
# 2. 核心数据结构 (Core Data Structures)
# ==============================================================================
@dataclass
class VoxelGrid:
    x_edges_km: np.ndarray
    y_edges_km: np.ndarray
    z_edges_km: np.ndarray

    @property
    def nx(self) -> int: return len(self.x_edges_km) - 1
    @property
    def ny(self) -> int: return len(self.y_edges_km) - 1
    @property
    def nz(self) -> int: return len(self.z_edges_km) - 1
    @property
    def nvox(self) -> int: return self.nx * self.ny * self.nz

    def column_centers_km(self) -> np.ndarray:
        xc = 0.5 * (self.x_edges_km[:-1] + self.x_edges_km[1:])
        yc = 0.5 * (self.y_edges_km[:-1] + self.y_edges_km[1:])
        X, Y = np.meshgrid(xc, yc)
        return np.column_stack([X.ravel(), Y.ravel()]).astype(np.float64)

    def voxel_bounds_m(self):
        """返回所有体素的几何边界 (x1, x2, y1, y2, z1, z2) in meters"""
        x1, x2 = self.x_edges_km[:-1], self.x_edges_km[1:]
        y1, y2 = self.y_edges_km[:-1], self.y_edges_km[1:]
        z1, z2 = self.z_edges_km[:-1], self.z_edges_km[1:]
        
        X1, Y1, Z1 = np.meshgrid(x1, y1, z1, indexing="xy")
        X2, Y2, Z2 = np.meshgrid(x2, y2, z2, indexing="xy")
        
        # Transpose ensures fast looping order matches column_index logic
        return (
            np.transpose(X1, (2, 0, 1)).ravel() * KM_TO_M,
            np.transpose(X2, (2, 0, 1)).ravel() * KM_TO_M,
            np.transpose(Y1, (2, 0, 1)).ravel() * KM_TO_M,
            np.transpose(Y2, (2, 0, 1)).ravel() * KM_TO_M,
            np.transpose(Z1, (2, 0, 1)).ravel() * KM_TO_M,
            np.transpose(Z2, (2, 0, 1)).ravel() * KM_TO_M
        )

    def voxel_centers_km(self) -> np.ndarray:
        xc = 0.5 * (self.x_edges_km[:-1] + self.x_edges_km[1:])
        yc = 0.5 * (self.y_edges_km[:-1] + self.y_edges_km[1:])
        zc = 0.5 * (self.z_edges_km[:-1] + self.z_edges_km[1:])
        X, Y, Z = np.meshgrid(xc, yc, zc, indexing="xy")
        return np.column_stack([
            np.transpose(X, (2, 0, 1)).ravel(),
            np.transpose(Y, (2, 0, 1)).ravel(),
            np.transpose(Z, (2, 0, 1)).ravel()
        ]).astype(np.float64)

    def voxel_column_index(self) -> np.ndarray:
        """返回每个体素所属的水平柱索引 (column index)"""
        nx, ny, nz = self.nx, self.ny, self.nz
        ix = np.tile(np.arange(nx, dtype=np.int32), ny)
        iy = np.repeat(np.arange(ny, dtype=np.int32), nx)
        col = iy * nx + ix
        return np.tile(col, nz).astype(np.int32)

@dataclass
class ObservationSet:
    x_km: np.ndarray
    y_km: np.ndarray
    z_km: np.ndarray
    d_mgal: np.ndarray

    @property
    def nobs(self) -> int: return len(self.d_mgal)

@dataclass
class SchemeBSurfaces:
    hW_km: np.ndarray
    hB_km: np.ndarray
    hUC_km: np.ndarray
    hM_km: np.ndarray

    @staticmethod
    def from_arrays(hW, hB, hUC, hM) -> "SchemeBSurfaces":
        return SchemeBSurfaces(
            np.asarray(hW, dtype=np.float64), np.asarray(hB, dtype=np.float64),
            np.asarray(hUC, dtype=np.float64), np.asarray(hM, dtype=np.float64)
        )
    
    @staticmethod
    def from_constant(ny, nx, hW, Ts, hM_arr, f_uc, tmin_uc, tmin_lc):
        hW_arr = np.full((ny, nx), float(hW), dtype=np.float64)
        hB_arr = hW_arr + float(Ts)
        Tc = hM_arr - hB_arr
        hUC = hB_arr + f_uc * Tc
        hUC = np.maximum(hUC, hB_arr + tmin_uc)
        hUC = np.minimum(hUC, hM_arr - tmin_lc)
        return SchemeBSurfaces(hW_arr, hB_arr, hUC, hM_arr.astype(np.float64))

# ==============================================================================
# 3. 物理引擎 (Physics Engine)
# ==============================================================================
def prism_gz_unit_density_mgal(obsx, obsy, obsz, x1, x2, y1, y2, z1, z2, eps=1e-12):
    """计算单位密度长方体对观测点的垂直重力异常 (Nagy/Blakely 公式)"""
    xs = np.stack([x1 - obsx, x2 - obsx], axis=0)
    ys = np.stack([y1 - obsy, y2 - obsy], axis=0)
    zs = np.stack([z1 - obsz, z2 - obsz], axis=0)
    gz = np.zeros_like(x1, dtype=np.float64)

    for i in (0, 1):
        x = xs[i]
        for j in (0, 1):
            y = ys[j]
            for k in (0, 1):
                z = zs[k]
                r = np.sqrt(x*x + y*y + z*z)
                term = (x * np.log(np.maximum(y + r, eps)) +
                        y * np.log(np.maximum(x + r, eps)) -
                        z * np.arctan2(x * y, z * r))
                sgn = -1.0 if ((i + j + k) % 2 == 1) else 1.0
                gz += sgn * term
    return (G_SI * gz * SI_TO_MGAL).astype(np.float64)

class GravityForward:
    def __init__(self, grid: VoxelGrid, obs: ObservationSet, kernel_dtype=np.float32):
        self.grid = grid
        self.obs = obs
        self.K = self._build_kernel(dtype=kernel_dtype)

    def _build_kernel(self, dtype) -> np.ndarray:
        x1, x2, y1, y2, z1, z2 = self.grid.voxel_bounds_m()
        K = np.empty((self.obs.nobs, self.grid.nvox), dtype=np.float64)
        for j in range(self.obs.nobs):
            ox = float(self.obs.x_km[j]) * KM_TO_M
            oy = float(self.obs.y_km[j]) * KM_TO_M
            oz = float(self.obs.z_km[j]) * KM_TO_M
            K[j, :] = prism_gz_unit_density_mgal(ox, oy, oz, x1, x2, y1, y2, z1, z2)
        return K.astype(dtype)

    def predict(self, drho_vox: np.ndarray) -> np.ndarray:
        return (self.K @ drho_vox.astype(self.K.dtype)).astype(np.float64)

def precompute_layer_code_schemeB(grid: VoxelGrid, surfaces: SchemeBSurfaces) -> np.ndarray:
    """计算每个体素属于哪一层 (0:UC, 1:LC, 2:Mantle, -1:Other)"""
    centers = grid.voxel_centers_km()
    zc = centers[:, 2]
    col = grid.voxel_column_index()
    
    # Flatten surfaces
    hB_v  = surfaces.hB_km.ravel()[col]
    hUC_v = surfaces.hUC_km.ravel()[col]
    hM_v  = surfaces.hM_km.ravel()[col]

    layer = np.full(grid.nvox, -1, dtype=np.int8)
    layer[(zc >= hB_v)  & (zc < hUC_v)] = 0
    layer[(zc >= hUC_v) & (zc < hM_v)]  = 1
    layer[zc >= hM_v] = 2
    return layer

# ==============================================================================
# 4. 贝叶斯模型 (Bayesian Model)
# ==============================================================================
@dataclass
class VoronoiModel:
    seeds_xy_km: np.ndarray
    drho_kgm3: np.ndarray
    sigma_mgal: float
    @property
    def K(self) -> int: return self.seeds_xy_km.shape[0]

@dataclass
class PriorConfig:
    K_min: int = PRIOR_K_MIN
    K_max: int = PRIOR_K_MAX
    sigma_rho_uc: float = SIGMA_RHO_UC
    sigma_rho_lc: float = SIGMA_RHO_LC
    sigma_rho_m: float  = SIGMA_RHO_M
    sigma_min_mgal: float = SIGMA_MIN_MGAL
    sigma_max_mgal: float = SIGMA_MAX_MGAL

@dataclass
class ProposalConfig:
    p_birth: float = P_BIRTH
    p_death: float = P_DEATH
    p_move: float = P_MOVE
    p_value: float = P_VALUE
    p_sigma: float = P_SIGMA
    step_xy_km: float = STEP_XY_KM
    step_rho: float = STEP_RHO
    step_log_sigma: float = STEP_LOG_SIGMA

def log_prior(model: VoronoiModel, bounds: Tuple[float,float,float,float], prior: PriorConfig) -> float:
    xmin, xmax, ymin, ymax = bounds
    K = model.K
    if not (prior.K_min <= K <= prior.K_max): return -np.inf
    
    sx, sy = model.seeds_xy_km[:, 0], model.seeds_xy_km[:, 1]
    if np.any(sx < xmin) or np.any(sx > xmax) or np.any(sy < ymin) or np.any(sy > ymax):
        return -np.inf

    sig = float(model.sigma_mgal)
    if not (prior.sigma_min_mgal <= sig <= prior.sigma_max_mgal): return -np.inf

    # Density priors (Gaussian)
    lp = 0.0
    dr = model.drho_kgm3
    lp += -0.5*np.sum((dr[:,0]/prior.sigma_rho_uc)**2) - K*math.log(prior.sigma_rho_uc*math.sqrt(2*math.pi))
    lp += -0.5*np.sum((dr[:,1]/prior.sigma_rho_lc)**2) - K*math.log(prior.sigma_rho_lc*math.sqrt(2*math.pi))
    lp += -0.5*np.sum((dr[:,2]/prior.sigma_rho_m )**2) - K*math.log(prior.sigma_rho_m *math.sqrt(2*math.pi))
    lp += -math.log(sig) # Jeffreys prior for sigma
    return lp

def log_likelihood_gaussian(d_obs, d_pred, sigma_mgal):
    r = d_obs - d_pred
    sig2 = sigma_mgal**2
    return -0.5 * (np.sum(r*r)/sig2 + r.size*math.log(2*math.pi*sig2))

def voronoi_assign_cells(seeds_xy, col_xy):
    if HAVE_SCIPY:
        tree = cKDTree(seeds_xy)
        _, idx = tree.query(col_xy, k=1)
        return idx.astype(np.int32)
    dx = col_xy[:,None,0] - seeds_xy[None,:,0]
    dy = col_xy[:,None,1] - seeds_xy[None,:,1]
    return np.argmin(dx*dx + dy*dy, axis=1).astype(np.int32)

def model_to_drho_vox(model: VoronoiModel, col_xy, col_index_of_vox, layer_code):
    cell_id = voronoi_assign_cells(model.seeds_xy_km, col_xy)
    drho_vox = np.zeros(layer_code.size, dtype=np.float64)
    valid = layer_code >= 0
    if np.any(valid):
        col_of_v = col_index_of_vox[valid]
        cell_of_v = cell_id[col_of_v]
        layer_of_v = layer_code[valid].astype(np.int32)
        drho_vox[valid] = model.drho_kgm3[cell_of_v, layer_of_v]
    return drho_vox

# ==============================================================================
# 5. RJMCMC 采样器 (The Sampler)
# ==============================================================================
class RJMCMCSampler:
    def __init__(self, rng, bounds, prior, prop, forward, col_xy, col_index_of_vox, layer_code, d_obs_mgal):
        self.rng = rng
        self.bounds = bounds
        self.prior = prior
        self.prop = prop
        self.forward = forward
        self.col_xy = col_xy
        self.col_index_of_vox = col_index_of_vox
        self.layer_code = layer_code
        self.d_obs = d_obs_mgal
        self.tot = {k:0 for k in ["birth", "death", "move", "value", "sigma"]}
        self.acc = {k:0 for k in ["birth", "death", "move", "value", "sigma"]}

    def _draw_move_type(self, K):
        keys = ["birth","death","move","value","sigma"]
        w = np.array([self.prop.p_birth, self.prop.p_death, self.prop.p_move, self.prop.p_value, self.prop.p_sigma])
        if K <= self.prior.K_min: w[1] = 0.0
        if K >= self.prior.K_max: w[0] = 0.0
        if w.sum() <= 0: return "value"
        return self.rng.choice(keys, p=w/w.sum())

    def _gauss_logpdf(self, x, sigma):
        return -0.5*(x/sigma)**2 - math.log(sigma * math.sqrt(2.0*math.pi))

    def step(self, cur: VoronoiModel, cur_lp, cur_ll):
        mv = self._draw_move_type(cur.K)
        prop_model, log_qratio = cur, 0.0
        
        # Proposal Logic
        if mv == "birth":
            x_new = self.rng.uniform(self.bounds[0], self.bounds[1])
            y_new = self.rng.uniform(self.bounds[2], self.bounds[3])
            seeds_new = np.vstack([cur.seeds_xy_km, [x_new, y_new]])
            dr_new_row = np.array([
                self.rng.normal(0, self.prior.sigma_rho_uc),
                self.rng.normal(0, self.prior.sigma_rho_lc),
                self.rng.normal(0, self.prior.sigma_rho_m)
            ])
            dr_new = np.vstack([cur.drho_kgm3, dr_new_row])
            prop_model = VoronoiModel(seeds_new, dr_new, cur.sigma_mgal)
            
            pb = max(self.prop.p_birth, 1e-12)
            pd = max(self.prop.p_death, 1e-12)
            logpdf_new = sum(self._gauss_logpdf(dr_new_row[i], s) for i,s in enumerate(
                [self.prior.sigma_rho_uc, self.prior.sigma_rho_lc, self.prior.sigma_rho_m]))
            log_qratio = math.log(pd) - math.log(pb) - math.log(cur.K + 1) - logpdf_new

        elif mv == "death":
            idx = int(self.rng.integers(0, cur.K))
            dr_rem = cur.drho_kgm3[idx]
            seeds_new = np.delete(cur.seeds_xy_km, idx, axis=0)
            dr_new = np.delete(cur.drho_kgm3, idx, axis=0)
            prop_model = VoronoiModel(seeds_new, dr_new, cur.sigma_mgal)
            
            pb = max(self.prop.p_birth, 1e-12)
            pd = max(self.prop.p_death, 1e-12)
            logpdf_rem = sum(self._gauss_logpdf(dr_rem[i], s) for i,s in enumerate(
                [self.prior.sigma_rho_uc, self.prior.sigma_rho_lc, self.prior.sigma_rho_m]))
            log_qratio = math.log(pb) - math.log(pd) + math.log(cur.K) + logpdf_rem

        elif mv == "move":
            seeds = cur.seeds_xy_km.copy()
            i = int(self.rng.integers(0, cur.K))
            seeds[i,0] = np.clip(seeds[i,0] + self.rng.normal(0, self.prop.step_xy_km), self.bounds[0], self.bounds[1])
            seeds[i,1] = np.clip(seeds[i,1] + self.rng.normal(0, self.prop.step_xy_km), self.bounds[2], self.bounds[3])
            prop_model = VoronoiModel(seeds, cur.drho_kgm3.copy(), cur.sigma_mgal)
            log_qratio = 0.0

        elif mv == "value":
            dr = cur.drho_kgm3.copy()
            i = int(self.rng.integers(0, cur.K))
            k = int(self.rng.integers(0, 3))
            dr[i,k] += self.rng.normal(0, self.prop.step_rho)
            prop_model = VoronoiModel(cur.seeds_xy_km.copy(), dr, cur.sigma_mgal)
            log_qratio = 0.0

        elif mv == "sigma":
            log_sig_p = math.log(cur.sigma_mgal) + self.rng.normal(0, self.prop.step_log_sigma)
            prop_model = VoronoiModel(cur.seeds_xy_km.copy(), cur.drho_kgm3.copy(), math.exp(log_sig_p))
            log_qratio = log_sig_p - math.log(cur.sigma_mgal)

        self.tot[mv] += 1
        
        # Accept/Reject
        prop_lp = log_prior(prop_model, self.bounds, self.prior)
        if not np.isfinite(prop_lp): return cur, cur_lp, cur_ll, mv, False

        drho_vox = model_to_drho_vox(prop_model, self.col_xy, self.col_index_of_vox, self.layer_code)
        d_pred = self.forward.predict(drho_vox)
        prop_ll = log_likelihood_gaussian(self.d_obs, d_pred, prop_model.sigma_mgal)
        if not np.isfinite(prop_ll): return cur, cur_lp, cur_ll, mv, False

        log_alpha = (prop_ll + prop_lp) - (cur_ll + cur_lp) + log_qratio
        if math.log(self.rng.uniform(0, 1) + 1e-300) < log_alpha:
            self.acc[mv] += 1
            return prop_model, prop_lp, prop_ll, mv, True
        return cur, cur_lp, cur_ll, mv, False

# ==============================================================================
# 6. IO 与辅助函数 (IO & Utils)
# ==============================================================================
def load_obs_csv(path: Path) -> ObservationSet:
    df = pd.read_csv(path)
    if not all(c in df.columns for c in ["x_km","y_km","d_mgal"]):
        raise ValueError(f"CSV missing columns: {path}")
    z = df["z_km"].to_numpy() if "z_km" in df.columns else np.zeros(len(df))
    return ObservationSet(df["x_km"].values, df["y_km"].values, z, df["d_mgal"].values)

def load_schemeB_surfaces_nc(path: Path):
    ds = xr.open_dataset(path)
    return (ds["x_km"].values, ds["y_km"].values, 
            ds["water_depth_km"].values, ds["basement_depth_km"].values,
            ds["uc_base_depth_km"].values, ds["moho_depth_km"].values)

def make_demo_problem(rng):
    """生成测试用 Demo 数据"""
    xmin, xmax, ymin, ymax = 0., 160., 0., 160.
    dx, dy, dz = 16., 16., 3.
    x_edges = np.arange(xmin, xmax + 0.5*dx, dx)
    y_edges = np.arange(ymin, ymax + 0.5*dy, dy)
    z_edges = np.arange(0., 60. + 0.5*dz, dz)
    grid = VoxelGrid(x_edges, y_edges, z_edges)
    
    col_xy = grid.column_centers_km()
    X, Y = col_xy[:,0].reshape(grid.ny, grid.nx), col_xy[:,1].reshape(grid.ny, grid.nx)
    hM = 25. + 0.02*(X-80) + 0.02*(Y-80) + 4.*np.exp(-((X-95)**2+(Y-60)**2)/(2*25**2))
    
    surfaces = SchemeBSurfaces.from_constant(grid.ny, grid.nx, 3.0, 2.0, hM, 0.55, 3.0, 3.0)
    
    # True Model
    K_true = 7
    seeds = np.column_stack([rng.uniform(xmin, xmax, K_true), rng.uniform(ymin, ymax, K_true)])
    dr = np.column_stack([rng.normal(0,120,K_true), rng.normal(0,120,K_true), rng.normal(0,80,K_true)])
    dr[0,1] += 250.; dr[3,2] -= 180.
    m_true = VoronoiModel(seeds, dr, 0.30)
    
    # Obs
    obs_x, obs_y = col_xy[:,0], col_xy[:,1]
    fwd = GravityForward(grid, ObservationSet(obs_x, obs_y, np.zeros_like(obs_x), np.zeros_like(obs_x)))
    lc = precompute_layer_code_schemeB(grid, surfaces)
    ci = grid.voxel_column_index()
    d_clean = fwd.predict(model_to_drho_vox(m_true, col_xy, ci, lc))
    d_noisy = d_clean + rng.normal(0, 0.30, size=d_clean.size)
    obs = ObservationSet(obs_x, obs_y, np.zeros_like(obs_x), d_noisy)
    
    return grid, surfaces, obs, m_true, (xmin,xmax,ymin,ymax)

# ==============================================================================
# 7. 可视化系统 (Visualization System)
# ==============================================================================
def set_publication_style():
    plt.rcParams.update({
        "font.family": "serif", "font.serif": ["Times New Roman"],
        "mathtext.fontset": "stix", "font.size": 12, "axes.labelsize": 14,
        "axes.titlesize": 16, "savefig.dpi": 300, "xtick.direction": "in", "ytick.direction": "in"
    })

def _xy_edges_from_centers(xc, yc):
    dx, dy = np.median(np.diff(xc)), np.median(np.diff(yc))
    return (np.concatenate([[xc[0]-0.5*dx], xc+0.5*dx]), 
            np.concatenate([[yc[0]-0.5*dy], yc+0.5*dy]))

# ----------------- 图 1-12 核心绘图函数 -----------------
def plot_fig1_ppd(outdir, ppd_npz):
    dat = np.load(ppd_npz)
    dr, zmax = dat["dr_layers"], float(dat["zmax_km"])
    hB, hUC, hM = float(dat["hB_km"]), float(dat["hUC_km"]), float(dat["hM_km"])
    depths = np.linspace(0, zmax, PPD_NDEPTH)
    
    # 统计 PPD
    ppd_grid = np.zeros((PPD_NBINS, len(depths)))
    stats_mean, ci_l, ci_h = np.zeros_like(depths), np.zeros_like(depths), np.zeros_like(depths)
    val_edges = np.linspace(PPD_RHO_MIN, PPD_RHO_MAX, PPD_NBINS+1)
    
    true_model = np.zeros_like(depths) if "true_layers" in dat else None
    true_vals = dat["true_layers"] if true_model is not None else None

    for i, z in enumerate(depths):
        idx = 0 if z<hB else (0 if z<hUC else (1 if z<hM else 2))
        vals = dr[:, idx] if z >= hB else np.zeros(dr.shape[0])
        stats_mean[i], ci_l[i], ci_h[i] = np.mean(vals), np.percentile(vals, 2.5), np.percentile(vals, 97.5)
        hist, _ = np.histogram(vals, bins=val_edges)
        if hist.max() > 0: ppd_grid[:, i] = hist / hist.max()
        if true_model is not None: true_model[i] = true_vals[idx] if z >= hB else 0.0

    fig, ax = plt.subplots(figsize=(8, 6))
    dz = depths[1]-depths[0]
    d_edges = np.concatenate([[depths[0]-0.5*dz], depths+0.5*dz])
    im = ax.pcolormesh(d_edges, val_edges, ppd_grid, cmap="OrRd", shading='flat', rasterized=True)
    ax.plot(depths, stats_mean, "k-", lw=2, label="Mean")
    ax.plot(depths, ci_l, "k--", lw=1, alpha=0.6); ax.plot(depths, ci_h, "k--", lw=1, alpha=0.6)
    if true_model is not None: ax.plot(depths, true_model, "b-.", lw=2, label="True")
    
    ax.set(xlabel="Depth (km)", ylabel=r"$\Delta \rho$ (kg/m$^3$)", title="PPD")
    plt.colorbar(im, ax=ax, label="Normalized Probability")
    ax.legend()
    fig.savefig(outdir / "fig1_ppd.png"); plt.close(fig)

def plot_fig2_k_posterior(outdir, trace_npz):
    K = np.load(trace_npz)["K"]
    fig, ax = plt.subplots(figsize=(7, 5))
    ax.hist(K, bins=np.arange(K.min()-0.5, K.max()+1.5, 1), rwidth=0.8, color='gray', edgecolor='k')
    ax.set(xlabel="k", ylabel="Frequency", title="Posterior Dimension k")
    fig.savefig(outdir / "fig2_k_posterior.png"); plt.close(fig)

def plot_fig3_data_fit(outdir, postpred_npz):
    d = np.load(postpred_npz)
    obs, pred, std, sig = d["d_obs"], d["pred_mean"], d["pred_std"], float(d["sigma_mean"])
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True, gridspec_kw={'height_ratios':[2,1]})
    
    ax1.plot(obs, 'ko', ms=3, alpha=0.6, label='Obs')
    ax1.plot(pred, 'r-', lw=1.5, label='Pred Mean')
    tot_std = np.sqrt(std**2 + sig**2)
    ax1.fill_between(range(len(obs)), pred-1.96*tot_std, pred+1.96*tot_std, color='r', alpha=0.2)
    ax1.legend()
    
    resid = obs - pred
    ax2.plot(resid, 'k-', lw=1); ax2.axhline(0, c='r', ls='--')
    ax2.axhline(sig, c='gray', ls=':'); ax2.axhline(-sig, c='gray', ls=':')
    ax2.set_ylabel("Resid (mGal)")
    
    fig.savefig(outdir / "fig3_data_fit.png"); plt.close(fig)

def plot_fig4_interfaces(outdir, nc_path):
    ds = xr.open_dataset(nc_path)
    fig, ax = plt.subplots(figsize=(6, 8))
    bins = np.linspace(0, 60, 100)
    for name, c in zip(["basement_depth_km","uc_base_depth_km","moho_depth_km"], ['#1f77b4','#ff7f0e','#2ca02c']):
        if name in ds:
            ax.hist(ds[name].values.ravel(), bins=bins, orientation='horizontal', histtype='step', lw=2, label=name, color=c)
    ax.invert_yaxis(); ax.legend(); ax.set_title("Interface Depths")
    fig.savefig(outdir / "fig4_interface_depth_statistics.png"); plt.close(fig)

def plot_fig5_6_9_maps(outdir, nc_path):
    ds = xr.open_dataset(nc_path)
    xe, ye = _xy_edges_from_centers(ds["x_km"].values, ds["y_km"].values)
    
    # Fig 5: Means
    fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))
    all_v = np.concatenate([ds[v].values.ravel() for v in ["drho_uc_mean","drho_lc_mean","drho_m_mean"]])
    norm = mcolors.Normalize(vmin=-np.max(np.abs(all_v)), vmax=np.max(np.abs(all_v)))
    for ax, v in zip(axes, ["drho_uc_mean","drho_lc_mean","drho_m_mean"]):
        im = ax.pcolormesh(xe, ye, ds[v].values, cmap="RdBu_r", norm=norm, shading='flat')
        ax.set_aspect('equal'); ax.set_title(v)
    fig.colorbar(im, ax=axes, label="Mean Density"); fig.savefig(outdir / "fig5_posterior_mean_model.png"); plt.close(fig)

    # Fig 6: Stds
    if "drho_uc_std" in ds:
        fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))
        all_s = np.concatenate([ds[v].values.ravel() for v in ["drho_uc_std","drho_lc_std","drho_m_std"]])
        for ax, v in zip(axes, ["drho_uc_std","drho_lc_std","drho_m_std"]):
            im = ax.pcolormesh(xe, ye, ds[v].values, cmap="inferno_r", vmin=0, vmax=np.percentile(all_s, 99), shading='flat')
            ax.set_aspect('equal'); ax.set_title(v)
        fig.colorbar(im, ax=axes, label="Std Dev"); fig.savefig(outdir / "fig6_posterior_std_model.png"); plt.close(fig)

    # Fig 9: Nuclei
    if "nuclei_density" in ds:
        fig, ax = plt.subplots(figsize=(7, 6))
        im = ax.pcolormesh(xe, ye, ds["nuclei_density"].values, cmap="magma", shading='flat')
        ax.set_aspect('equal'); ax.set_title("Nuclei Density")
        plt.colorbar(im, ax=ax); fig.savefig(outdir / "fig9_nuclei_density.png"); plt.close(fig)

def plot_fig7_convergence(outdir, trace_npz):
    tr = np.load(trace_npz)
    fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
    axes[0].plot(tr["iter"], tr["loglike"], 'k-', lw=0.8); axes[0].set_ylabel("LogLike")
    axes[1].step(tr["iter"], tr["K"], where='post', color='r'); axes[1].set_ylabel("K")
    axes[2].plot(tr["iter"], tr["sigma"], 'b-'); axes[2].set_ylabel(r"$\sigma$")
    fig.savefig(outdir / "fig7_convergence_traces.png"); plt.close(fig)

def plot_fig8_spatial_resid(outdir, postpred_npz, obs_csv):
    d = np.load(postpred_npz)
    df = pd.read_csv(obs_csv)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    sc1 = ax1.scatter(df["x_km"], df["y_km"], c=d["d_obs"]-d["pred_mean"], cmap="RdBu_r", s=20)
    plt.colorbar(sc1, ax=ax1, label="Resid")
    ax1.set_title("Spatial Residuals")
    
    sc2 = ax2.scatter(df["x_km"], df["y_km"], c=d["pred_std"], cmap="viridis", s=20)
    plt.colorbar(sc2, ax=ax2, label="Uncertainty")
    ax2.set_title("Predictive Std")
    fig.savefig(outdir / "fig8_spatial_residuals.png"); plt.close(fig)

def plot_fig12_complexity(outdir, trace_npz):
    tr = np.load(trace_npz)
    fig, ax = plt.subplots(figsize=(8, 6))
    sc = ax.scatter(tr["K"]+np.random.uniform(-0.2,0.2,len(tr["K"])), tr["loglike"], c=tr["iter"], cmap="viridis", s=10, alpha=0.5)
    plt.colorbar(sc, ax=ax, label="Iter")
    ax.set(xlabel="K", ylabel="LogL", title="Occam's Razor")
    fig.savefig(outdir / "fig12_complexity_tradeoff.png"); plt.close(fig)

def plot_multichain_traces(outdir, chain_dirs):
    if len(chain_dirs) < 2: return
    fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
    colors = plt.cm.jet(np.linspace(0, 1, len(chain_dirs)))
    for i, cd in enumerate(chain_dirs):
        if not (cd/"trace.npz").exists(): continue
        tr = np.load(cd/"trace.npz")
        axes[0].plot(tr["iter"], tr["loglike"], c=colors[i], alpha=0.6, lw=0.8)
        axes[1].plot(tr["iter"], tr["K"] + np.random.uniform(-0.2,0.2,len(tr["K"])), c=colors[i], alpha=0.6, lw=0.8)
        axes[2].plot(tr["iter"], tr["sigma"], c=colors[i], alpha=0.6, lw=0.8)
    axes[0].set_title("Multi-chain Convergence"); fig.savefig(outdir / "fig_multichain.png"); plt.close(fig)

def generate_all_figures(outdir, trace_path, ppd_path, postpred_path, nc_path, obs_path):
    print("Generating Figures...")
    set_publication_style()
    plot_fig1_ppd(outdir, ppd_path)
    plot_fig2_k_posterior(outdir, trace_path)
    plot_fig3_data_fit(outdir, postpred_path)
    plot_fig4_interfaces(outdir, nc_path)
    plot_fig5_6_9_maps(outdir, nc_path)
    plot_fig7_convergence(outdir, trace_path)
    plot_fig8_spatial_resid(outdir, postpred_path, obs_path)
    plot_fig12_complexity(outdir, trace_path)

# ==============================================================================
# 8. MCMC 主引擎 (MCMC Engine)
# ==============================================================================
def initial_model_random(rng, bounds, prior, K0=4, sigma0=0.6) -> VoronoiModel:
    seeds = np.column_stack([rng.uniform(bounds[0], bounds[1], K0), rng.uniform(bounds[2], bounds[3], K0)])
    dr = np.column_stack([rng.normal(0, prior.sigma_rho_uc, K0), 
                          rng.normal(0, prior.sigma_rho_lc, K0), 
                          rng.normal(0, prior.sigma_rho_m, K0)])
    return VoronoiModel(seeds, dr, float(sigma0))

def run_mcmc(outdir, grid, surfaces, obs, obs_csv_path, bounds, seed, true_model=None):
    outdir.mkdir(parents=True, exist_ok=True)
    rng = np.random.default_rng(seed)
    
    # Setup
    col_xy = grid.column_centers_km()
    col_idx = grid.voxel_column_index()
    layer_code = precompute_layer_code_schemeB(grid, surfaces)
    fwd = GravityForward(grid, obs)
    
    # Configs
    prior = PriorConfig()
    prop = ProposalConfig()
    
    # Init
    cur = initial_model_random(rng, bounds, prior)
    cur_lp = log_prior(cur, bounds, prior)
    if not np.isfinite(cur_lp): raise RuntimeError("Bad init prior")
    cur_ll = log_likelihood_gaussian(obs.d_mgal, fwd.predict(model_to_drho_vox(cur, col_xy, col_idx, layer_code)), cur.sigma_mgal)
    
    sampler = RJMCMCSampler(rng, bounds, prior, prop, fwd, col_xy, col_idx, layer_code, obs.d_mgal)
    
    # Storage & Online Stats (Welford)
    traces = {"iter":[], "K":[], "sigma":[], "loglike":[]}
    ppd_samples = []
    
    # PPD Column selection
    ix, iy = (PPD_IX, PPD_IY) if PPD_COLUMN_MODE=="manual" else (grid.nx//2, grid.ny//2)
    ppd_col_id = iy * grid.nx + ix
    
    # Online accumulators
    n_post = 0
    sum_dr = [np.zeros(len(col_xy)), np.zeros(len(col_xy)), np.zeros(len(col_xy))]
    sq_dr  = [np.zeros(len(col_xy)), np.zeros(len(col_xy)), np.zeros(len(col_xy))]
    nuclei_map = np.zeros((grid.ny, grid.nx), dtype=np.int32)
    sum_pred, sq_pred, sum_sig = np.zeros(obs.nobs), np.zeros(obs.nobs), 0.0

    print(f"Start MCMC: {N_ITER} iters, Seed={seed}")
    t0 = time.time()
    
    for it in range(1, N_ITER+1):
        cur, cur_lp, cur_ll, _, _ = sampler.step(cur, cur_lp, cur_ll)
        
        if it % 1000 == 0:
            print(f"Iter {it}/{N_ITER}, K={cur.K}, LL={cur_ll:.1f}, Time={time.time()-t0:.1f}s")
            
        if it > BURN_IN and (it - BURN_IN) % THIN == 0:
            # Trace
            traces["iter"].append(it); traces["K"].append(cur.K)
            traces["sigma"].append(cur.sigma_mgal); traces["loglike"].append(cur_ll)
            
            # Spatial Stats
            cell_ids = voronoi_assign_cells(cur.seeds_xy_km, col_xy)
            for k in range(3):
                val = cur.drho_kgm3[cell_ids, k]
                sum_dr[k] += val; sq_dr[k] += val**2
            
            # Nuclei Density
            H, _, _ = np.histogram2d(cur.seeds_xy_km[:,1], cur.seeds_xy_km[:,0], bins=[grid.y_edges_km, grid.x_edges_km])
            nuclei_map += H.astype(np.int32)
            
            # Predictive
            d_pred = fwd.predict(model_to_drho_vox(cur, col_xy, col_idx, layer_code))
            sum_pred += d_pred; sq_pred += d_pred**2; sum_sig += cur.sigma_mgal
            
            # PPD Sample
            ppd_samples.append(cur.drho_kgm3[cell_ids[ppd_col_id], :])
            n_post += 1

    # Save Results
    # 1. Trace
    np.savez_compressed(outdir/"trace.npz", **{k:np.array(v) for k,v in traces.items()})
    
    # 2. NetCDF Stats
    if n_post > 0:
        means = [s/n_post for s in sum_dr]
        stds  = [np.sqrt(np.maximum(0, q/n_post - m**2)) for s, q, m in zip(sum_dr, sq_dr, means)]
        ds_out = xr.Dataset(
            data_vars={
                **{f"drho_{tag}_mean":(("y_km","x_km"), m.reshape(grid.ny,grid.nx)) for tag,m in zip(["uc","lc","m"], means)},
                **{f"drho_{tag}_std":(("y_km","x_km"), s.reshape(grid.ny,grid.nx)) for tag,s in zip(["uc","lc","m"], stds)},
                "nuclei_density": (("y_km","x_km"), nuclei_map/n_post),
                "basement_depth_km": (("y_km","x_km"), surfaces.hB_km),
                "uc_base_depth_km": (("y_km","x_km"), surfaces.hUC_km),
                "moho_depth_km": (("y_km","x_km"), surfaces.hM_km)
            },
            coords={"x_km": 0.5*(grid.x_edges_km[1:]+grid.x_edges_km[:-1]), "y_km": 0.5*(grid.y_edges_km[1:]+grid.y_edges_km[:-1])}
        )
        ds_out.to_netcdf(outdir/"posterior_mean_columns.nc")
        
        # 3. Post Pred
        p_mean = sum_pred/n_post
        p_std = np.sqrt(np.maximum(0, sq_pred/n_post - p_mean**2))
        np.savez_compressed(outdir/"posterior_predictive.npz", d_obs=obs.d_mgal, pred_mean=p_mean, pred_std=p_std, sigma_mean=sum_sig/n_post)
        
        # 4. PPD Samples
        save_dict = {"dr_layers": np.stack(ppd_samples), "hB_km":surfaces.hB_km[iy,ix], "hUC_km":surfaces.hUC_km[iy,ix], "hM_km":surfaces.hM_km[iy,ix], "zmax_km": grid.z_edges_km[-1]}
        if true_model:
             t_cid = voronoi_assign_cells(true_model.seeds_xy_km, col_xy)[ppd_col_id]
             save_dict["true_layers"] = true_model.drho_kgm3[t_cid,:]
        np.savez_compressed(outdir/"ppd_column_samples.npz", **save_dict)
        
        # Plot
        generate_all_figures(outdir, outdir/"trace.npz", outdir/"ppd_column_samples.npz", outdir/"posterior_predictive.npz", outdir/"posterior_mean_columns.nc", obs_csv_path)

# ==============================================================================
# 9. 主入口 (Main Entry)
# ==============================================================================
def _single_chain_task(chain_id, base_outdir, common_kwargs):
    chain_dir = base_outdir / f"chain_{chain_id}"
    seed = common_kwargs["seed"] + chain_id + 999
    print(f"--- Chain {chain_id} Started ---")
    try:
        run_mcmc(outdir=chain_dir, **{**common_kwargs, "seed": seed})
        return chain_dir
    except Exception:
        traceback.print_exc()
        return None

def main():
    print(f"=== RJMCMC Gravity Inversion ({MODE.upper()} Mode) ===")
    
    # 准备数据
    if MODE == "demo":
        grid, surfaces, obs, m_true, bounds = make_demo_problem(np.random.default_rng(SEED))
        obs_path_fake = OUTDIR / "demo_dummy.csv"
        obs_path_fake.parent.mkdir(parents=True, exist_ok=True)
        pd.DataFrame({"x_km":obs.x_km,"y_km":obs.y_km,"d_mgal":obs.d_mgal}).to_csv(obs_path_fake, index=False)
        kwargs = {"grid":grid, "surfaces":surfaces, "obs":obs, "obs_csv_path":obs_path_fake, "bounds":bounds, "seed":SEED, "true_model":m_true}
    else:
        obs = load_obs_csv(OBS_CSV)
        x, y, hW, hB, hUC, hM = load_schemeB_surfaces_nc(SCHEMEB_NC)
        dx, dy = np.median(np.diff(x)), np.median(np.diff(y))
        grid = VoxelGrid(np.concatenate([[x[0]-0.5*dx], x+0.5*dx]), np.concatenate([[y[0]-0.5*dy], y+0.5*dy]), np.arange(0, ZMAX_KM+0.5*DZ_KM, DZ_KM))
        surfaces = SchemeBSurfaces.from_arrays(hW, hB, hUC, hM)
        bounds = (grid.x_edges_km[0], grid.x_edges_km[-1], grid.y_edges_km[0], grid.y_edges_km[-1])
        kwargs = {"grid":grid, "surfaces":surfaces, "obs":obs, "obs_csv_path":OBS_CSV, "bounds":bounds, "seed":SEED}

    # 并行运行多链
    chain_dirs = Parallel(n_jobs=N_JOBS)(
        delayed(_single_chain_task)(i, OUTDIR, kwargs) for i in range(N_CHAINS)
    )
    
    # 汇总绘图
    valid_chains = [d for d in chain_dirs if d is not None]
    if valid_chains:
        plot_multichain_traces(OUTDIR, valid_chains)
    print("=== All Done ===")

if __name__ == "__main__":
    main()

=== RJMCMC Gravity Inversion (REAL Mode) ===
--- Chain 0 Started ---
Start MCMC: 50000 iters, Seed=20252224
Iter 1000/50000, K=7, LL=-2978.3, Time=0.7s
Iter 2000/50000, K=8, LL=-1762.3, Time=1.5s
Iter 3000/50000, K=10, LL=-1692.3, Time=2.2s
Iter 4000/50000, K=10, LL=-1605.6, Time=2.9s
Iter 5000/50000, K=10, LL=-1543.6, Time=3.6s
Iter 6000/50000, K=10, LL=-1516.4, Time=4.4s
Iter 7000/50000, K=10, LL=-1495.6, Time=5.2s
Iter 8000/50000, K=10, LL=-1466.3, Time=6.0s
Iter 9000/50000, K=10, LL=-1454.7, Time=6.8s
Iter 10000/50000, K=10, LL=-1436.9, Time=7.5s
Iter 11000/50000, K=11, LL=-1393.0, Time=8.3s
Iter 12000/50000, K=10, LL=-1379.6, Time=9.1s
Iter 13000/50000, K=10, LL=-1380.7, Time=9.9s
Iter 14000/50000, K=10, LL=-1366.6, Time=10.7s
Iter 15000/50000, K=10, LL=-1356.1, Time=11.4s
Iter 16000/50000, K=10, LL=-1354.9, Time=12.2s
Iter 17000/50000, K=10, LL=-1352.3, Time=12.9s
Iter 18000/50000, K=10, LL=-1349.7, Time=13.7s
Iter 19000/50000, K=10, LL=-1343.6, Time=14.4s
Iter 20000/50000, K=10,

In [2]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from scipy.spatial import cKDTree
from pathlib import Path

# ================= 配置 =================
# 输入文件路径（请根据实际情况修改）
NC_PATH = Path(r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\schemeB_surfaces_mock.nc")
NPZ_PATH = Path(r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\truth_model_mock.npz")
# 输出图片路径
OUT_PATH = Path(r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\12_30\comparison_01_true_model.png")

# ***关键***：设置统一的色标范围，以便与结果图对比
# 根据经验或数据范围设置，例如 +/- 300 或 +/- 400 kg/m3
VMIN = -350
VMAX = 350
CMAP = "RdBu_r" # 红蓝反向色标，适合展示正负异常
# =======================================

def plot_true_model():
    print(f"Loading geometry from {NC_PATH}...")
    try:
        ds_surf = xr.open_dataset(NC_PATH)
        x_edges = ds_surf.x_km.values
        y_edges = ds_surf.y_km.values
    except FileNotFoundError:
        print(f"[Error] Needed file not found: {NC_PATH}")
        return

    # 计算网格中心点
    xc = 0.5 * (x_edges[:-1] + x_edges[1:])
    yc = 0.5 * (y_edges[:-1] + y_edges[1:])
    X, Y = np.meshgrid(xc, yc)
    grid_points = np.column_stack([X.ravel(), Y.ravel()])
    ny, nx = X.shape

    print(f"Loading true model definition from {NPZ_PATH}...")
    try:
        truth = np.load(NPZ_PATH)
        # 假设 npz 文件包含 'seeds_xy_km' 和 'drho_kgm3' 这两个键
        # 如果你的文件名不一样，请在这里修改
        seeds = truth['seeds_xy_km']     # Shape: (K, 2)
        drho_vals = truth['drho_kgm3']   # Shape: (K, 3)
    except FileNotFoundError:
        print(f"[Error] Needed file not found: {NPZ_PATH}")
        return
    except KeyError as e:
        print(f"[Error] Key not found in NPZ file: {e}. Please check file structure.")
        print(f"Available keys: {list(truth.keys())}")
        return

    print("Reconstructing Voronoi map on grid...")
    # 使用 KDTree 快速找到每个网格点最近的种子点
    tree = cKDTree(seeds)
    _, nearest_seed_idx = tree.query(grid_points, k=1)

    # 将每个Voronoi单元的密度值映射回网格
    # drho_vals[nearest_seed_idx] 的形状是 (N_grid_points, 3)
    # 我们需要把它 reshape 成 (ny, nx, 3)
    drho_map_flat = drho_vals[nearest_seed_idx]
    drho_uc = drho_map_flat[:, 0].reshape(ny, nx)
    drho_lc = drho_map_flat[:, 1].reshape(ny, nx)
    drho_m  = drho_map_flat[:, 2].reshape(ny, nx)

    # 绘图
    print("Plotting...")
    fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharex=True, sharey=True)
    
    norm = mcolors.Normalize(vmin=VMIN, vmax=VMAX)
    
    maps = [drho_uc, drho_lc, drho_m]
    titles = ["True Upper Crust (UC)", "True Lower Crust (LC)", "True Mantle (M)"]
    
    im = None
    for ax, data, title in zip(axes, maps, titles):
        # 使用 pcolormesh 绘制，注意传入的是网格边缘
        im = ax.pcolormesh(x_edges, y_edges, data, cmap=CMAP, norm=norm, shading='flat')
        ax.set_aspect('equal')
        ax.set_title(title, fontsize=14)
        ax.set_xlabel("X (km)")
        axes[0].set_ylabel("Y (km)")

    # 添加统一色标尺
    cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7]) # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label(r"Density Contrast ($\Delta\rho$) [kg/m$^3$]", fontsize=12)
    
    fig.suptitle(f"True Model Structure (Mock Data)\nFixed Color Range: [{VMIN}, {VMAX}]", fontsize=16, y=1.05)
    
    # 保存
    OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(OUT_PATH, dpi=300, bbox_inches='tight')
    print(f"True model plot saved to: {OUT_PATH}")
    plt.close(fig)

if __name__ == "__main__":
    # 设置出版级绘图风格
    plt.rcParams.update({
        "font.family": "serif", "mathtext.fontset": "stix", 
        "font.size": 12, "xtick.direction": "in", "ytick.direction": "in"
    })
    plot_true_model()

Loading geometry from E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\schemeB_surfaces_mock.nc...
Loading true model definition from E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\truth_model_mock.npz...
Reconstructing Voronoi map on grid...
Plotting...
True model plot saved to: E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\12_30\comparison_01_true_model.png


In [4]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path

# ================= 配置 =================
# 输入文件路径：指向你主程序输出的 NC 文件
# 假设你的输出在 output/Outdata/rjmcmc/chatgpt 下
RESULT_NC_PATH = Path(r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\12_30\chain_0\posterior_mean_columns.nc")
# 输出图片路径
OUT_PATH = Path(r"E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\12_30\chain_0\comparison_02_inversion_mean.png")

# ***关键***：必须与真实模型使用相同的范围!!!
VMIN = -350
VMAX = 350
CMAP = "RdBu_r"
# =======================================

def plot_inversion_result():
    print(f"Loading inversion results from {RESULT_NC_PATH}...")
    try:
        ds = xr.open_dataset(RESULT_NC_PATH)
    except FileNotFoundError:
        print(f"[Error] Result file not found: {RESULT_NC_PATH}")
        print("Please ensure you have run the main RJMCMC script and produced this file.")
        return

    # 获取网格中心和边缘用于绘图
    xc = ds.x_km.values
    yc = ds.y_km.values
    # 推算边缘 (pcolormesh 需要)
    dx, dy = xc[1]-xc[0], yc[1]-yc[0]
    x_edges = np.concatenate([[xc[0]-0.5*dx], xc+0.5*dx])
    y_edges = np.concatenate([[yc[0]-0.5*dy], yc+0.5*dy])

    # 提取后验均值数据
    # 变量名需要与主程序中保存的一致 (drho_uc_mean, drho_lc_mean, drho_m_mean)
    try:
        drho_uc_mean = ds.drho_uc_mean.values
        drho_lc_mean = ds.drho_lc_mean.values
        drho_m_mean  = ds.drho_m_mean.values
    except AttributeError as e:
         print(f"[Error] Missing variable in NetCDF: {e}.")
         print("Did the main script run successfully and save density means?")
         return

    # 绘图
    print("Plotting...")
    fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharex=True, sharey=True)
    
    # 使用与真实模型相同的 Normalize
    norm = mcolors.Normalize(vmin=VMIN, vmax=VMAX)
    
    maps = [drho_uc_mean, drho_lc_mean, drho_m_mean]
    titles = ["Posterior Mean UC", "Posterior Mean LC", "Posterior Mean Mantle"]
    
    im = None
    for ax, data, title in zip(axes, maps, titles):
        im = ax.pcolormesh(x_edges, y_edges, data, cmap=CMAP, norm=norm, shading='flat')
        ax.set_aspect('equal')
        ax.set_title(title, fontsize=14)
        ax.set_xlabel("X (km)")
        axes[0].set_ylabel("Y (km)")

    # 添加统一色标尺
    cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label(r"Density Contrast ($\Delta\rho$) [kg/m$^3$]", fontsize=12)
    
    fig.suptitle(f"Inversion Result (Posterior Mean)\nFixed Color Range: [{VMIN}, {VMAX}]", fontsize=16, y=1.05)
    
    # 保存
    # 存到和脚本相同的目录下
    OUT_PATH_LOCAL = Path(OUT_PATH.name) 
    fig.savefig(OUT_PATH_LOCAL, dpi=300, bbox_inches='tight')
    print(f"Inversion result plot saved to: {OUT_PATH_LOCAL.absolute()}")
    plt.close(fig)

if __name__ == "__main__":
    plt.rcParams.update({
        "font.family": "serif", "mathtext.fontset": "stix", 
        "font.size": 12, "xtick.direction": "in", "ytick.direction": "in"
    })
    plot_inversion_result()

Loading inversion results from E:\wjy\Gravity\SCS_Gravity\out\Outdata\rjmcmc\12_30\chain_0\posterior_mean_columns.nc...
Plotting...
Inversion result plot saved to: e:\wjy\Gravity\SCS_Gravity\src\comparison_02_inversion_mean.png
