In [1]:
%reload_ext autoreload
%autoreload 2
import sys
import arrow
import numpy as np
from torch import load
from torch.utils.data import DataLoader

sys.path.append('/home/morisi/Workspace/3D-Ocean')

from src.config.params import PROJECT_PATH
from src.plot.profile import plot_3d_temperature_error
from src.config.area import Area
from src.trainer.base import BaseTrainer
from src.dataset.Argo import Argo3DTemperatureDataset, ArgoDepthMap
from src.models.Profile.UNet3DReconstructor import UNet3DReconstructor



In [2]:
import uuid

area = Area("global_rf_model", [-180, 180], [-80, 80], "global")

depth = [0, 10]

dataset_params = {
    "depth": depth,
    "resolution": 2
}

# 训练器参数
trainer_params = {
    "epochs": 300,
    "batch_size": 32,
    "num_workers": 12,
}

# Unet-3D 模型参数
model_params = {
    "n_channels": 1,
    "n_depth": 10,
    "base_channels": 128,
    "learning_rate": 1e-3
}

# 使用动态生成的 UUID 避免 wandb run id 冲突
run_id = f"rf_{str(uuid.uuid4())[:8]}"
print(f"Training run ID: {run_id}")

rf_trainer = BaseTrainer(
    uid=run_id,
    title="rf",
    area=area,
    model_class=UNet3DReconstructor,
    dataset_class=Argo3DTemperatureDataset,
    save_path=f"{PROJECT_PATH}/out/models/unet-3d.pkl",
    use_checkpoint=False,
    dataset_params=dataset_params,
    trainer_params=trainer_params,
    model_params=model_params,
    use_wandb=False
)


Training run ID: rf_ef121da7


In [3]:
# model = rf_trainer.train()
model = None

In [None]:
from plot.profile import plot_3d_temperature


if model is None:
    model = load(rf_trainer.save_path, weights_only=False)


sst_model = load(f"{PROJECT_PATH}/out/models/seq_len-2/conv.pkl", weights_only=False)

month = 246

time = arrow.get('2004-01-01').shift(months=month).format('YYYYMM')

print(f'{time}')

dataset = Argo3DTemperatureDataset(
    lon=area.lon,
    lat=area.lat,
    offset=month,
    **dataset_params
)

loader = DataLoader(dataset, batch_size=1, shuffle=False)

sst, temp = next(iter(loader))

f_sst = sst.reshape((1, 1, 1, sst.shape[1], sst.shape[2]))
n_sst = sst_model(f_sst)

diff = (n_sst - f_sst).detach().numpy()

rmse = np.sqrt(np.nanmean(diff ** 2))

print('sst rmse:', rmse)

profile = model.predict(f_sst[0, 0, :, :, :]).reshape((sst.shape[1], sst.shape[2], depth[1] - depth[0]))
n_profile = model.predict(n_sst[0, :, :, :]).reshape((sst.shape[1], sst.shape[2], depth[1] - depth[0]))

# 原始真值数据（包含 NaN 作为陆地 mask）
orgin_profile = temp.detach().numpy()[0, :, :, :]

# 计算误差（NaN 会自动传播，陆地区域保持为 NaN）
diff = profile - orgin_profile

rmse = np.sqrt(np.nanmean(diff ** 2))

print(f'rmse: {rmse}')

# diff 的形状是 (lat, lon, depth)，需要转置为 (lon, lat, depth)
# 因为 plot_3d_temperature_error 期望的格式是 (lon, lat, depth)
diff_transposed = np.transpose(diff, (1, 0, 2))

print(f'原始 diff.shape: {diff.shape} (lat, lon, depth)')
print(f'转置后 diff_transposed.shape: {diff_transposed.shape} (lon, lat, depth)')
print(f'diff 统计: min={np.nanmin(diff):.4f}, max={np.nanmax(diff):.4f}, nan_count={np.isnan(diff).sum()}/{diff.size}')

# 使用优化的误差色标绘制3维温度场误差（蓝-白-红）
# 绘图函数内部会自动处理插值时的维度转换

plot_3d_temperature(
    orgin_profile,
    area.lon,
    area.lat,
    ArgoDepthMap.get(depth),
    step=2,
    interpolate=True,
    interp_interval=0.2,
    interp_method='pchip'
)


plot_3d_temperature(
    profile,
    area.lon,
    area.lat,
    ArgoDepthMap.get(depth),
    step=2,
    interpolate=True,
    interp_interval=0.2,
    interp_method='pchip'
)

plot_3d_temperature(
    n_profile,
    area.lon,
    area.lat,
    ArgoDepthMap.get(depth),
    step=2,
    interpolate=True,
    interp_interval=0.2,
    interp_method='pchip'
)

plot_3d_temperature_error(
    diff_transposed,
    area.lon,
    area.lat,
    ArgoDepthMap.get(depth),
    step=2,
    filename='3d_profile_error.png',
    title='3D Temperature Profile Error',
    interpolate=True,
    interp_interval=0.2,
    interp_method='pchip'
)


202311
sst rmse: 1.6185969
rmse: 0.7342395782470703
原始 diff.shape: (80, 180, 10) (lat, lon, depth)
转置后 diff_transposed.shape: (180, 80, 10) (lon, lat, depth)
diff 统计: min=-7.2635, max=10.8126, nan_count=61250/144000
[92mINFO: [0m     DepthInterpolator 初始化:
[92mINFO: [0m       原始深度层数: 10
[92mINFO: [0m       原始深度范围: 0m - 80m
[92mINFO: [0m       目标插值间隔: 0.2m
[92mINFO: [0m       目标深度层数: 401
[92mINFO: [0m       插值方法: pchip
[90mDEBUG: [0m     插值后数据形状: (401, 180, 80)
[90mDEBUG: [0m     转置回后数据形状: (80, 180, 401)
[92mINFO: [0m     已应用深度插值: 401 层，间隔 0.2m
[92mINFO: [0m     DepthInterpolator 初始化:
[92mINFO: [0m       原始深度层数: 10
[92mINFO: [0m       原始深度范围: 0m - 80m
[92mINFO: [0m       目标插值间隔: 0.2m
[92mINFO: [0m       目标深度层数: 401
[92mINFO: [0m       插值方法: pchip
[90mDEBUG: [0m     插值后数据形状: (401, 180, 80)
[90mDEBUG: [0m     转置回后数据形状: (80, 180, 401)
[92mINFO: [0m     已应用深度插值: 401 层，间隔 0.2m
[92mINFO: [0m     DepthInterpolator 初始化:
[92mINFO: [0m       原始深度层数: 10
[92