In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
def plot_heatmaps(checkpoint_path):
    # 加载数据
    salt_pred = torch.load(f"{checkpoint_path}/salt_pred.pt").numpy()
    salt_real = torch.load(f"{checkpoint_path}/salt_real.pt").numpy()
    temp_pred = torch.load(f"{checkpoint_path}/temperature_pred.pt").numpy()
    temp_real = torch.load(f"{checkpoint_path}/temperature_real.pt").numpy()
    all_pred = torch.load(f"{checkpoint_path}/all_pred.pt").numpy()
    all_real = torch.load(f"{checkpoint_path}/all_real.pt").numpy()

    # 计算误差矩阵
    def process_data(pred, real, is_all=False):
        if is_all:
            # all数据需要分解salt和temperature
            salt_err = np.abs(pred[..., 0] - real[..., 0])  # [716, 12, 576]
            temp_err = np.abs(pred[..., 1] - real[..., 1])
            err = (salt_err + temp_err) / 2  # 平均误差
        else:
            err = np.abs(pred - real).squeeze()  # [716, 12, 576]
        
        # 按时间步聚合（取样本平均）
        err_mean = err.mean(axis=0)  # [12, 576]
        return err_mean.reshape(12, 24, 24)  # 转为网格

    # 处理三组数据
    salt_err = process_data(salt_pred, salt_real)
    temp_err = process_data(temp_pred, temp_real)
    all_err = process_data(all_pred, all_real, is_all=True)

    # 创建三张画布
    variables = ['Salt Error', 'Temperature Error', 'Combined Error']
    data_list = [salt_err, temp_err, all_err]
    
    for var_name, data in zip(variables, data_list):
        plt.figure(figsize=(24, 18))
        plt.suptitle(var_name, fontsize=16, y=0.95)
        
        # 统一颜色范围
        vmax = np.max(data)
        vmin = np.min(data)
        
        # 绘制12个子图
        for t in range(12):
            plt.subplot(3, 4, t+1)
            sns.heatmap(data[t], cmap='viridis', 
                       vmin=vmin, vmax=vmax,
                       cbar_kws={'label': 'Error Value'})
            plt.title(f'Time Step {t+1}')
            plt.xlabel('X Grid')
            plt.ylabel('Y Grid')
            
        plt.tight_layout()
        plt.savefig(f"{checkpoint_path}/{var_name.replace(' ', '_')}_heatmap.png")
        plt.close()

In [3]:
if __name__ == "__main__":
    # 使用示例
    checkpoint_path = "./logs/2025-04-13-12-03-12-eastsea"  # 替换为实际路径
    plot_heatmaps(checkpoint_path)