In [3]:
import torch

def compute_metrics(predictions, ground_truth, metric_type="hs"):
    """
    计算RMSE和CRPS的评估指标。

    参数:
    - predictions: torch.Tensor, 模型的预测值，形状为 [10200, 4, 128, 128]
    - ground_truth: torch.Tensor, 真实值，形状为 [10200, 4, 128, 128]
    - metric_type: str, 指标类型，'hs' 或 'tm' 使用直接方法，'dirm' 使用角度处理

    返回:
    - rmse: float, 均方根误差
    - crps: float, 连续分级概率得分
    """
    # 确保预测值和真实值形状一致
    assert predictions.shape == ground_truth.shape, "预测值和真实值的形状必须一致"
    
    # 获取所有网格单元的总数
    n = predictions.numel()  # 总元素数量

    # ---- 计算 RMSE ----
    if metric_type == "dirm":  # 特殊处理方向角
        error = torch.min(
            torch.abs(predictions - ground_truth), 
            360 - torch.abs(predictions - ground_truth)
        )
    else:
        error = predictions - ground_truth
    rmse = torch.sqrt(torch.mean(error**2)).item()

    # ---- 计算 CRPS ----
    if metric_type == "dirm":
        abs_diff = torch.min(
            torch.abs(predictions.unsqueeze(1) - ground_truth.unsqueeze(1)), 
            360 - torch.abs(predictions.unsqueeze(1) - ground_truth.unsqueeze(1))
        )
    else:
        abs_diff = torch.abs(predictions.unsqueeze(1) - ground_truth.unsqueeze(1))
    
    # 模拟生成N个样本
    N = 10
    ensemble_samples = torch.rand_like(predictions.unsqueeze(1)).repeat(1, N, 1, 1, 1)  # 随机生成
    diff_samples = torch.abs(ensemble_samples - predictions.unsqueeze(1))
    
    # 计算CRPS的两部分
    part1 = torch.mean(abs_diff, dim=1)
    part2 = torch.mean(diff_samples, dim=1)
    crps = torch.mean(part1 - 0.5 * part2).item()

    return rmse, crps


In [4]:
import wave_filed_data_prepare
import Utils
import numpy as np
# predictions = np.random.rand(102, 4, 128, 128)
# ground_truths = np.random.rand(102, 4, 128, 128)
ground_truths = wave_filed_data_prepare.combine_monthly_data("/home/hy4080/met_waves/Swan_cropped/swanSula", 2019, 2021).numpy()
predictions=np.load("data/generated_wave_fields.npy")
# 计算RMSE和CRPS
print(ground_truths.shape,predictions.shape)
ground_truths=torch.tensor(ground_truths)
predictions=torch.tensor(predictions)
rmse_hs, crps_hs = compute_metrics(predictions, ground_truths, metric_type="hs")
rmse_dirm, crps_dirm = compute_metrics(predictions, ground_truths, metric_type="dirm")

print(f"RMSE (hs): {rmse_hs:.4f}, CRPS (hs): {crps_hs:.4f}")
print(f"RMSE (dirm): {rmse_dirm:.4f}, CRPS (dirm): {crps_dirm:.4f}")


Processing file for 201901...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201901_cropped.nc 处理后形状: (744, 128, 128, 4)
Processing file for 201902...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201902_cropped.nc 处理后形状: (672, 128, 128, 4)
Processing file for 201903...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201903_cropped.nc 处理后形状: (744, 128, 128, 4)
Processing file for 201904...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201904_cropped.nc 处理后形状: (720, 128, 128, 4)
Processing file for 201905...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201905_cropped.nc 处理后形状: (744, 128, 128, 4)
Processing file for 201906...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201906_cropped.nc 处理后形状: (720, 128, 128, 4)
Processing file for 201907...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201907_cropped.nc 处理后形状: (744, 128, 128, 4)
Processing file for 201908...
文件 /home/hy4080/met_waves/Swan_cropped/swanSula201908_cropped.nc 处理后形状: (744, 128, 128, 4)
Processing file for 201909...
文件

In [8]:
import numpy as np
import matplotlib.pyplot as plt
# 加载数据
predictions_np = predictions[0:20,:,:,:]
ground_truth_np = ground_truths[0:20,:,:,:]

print(predictions_np[0,0,10:13,0:3],ground_truth_np[0,0,10:13,0:3])


tensor([[ 0.6268,  0.2383,  0.7844],
        [ 0.4868, -0.2833, -0.4021],
        [ 0.1106,  0.0072, -0.1512]]) tensor([[1.0000e+00, 9.2519e-19, 1.0000e+00],
        [9.2519e-19, 9.2519e-19, 9.2519e-19],
        [9.2519e-19, 9.2519e-19, 9.2519e-19]])
