目前的绘图逻辑从上到下是
1.prior和post；
2.parameter和hydrograph
3.parameter的ensemble version和 mean-std version
4.hydrograph的frame和anaimation
，然后中间可能有不同station的hydrograph

In [1]:
import sys
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import hydroeval as he
import copy
current_file = 'visualize_improved.ipynb'
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(current_file), '..', '..')))
from utils import process_yaml
from ifc_usgs_fileorder import load_usgs_mapping_from_path


# 全局 matplotlib 设置
plt.rcParams['font.size'] = 12
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

ModuleNotFoundError: No module named 'utils'

In [2]:
# ===================== 评价指标函数 =====================
def kge_metric(obs, sim):
    """计算模拟与观测数据的 NSE 指标（可改为真实KGE计算）"""
    return he.evaluator(he.nse, sim[obs > 0], obs[obs > 0])

def peak_relative_diff(obs, sim):
    """计算模拟与观测峰值的相对差异"""
    return (np.max(sim) - np.max(obs)) / np.max(obs)

def peak_timing_diff(obs, sim):
    """计算模拟与观测达到峰值时刻的索引差"""
    return np.argmax(sim) - np.argmax(obs)

In [7]:
# ------------------ 工具函数：清空并创建文件夹 ------------------
def clear_and_create_dir(dir_path):
    """若目录存在，则清空（删除整个目录），再重新创建该目录"""
    if os.path.exists(dir_path):
        shutil.rmtree(dir_path)
    os.makedirs(dir_path)


# ===================== 动画帧绘制函数 =====================
def draw_animation_frame(iter_idx, ensemble_sim, station_idx, time_axis, measured_data, station_label):
    """
    绘制单个动画帧。

    参数:
      iter_idx: 当前同化迭代编号（格式化为两位数）。
      ensemble_sim: 模型模拟粒子数据，形状为 (ensemble_size, time_steps, num_stations)。
      station_idx: 当前绘制的观测站索引（直接显示，无格式宽度）。
      time_axis: 时间序列 (pandas DatetimeIndex)。
      measured_data: 观测数据数组，形状为 (time_steps, num_stations)。
      station_label: 站点名称，用于图标题显示（此处我们希望显示 gauge id）。
    """
    plt.clf()  # 清空当前图形
    station_ensemble = ensemble_sim[:, :, station_idx]
    median_sim = np.median(station_ensemble, axis=0)
    obs_series = measured_data[:, station_idx]

    plt.plot(time_axis, median_sim, 'b-', label='Particle median')
    plt.fill_between(time_axis,
                     np.percentile(station_ensemble, 5, axis=0),
                     np.percentile(station_ensemble, 95, axis=0),
                     color='blue', alpha=0.3)
    plt.plot(time_axis, obs_series, 'k--', label='Observed')

    try:
        kge_val = kge_metric(obs_series, median_sim)
        pr_diff = peak_relative_diff(obs_series, median_sim)
        pt_diff = peak_timing_diff(obs_series, median_sim)
        print(f"Iteration {iter_idx:02d}, Gauge {station_label}: KGE={kge_val}, PeakDiff={pr_diff}, PeakTiming={pt_diff}")
    except Exception:
        print(f"Iteration {iter_idx:02d}, Gauge {station_label}: Metric calculation failed.")

    plt.title(f'EKI iteration {iter_idx:02d} - Gauge {station_label}')
    plt.xlabel('Time')
    plt.ylabel('Discharge (m^3/s)')
    plt.legend()
    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()

# ===================== 数据加载辅助函数 =====================
def load_ensemble(assimilation_phase, iter_idx):
    """
    根据同化阶段和迭代编号加载模拟粒子数据。

    参数:
      assimilation_phase: 字符串 'post' 或 'prior'
      iter_idx: 当前迭代编号（对于 post，iter_idx==0 使用先验文件）
    返回:
      模型模拟粒子数据数组。
    """
    if assimilation_phase == 'post':
        if iter_idx == 0:
            file_path = 'npy/0_prior_particles.npy'
        else:
            file_path = f'npy/{iter_idx - 1}_post_particles.npy'
    elif assimilation_phase == 'prior':
        file_path = f'npy/{iter_idx}_prior_particles.npy'
    else:
        raise ValueError("assimilation_phase must be 'post' or 'prior'")
    with open(file_path, 'rb') as f:
        return np.load(f)

# ===================== 动画生成函数 =====================
def generate_hydrograph_animation(num_iters, station_indices, station_names, measured_data, time_axis, assimilation_phase, base_output_dir):
    """
    生成并保存 hydrograph 动画 GIF。

    参数:
      num_iters: 同化迭代步数（post: num_iters+1；prior: num_iters）
      station_indices: 要绘制的站点在观测数据中的列索引列表
      station_names: 对应的站点名称列表（通常与 desired_usgs_ids 一致，即 gauge id）
      measured_data: 观测数据数组 (time_steps, num_stations)
      time_axis: 时间序列 (pandas DatetimeIndex)
      assimilation_phase: 'post' 或 'prior'
      base_output_dir: 顶层输出文件夹（例如 visualization）
    """
    hydrograph_frames_dir = os.path.join(base_output_dir, assimilation_phase, "hydrograph", "frames")
    hydrograph_anim_dir = os.path.join(base_output_dir, assimilation_phase, "hydrograph", "animation")
    clear_and_create_dir(hydrograph_frames_dir)
    clear_and_create_dir(hydrograph_anim_dir)
    
    iter_range = range(num_iters + 1) if assimilation_phase == 'post' else range(num_iters)
    
    # 遍历传入的站点索引及对应的站点名称（gauge id）
    for i, station_idx in enumerate(station_indices):
        station_label = station_names[i]
        frame_imgs = []
        for iter_idx in iter_range:
            plt.clf()
            ensemble_sim = load_ensemble(assimilation_phase, iter_idx)
            y_limits = [0, 3 * np.max(measured_data[:, station_idx])]
            draw_animation_frame(iter_idx, ensemble_sim, station_idx, time_axis, measured_data,
                                 station_label=station_label)
            plt.ylim(*y_limits)
            # 文件名：iter_XX_station_Y_hydrograph.png（iteration 保留两位，station 不填充）
            frame_filepath = os.path.join(hydrograph_frames_dir, f"iter_{iter_idx:02d}_gauge_{station_label}_hydrograph.png")
            plt.savefig(frame_filepath)
            frame_imgs.append(Image.open(frame_filepath))
        gif_filepath = os.path.join(hydrograph_anim_dir, f"gauge_{station_label}_hydrograph_animation.gif")
        frame_imgs[0].save(gif_filepath, save_all=True, append_images=frame_imgs[1:], duration=1000, loop=0)
        print(f"Animation saved to {gif_filepath}")

# ===================== 参数演变绘图函数 =====================
def plot_parameter_evolution(param_array, active_param_indices, param_labels, param_ranges, assimilation_phase, base_output_dir, iter_range):
    """
    绘制参数演变图，分别保存 ensemble 版本和 mean_std 版本。

    参数:
      param_array: 参数粒子数据数组，形状 (num_iters, num_active_params, num_stations, particle_dim)
      active_param_indices: 原始参数列表中参与同化的参数索引列表
      param_labels: 参数名称列表
      param_ranges: 参数取值范围列表
      assimilation_phase: 'post' 或 'prior'
      base_output_dir: 顶层输出文件夹（例如 visualization）
      iter_range: 迭代编号数组
    """
    param_ensemble_dir = os.path.join(base_output_dir, assimilation_phase, "parameter", "ensemble")
    param_mean_std_dir = os.path.join(base_output_dir, assimilation_phase, "parameter", "mean_std")
    clear_and_create_dir(param_ensemble_dir)
    clear_and_create_dir(param_mean_std_dir)
    
    num_iters = len(iter_range)
    num_stations = param_array.shape[2]
    
    # Ensemble 版本：保存所有粒子的轨迹
    for idx_active, orig_idx in enumerate(active_param_indices):
        for station_idx in range(num_stations):
            plt.figure()
            plt.plot(iter_range, param_array[:, idx_active, station_idx, :])
            plt.ylabel(param_labels[orig_idx])
            plt.xlabel('EKI Iterations')
            plt.ylim(*param_ranges[orig_idx])
            # 文件名：parameter_X_station_Y_ensemble.png
            out_path = os.path.join(param_ensemble_dir, f"parameter_{orig_idx}_station_{station_idx}_ensemble.png")
            plt.savefig(out_path)
            plt.close()
            print(f"Saved parameter ensemble plot {out_path}")
    
    # Mean-Std 版本：均值及标准差统计图
    for idx_active, orig_idx in enumerate(active_param_indices):
        for station_idx in range(num_stations):
            plt.figure()
            param_mean = np.mean(param_array[:, idx_active, station_idx, :], axis=1)
            param_std = np.std(param_array[:, idx_active, station_idx, :], axis=1)
            plt.plot(iter_range, param_mean, 'k-', lw=2, label='Mean')
            plt.fill_between(iter_range, param_mean - param_std, param_mean + param_std,
                             color='gray', alpha=0.3, label='Mean ± Std')
            plt.ylabel(param_labels[orig_idx])
            plt.xlabel('EKI Iterations')
            plt.ylim(*param_ranges[orig_idx])
            # 文件名：parameter_X_station_Y_mean_std.png
            out_path = os.path.join(param_mean_std_dir, f"parameter_{orig_idx}_station_{station_idx}_mean_std.png")
            plt.savefig(out_path)
            plt.close()
            print(f"Saved parameter mean-std plot {out_path}")

# ===================== 事件统计绘图函数 =====================
def plot_event_statistics(assimilation_phase, base_output_dir):
    """
    计算并绘制事件统计图，将三个统计量（峰值、均值、标准差）转移到图形中保存。

    这里假设从 CSV 文件中读取观测数据用于计算事件统计量，
    实际项目中可替换为更复杂的事件检测算法。
    """
    # 读取观测数据（假设同一文件）
    measured_data = np.genfromtxt("csv/meas_mean.csv", delimiter=',', skip_header=1)
    measured_data[measured_data == 0] = np.nan

    # 对每个站点计算统计量（这里简单采用整个时间序列的统计量作为示例）
    event_peaks = np.nanmax(measured_data, axis=0)
    event_means = np.nanmean(measured_data, axis=0)
    event_stds  = np.nanstd(measured_data, axis=0)

    event_stats_dir = os.path.join(base_output_dir, assimilation_phase, "event_statistics")
    clear_and_create_dir(event_stats_dir)

    stations = np.arange(measured_data.shape[1])
    plt.figure()
    plt.plot(stations, event_peaks, 'r-o', label="Peak")
    plt.xlabel("Station")
    plt.ylabel("Peak Value")
    plt.title(f"Event Peak Values ({assimilation_phase})")
    plt.legend()
    out_path = os.path.join(event_stats_dir, "event_peak.png")
    plt.savefig(out_path)
    plt.close()
    print(f"Saved event peak plot {out_path}")

    plt.figure()
    plt.plot(stations, event_means, 'g-o', label="Mean")
    plt.xlabel("Station")
    plt.ylabel("Mean Value")
    plt.title(f"Event Mean Values ({assimilation_phase})")
    plt.legend()
    out_path = os.path.join(event_stats_dir, "event_mean.png")
    plt.savefig(out_path)
    plt.close()
    print(f"Saved event mean plot {out_path}")

    plt.figure()
    plt.plot(stations, event_stds, 'b-o', label="Std")
    plt.xlabel("Station")
    plt.ylabel("Standard Deviation")
    plt.title(f"Event Standard Deviation ({assimilation_phase})")
    plt.legend()
    out_path = os.path.join(event_stats_dir, "event_std.png")
    plt.savefig(out_path)
    plt.close()
    print(f"Saved event std plot {out_path}")



In [8]:
# ===================== 主流程 =====================
def main_visualization():
    base_output_dir = "visualization"
    assimilation_phases = ['prior', 'post']
    
    test_dict = process_yaml("test_config.j2")
    start_time_str = test_dict["time_start"]
    end_time_str = test_dict["time_end"]
    time_axis = pd.date_range(start=start_time_str, end=end_time_str, freq='H')

    num_assimilation_steps = test_dict["steps"]
    max_station_count = 5  # 默认值，当 desired_usgs_ids 未找到时使用

    # 读取 desired_usgs_ids（即 gauge id），确保为列表格式
    desired_usgs_ids = test_dict.get("meas_usgs", [])
    if isinstance(desired_usgs_ids, str):
        desired_usgs_ids = [desired_usgs_ids]
        
    # 读取 USGS 映射关系（调整路径根据实际情况）
    usgs_2_id, id_2_usgs, file_order = load_usgs_mapping_from_path("../../" + test_dict["usgs_csv"])
    
    # 根据 desired_usgs_ids 计算绘图所用的站点索引和 gauge 名称
    plot_station_indices = []
    plot_station_names = []
    for usgs in desired_usgs_ids:
        link_id = usgs_2_id.get(usgs)
        if link_id is None:
            print(f"Warning: USGS ID {usgs} not found in mapping.")
            continue
        idx_arr = np.where(file_order == link_id)[0]
        if idx_arr.size > 0:
            plot_station_indices.append(idx_arr[0])
            plot_station_names.append(usgs)  # 使用 USGS ID 作为 gauge id 名称
        else:
            print(f"Warning: Link id {link_id} for USGS {usgs} not found in file_order.")
    if not plot_station_indices:
        print("No desired station indices found, using default range.")
        plot_station_indices = list(range(max_station_count))
        plot_station_names = [str(i) for i in range(max_station_count)]
    
    observed_data = np.genfromtxt("csv/meas_mean.csv", delimiter=',', skip_header=1)
    observed_data_clean = observed_data.copy()
    observed_data[observed_data == 0] = np.nan

    # 参数设置: 对已经提取的数据进行“取名”和“限制显示范围”
    param_labels = ["$Cr$"]
    param_ranges = [[0.00, 1.0]]
    active_param_indices = [0]

    # -------------------- 后验 (post) 部分 --------------------
    post_param_list = []
    for i in range(num_assimilation_steps + 1):
        if i > 0:
            file_path = f'npy/{i-1}_post_params_particles.npy'
        else:
            file_path = 'npy/0_prior_params_particles.npy'
        with open(file_path, 'rb') as f:
            post_param_list.append(np.load(f))
    post_param_array = np.stack(post_param_list, axis=0)
    post_param_array = post_param_array.reshape(num_assimilation_steps + 1,
                                                 len(active_param_indices),
                                                 -1,
                                                 post_param_array.shape[-1])
    iter_range_post = np.arange(0, num_assimilation_steps + 1)
    
    # 调用时，将 station_indices 和 station_names 一并传入
    generate_hydrograph_animation(num_assimilation_steps, plot_station_indices, plot_station_names,
                                  observed_data_clean, time_axis,
                                  assimilation_phase='post', base_output_dir=base_output_dir)
    plot_parameter_evolution(post_param_array, active_param_indices, param_labels, param_ranges,
                             assimilation_phase='post', base_output_dir=base_output_dir, iter_range=iter_range_post)
    plot_event_statistics('post', base_output_dir)

    # -------------------- 先验 (prior) 部分 --------------------
    prior_param_list = []
    for i in range(num_assimilation_steps):
        file_path = f'npy/{i}_prior_params_particles.npy'
        with open(file_path, 'rb') as f:
            prior_param_list.append(np.load(f))
    prior_param_array = np.stack(prior_param_list, axis=0)
    prior_param_array = prior_param_array.reshape(num_assimilation_steps,
                                                   len(active_param_indices),
                                                   -1,
                                                   prior_param_array.shape[-1])
    iter_range_prior = np.arange(0, num_assimilation_steps)
    generate_hydrograph_animation(num_assimilation_steps, plot_station_indices, plot_station_names,
                                  observed_data_clean, time_axis,
                                  assimilation_phase='prior', base_output_dir=base_output_dir)
    plot_parameter_evolution(prior_param_array, active_param_indices, param_labels, param_ranges,
                             assimilation_phase='prior', base_output_dir=base_output_dir, iter_range=iter_range_prior)
    plot_event_statistics('prior', base_output_dir)
    
    plt.close('all')
    print("Visualization complete.")


if __name__ == '__main__':
    main_visualization()


  time_axis = pd.date_range(start=start_time_str, end=end_time_str, freq='H')


Iteration 00, Gauge 5570910: KGE=[-4.28385832], PeakDiff=-0.2528212363524527, PeakTiming=1
Iteration 01, Gauge 5570910: KGE=[-11.45155087], PeakDiff=2.353718283868983, PeakTiming=4
Iteration 02, Gauge 5570910: KGE=[-7.28569755], PeakDiff=0.9784061202521912, PeakTiming=6
Iteration 03, Gauge 5570910: KGE=[-7.62073893], PeakDiff=1.1082250499769337, PeakTiming=5
Iteration 04, Gauge 5570910: KGE=[-6.8218508], PeakDiff=0.7888858988159311, PeakTiming=6
Iteration 05, Gauge 5570910: KGE=[-6.90565835], PeakDiff=0.8242807165923419, PeakTiming=6
Iteration 06, Gauge 5570910: KGE=[-6.53182793], PeakDiff=0.6635944948485315, PeakTiming=6
Iteration 07, Gauge 5570910: KGE=[-6.62273629], PeakDiff=0.703463017069045, PeakTiming=6
Iteration 08, Gauge 5570910: KGE=[-6.43504462], PeakDiff=0.6207488851299401, PeakTiming=6
Iteration 09, Gauge 5570910: KGE=[-6.58453504], PeakDiff=0.6867868675995694, PeakTiming=6
Iteration 10, Gauge 5570910: KGE=[-6.39710911], PeakDiff=0.6044391050284484, PeakTiming=7
Animation s

  event_peaks = np.nanmax(measured_data, axis=0)
  event_means = np.nanmean(measured_data, axis=0)
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


Saved event mean plot visualization/post/event_statistics/event_mean.png
Saved event std plot visualization/post/event_statistics/event_std.png
Iteration 00, Gauge 5570910: KGE=[-4.28385832], PeakDiff=-0.2528212363524527, PeakTiming=1
Iteration 01, Gauge 5570910: KGE=[-11.45155087], PeakDiff=2.353718283868983, PeakTiming=4
Iteration 02, Gauge 5570910: KGE=[-7.28569755], PeakDiff=0.9784061202521912, PeakTiming=6
Iteration 03, Gauge 5570910: KGE=[-7.62073893], PeakDiff=1.1082250499769337, PeakTiming=5
Iteration 04, Gauge 5570910: KGE=[-6.8218508], PeakDiff=0.7888858988159311, PeakTiming=6
Iteration 05, Gauge 5570910: KGE=[-6.90565835], PeakDiff=0.8242807165923419, PeakTiming=6
Iteration 06, Gauge 5570910: KGE=[-6.53182793], PeakDiff=0.6635944948485315, PeakTiming=6
Iteration 07, Gauge 5570910: KGE=[-6.62273629], PeakDiff=0.703463017069045, PeakTiming=6
Iteration 08, Gauge 5570910: KGE=[-6.43504462], PeakDiff=0.6207488851299401, PeakTiming=6
Iteration 09, Gauge 5570910: KGE=[-6.58453504],