# SST Predition Project
- 海表温度 (SST)
- 这个文件作为一个简单开始，搭建一个简单的预测框架，仅利用ERA5的数据

In [20]:
# 验证cuda安装成功
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("GPU:", torch.cuda.get_device_name(0))


PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
GPU: NVIDIA GeForce RTX 4060 Laptop GPU


In [21]:
# -----------------------------------------------------------------------------
# 模块 0: 全局配置和导入
# -----------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd
import xarray as xr
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import os
import glob
from datetime import datetime, timedelta

# --- 数据相关配置 ---
DATA_DIR = '../data/ERA5/'
START_DATE_STR = '2020-01-01'
END_DATE_STR = '2020-03-31' # 可以调整为更长的时间范围以获得更多数据
FILE_PREFIX = ''
FILE_SUFFIX = '_ERA5_daily_mean_sst.nc'
FILENAME_DATE_FORMAT = '%Y%m%d'

VARIABLE_NAME = 'sst'
LATITUDE_POINT = 30.0
LONGITUDE_POINT = 120.0

# --- 模型相关配置 ---
LOOK_BACK = 15          # 使用过去N天的数据 (如果数据量少，尝试减小此值)
PREDICT_STEPS = 1       # 预测未来N天
HIDDEN_SIZE = 64
NUM_LAYERS = 2
DROPOUT_RATE = 0.2

# --- 训练相关配置 ---
NUM_EPOCHS = 50
BATCH_SIZE = 16         # 如果数据量少，尝试减小此值
LEARNING_RATE = 0.001
TRAIN_SPLIT_RATIO = 0.8 # 80% 训练, 20% 测试
BEST_MODEL_PATH = 'model/best_sst_model.pth'

# --- 设备配置 ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- 随机种子 ---
SEED = 42

def set_seed(seed_value):
    """设置随机种子以保证结果可复现"""
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)

print(f"Using device: {DEVICE}")
set_seed(SEED)

Using device: cuda


In [22]:
# -----------------------------------------------------------------------------
# 模块 1: 数据处理模块
# -----------------------------------------------------------------------------

def get_file_paths_in_range(data_dir, start_date_str, end_date_str,
                            file_prefix, file_suffix, date_format):
    """获取指定日期范围内的所有NetCDF文件路径列表"""
    start_dt = datetime.strptime(start_date_str, '%Y-%m-%d')
    end_dt = datetime.strptime(end_date_str, '%Y-%m-%d')
    
    file_paths = []
    current_dt = start_dt
    while current_dt <= end_dt:
        date_str_in_filename = current_dt.strftime(date_format)
        filename = f"{file_prefix}{date_str_in_filename}{file_suffix}"
        file_path = os.path.join(data_dir, filename)
        if os.path.exists(file_path):
            file_paths.append(file_path)
        else:
            print(f"警告: 文件 {file_path} 未找到，将被跳过。")
        current_dt += timedelta(days=1)
        
    if not file_paths:
        print(f"错误: 在目录 '{data_dir}' 中未找到 {start_date_str} 到 {end_date_str} 范围内的任何文件。")
        return None
    file_paths.sort()
    return file_paths

def load_and_preprocess_sst_data(file_paths, lat, lon, var_name):
    """加载并预处理SST数据"""
    if not file_paths:
        return None, None

    try:
        print(f"找到 {len(file_paths)} 个文件，将使用 xarray.open_mfdataset 进行合并。")
        ds_combined = xr.open_mfdataset(file_paths, combine='by_coords', engine='netcdf4')
        
        print("\n合并后的数据集信息:")
        print(ds_combined)

        sst_series_xr = ds_combined[var_name].sel(latitude=lat, longitude=lon, method='nearest')
        
        time_coords = sst_series_xr['time'].values
        sst_values = sst_series_xr.values

        if sst_series_xr.attrs.get('units', '').lower() == 'k':
            sst_values = sst_values - 273.15
            print("SST数据已从开尔文转换为摄氏度。")

        sst_df = pd.Series(sst_values)
        nan_count = sst_df.isnull().sum()
        if nan_count > 0:
            print(f"数据中发现 {nan_count} 个NaN值，将使用前向填充然后后向填充处理。")
            sst_df = sst_df.fillna(method='ffill').fillna(method='bfill')
            sst_values = sst_df.values
            if pd.Series(sst_values).isnull().sum() > 0:
                raise ValueError("处理后仍存在NaN值，请检查数据源。")
        
        min_required_length = LOOK_BACK + PREDICT_STEPS + BATCH_SIZE # 至少需要一个batch用于测试
        if len(sst_values) < min_required_length:
             print(f"警告: 数据点较少 ({len(sst_values)}), 最小需求约为 {min_required_length}。可能不足以进行有效的训练和测试。")
             if len(sst_values) <= LOOK_BACK + PREDICT_STEPS:
                 raise ValueError(f"数据点 ({len(sst_values)}) 过少，无法满足 look_back和predict_steps的要求。")
        return sst_values, time_coords
    except Exception as e:
        print(f"加载或处理数据时发生错误: {e}")
        return None, None

def preview_raw_data(sst_values, time_coords, lat, lon, start_date_str, end_date_str):
    """预览原始SST数据"""
    if sst_values is None or len(sst_values) == 0:
        print("无有效SST数据，无法进行预览。")
        return

    plt.figure(figsize=(14, 6))
    if time_coords is not None and len(time_coords) == len(sst_values):
        plt.plot(time_coords, sst_values, label=f'SST at ({lat}, {lon})')
        plt.xlabel('日期 (Date)')
    else:
        plt.plot(sst_values, label=f'SST at ({lat}, {lon})')
        plt.xlabel('时间步 (Time Step)')
    plt.title(f'原始SST时间序列预览 ({start_date_str} to {end_date_str})')
    plt.ylabel('SST (°C)')
    plt.legend()
    plt.grid(True)
    plt.show()

def scale_data(sst_values):
    """归一化数据"""
    scaler = MinMaxScaler(feature_range=(0, 1))
    sst_scaled = scaler.fit_transform(sst_values.reshape(-1, 1))
    return sst_scaled, scaler

def create_sequences(data, look_back, predict_steps):
    """创建时间序列样本"""
    X, y = [], []
    if len(data) < look_back + predict_steps:
        print(f"数据长度 {len(data)} 不足以创建 look_back={look_back}, predict_steps={predict_steps} 的序列。")
        return np.array(X), np.array(y)
    for i in range(len(data) - look_back - predict_steps + 1):
        X.append(data[i:(i + look_back), 0])
        y.append(data[(i + look_back):(i + look_back + predict_steps), 0])
    return np.array(X), np.array(y)

def split_and_prepare_data(X_data, y_data, train_split_ratio, batch_size, look_back):
    """划分数据并创建DataLoader"""
    if X_data.size == 0 or y_data.size == 0:
        print("未能从数据中创建任何训练/测试序列。")
        return None, None, None, None # train_loader, test_loader, X_test_tensor, y_test_tensor

    split_index = int(len(X_data) * train_split_ratio)
    
    # 确保训练集和测试集有足够的数据，特别是对于小数据集
    min_samples_for_test = max(1, batch_size // 2) # 测试集至少有 batch_size/2 个样本或1个
    if len(X_data) - split_index < min_samples_for_test : # 如果测试集样本太少
        if len(X_data) > look_back + min_samples_for_test : # 确保总数据量足够划分
            split_index = len(X_data) - min_samples_for_test # 调整分割点，给测试集留足样本
            print(f"调整了训练/测试分割点，以确保测试集至少有 {min_samples_for_test} 个样本。新的分割索引：{split_index}")
        else:
            print("警告：总数据量过少，无法划分出有意义的测试集。将使用所有数据进行训练，评估将不可靠。")
            split_index = len(X_data) # 所有数据用于训练

    if split_index == 0 and len(X_data) > 0: # 如果训练集为空
        print("警告：计算得到的训练集为空，将尝试调整。")
        if len(X_data) > min_samples_for_test: # 如果总数据多于最小测试样本
            split_index = len(X_data) - min_samples_for_test
        elif len(X_data) > 0: # 如果总数据少于最小测试，但仍有数据
            split_index = len(X_data) # 全用于训练
        else: # 没数据了
            print("错误：没有序列数据可用于训练。")
            return None, None, None, None
            
    X_train, X_test = X_data[:split_index], X_data[split_index:]
    y_train, y_test = y_data[:split_index], y_data[split_index:]

    print(f"训练集序列数: {len(X_train)}, 测试集序列数: {len(X_test)}")

    if len(X_train) == 0:
        print("错误：训练集为空，无法继续。")
        return None, None, None, None

    X_train_tensor = torch.from_numpy(X_train).float().unsqueeze(-1)
    y_train_tensor = torch.from_numpy(y_train).float()
    train_dataset = SSTDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    X_test_tensor, y_test_tensor, test_loader = None, None, None
    if len(X_test) > 0:
        X_test_tensor = torch.from_numpy(X_test).float().unsqueeze(-1)
        y_test_tensor = torch.from_numpy(y_test).float()
        test_dataset = SSTDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    else:
        print("警告: 测试集为空。模型将仅在训练数据上训练，无法进行独立验证或最终评估。")

    return train_loader, test_loader, X_test_tensor, y_test_tensor

In [23]:
# -----------------------------------------------------------------------------
# 模块 2: 模型定义模块
# -----------------------------------------------------------------------------

class SSTDataset(Dataset):
    """自定义SST数据集"""
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class LSTMModel(nn.Module):
    """LSTM模型"""
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Dropout只在num_layers > 1时应用于LSTM层之间
        lstm_dropout = dropout_rate if num_layers > 1 else 0
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                            batch_first=True, dropout=lstm_dropout)
        # 可以在全连接层前加一个Dropout
        # self.dropout = nn.Dropout(dropout_rate) 
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        out, _ = self.lstm(x, (h0, c0))
        # out = self.dropout(out[:, -1, :]) # 应用dropout
        out = self.fc(out[:, -1, :]) # 只取序列中最后一个时间步的输出来进行预测
        return out

In [24]:
# -----------------------------------------------------------------------------
# 模块 3: 训练模块
# -----------------------------------------------------------------------------

def train_model_epoch(model, data_loader, criterion, optimizer, device):
    """执行一个epoch的训练"""
    model.train()
    epoch_loss = 0
    for batch_X, batch_y in data_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

def validate_model_epoch(model, data_loader, criterion, device):
    """执行一个epoch的验证"""
    model.eval()
    epoch_loss = 0
    if data_loader is None or len(data_loader) == 0:
        return float('nan') # 或其他表示不可用的值

    with torch.no_grad():
        for batch_X, batch_y in data_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

def run_training_loop(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, best_model_path):
    """运行完整的训练循环"""
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')

    print("\n开始训练模型...")
    for epoch in range(num_epochs):
        avg_train_loss = train_model_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(avg_train_loss)

        avg_val_loss = validate_model_epoch(model, val_loader, criterion, device) # val_loader可能是test_loader
        val_losses.append(avg_val_loss) # 即使是NaN也记录下来

        val_loss_str = f"{avg_val_loss:.6f}" if not np.isnan(avg_val_loss) else "N/A (无验证集)"
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss_str}")

        if not np.isnan(avg_val_loss) and avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), best_model_path)
            print(f"  验证损失改善，模型已保存至 {best_model_path}")
        elif np.isnan(avg_val_loss) and ((epoch + 1) % 10 == 0 or (epoch + 1) == num_epochs): # 无验证集时定期保存
            fallback_model_path = f"sst_model_epoch_{epoch+1}.pth"
            torch.save(model.state_dict(), fallback_model_path)
            print(f"  模型已保存至 {fallback_model_path} (无验证，按轮次保存)")
            
    print("训练完成!")
    return train_losses, val_losses

In [25]:
# -----------------------------------------------------------------------------
# 模块 4: 评估模块
# -----------------------------------------------------------------------------

def evaluate_model(model, test_loader, criterion, device, scaler, best_model_path):
    """在测试集上评估模型"""
    if test_loader is None or len(test_loader) == 0:
        print("\n无测试数据，跳过模型评估。")
        return None, None, float('nan')

    try:
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        print(f"\n已加载最佳模型 {best_model_path} 进行评估...")
    except FileNotFoundError:
        print(f"警告: 未找到最佳模型文件 {best_model_path}。将使用当前模型状态进行评估 (可能不是最佳)。")
    except Exception as e:
        print(f"加载模型时出错: {e}。将使用当前模型状态进行评估。")
        
    model.eval()
    all_test_predictions_scaled = []
    all_test_actuals_scaled = []

    with torch.no_grad():
        for batch_X_test, batch_y_test in test_loader:
            batch_X_test = batch_X_test.to(device)
            predictions_scaled = model(batch_X_test)
            all_test_predictions_scaled.append(predictions_scaled.cpu().numpy())
            all_test_actuals_scaled.append(batch_y_test.numpy()) # y已经是cpu tensor

    if not all_test_predictions_scaled: # 如果测试集为空或未能进行预测
        print("未能生成测试集预测。")
        return None, None, float('nan')

    test_predictions_s = np.concatenate(all_test_predictions_scaled)
    test_actuals_s = np.concatenate(all_test_actuals_scaled)
    
    test_predictions_orig = scaler.inverse_transform(test_predictions_s)
    test_actuals_orig = scaler.inverse_transform(test_actuals_s)
    
    rmse = np.sqrt(mean_squared_error(test_actuals_orig, test_predictions_orig))
    print(f"测试集均方根误差 (RMSE): {rmse:.4f} °C")
    
    return test_predictions_orig, test_actuals_orig, rmse

In [26]:
# -----------------------------------------------------------------------------
# 模块 5: 可视化模块
# -----------------------------------------------------------------------------

def plot_losses(train_losses, val_losses):
    """绘制训练和验证损失曲线"""
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='训练损失 (Training Loss)')
    # 过滤掉NaN的验证损失值再绘图
    val_losses_filtered = [v for v in val_losses if not np.isnan(v)]
    if val_losses_filtered:
        plt.plot(range(len(train_losses) - len(val_losses_filtered), len(train_losses)), # 确保对齐
                 val_losses_filtered, label='验证损失 (Validation Loss)')
    plt.title('模型训练和验证损失')
    plt.xlabel('迭代次数 (Epoch)')
    plt.ylabel('损失 (MSE)')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_test_predictions(test_actuals_orig, test_predictions_orig, predict_steps, y_test_len_before_concat):
    """绘制测试集上的预测 vs 真实值"""
    if test_actuals_orig is None or test_predictions_orig is None:
        print("无测试结果可供绘制。")
        return

    plt.figure(figsize=(15, 7))
    # y_test_len_before_concat 是原始y_test的长度，用于生成正确的x轴标签
    time_steps_test_plot = np.arange(y_test_len_before_concat)


    if predict_steps == 1:
        plt.plot(time_steps_test_plot, test_actuals_orig.flatten()[:y_test_len_before_concat], label='真实SST (Actual Test SST)', color='blue')
        plt.plot(time_steps_test_plot, test_predictions_orig.flatten()[:y_test_len_before_concat], label='预测SST (Predicted Test SST)', color='red', linestyle='--')
    else: # 多步预测，只画第一步
        plt.plot(time_steps_test_plot, test_actuals_orig[:y_test_len_before_concat, 0], label='真实SST (Actual Test SST - Step 1)', color='blue')
        plt.plot(time_steps_test_plot, test_predictions_orig[:y_test_len_before_concat, 0], label='预测SST (Predicted Test SST - Step 1)', color='red', linestyle='--')
    
    plt.title(f'测试集SST预测 vs 真实值 (未来 {predict_steps} 天)')
    plt.xlabel(f'测试集样本索引') # X轴现在是测试集样本的索引
    plt.ylabel('SST (°C)')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_predictions_on_original(original_sst_values, original_time_coords,
                                 test_predictions_orig, test_actuals_orig,
                                 split_index_in_sequences, look_back, predict_steps):
    """将测试集预测叠加在原始数据图上"""
    if test_predictions_orig is None or test_actuals_orig is None:
        print("无测试预测结果，无法绘制叠加图。")
        return

    plt.figure(figsize=(18, 8))
    
    # 绘制原始数据
    if original_time_coords is not None and len(original_time_coords) == len(original_sst_values):
        plt.plot(original_time_coords, original_sst_values, label='原始SST数据', color='grey', alpha=0.7)
        x_axis_is_time = True
    else:
        plt.plot(original_sst_values, label='原始SST数据', color='grey', alpha=0.7)
        x_axis_is_time = False

    # 计算测试集预测在原始数据图上的起始点
    # 第一个测试序列的第一个输入点对应原始数据的 sst_values[split_index_in_sequences]
    # 该序列的目标 (y) 对应原始数据的 sst_values[split_index_in_sequences + look_back]
    overlay_start_index_for_target = split_index_in_sequences + look_back
    
    num_predictions_to_plot = len(test_predictions_orig)
    
    # 确保不越界
    if overlay_start_index_for_target + num_predictions_to_plot > len(original_sst_values):
        num_predictions_to_plot = len(original_sst_values) - overlay_start_index_for_target

    if num_predictions_to_plot <= 0:
        print("警告：无法在原始数据图上叠加测试集预测，可能是索引计算问题或数据太少。")
    else:
        # 准备用于绘图的预测数据（只取第一步预测）
        predictions_to_plot = test_predictions_orig[:num_predictions_to_plot, 0] if predict_steps > 1 else test_predictions_orig[:num_predictions_to_plot].flatten()
        
        if x_axis_is_time:
            time_for_predictions = original_time_coords[overlay_start_index_for_target : overlay_start_index_for_target + num_predictions_to_plot]
            if len(time_for_predictions) == len(predictions_to_plot):
                 plt.plot(time_for_predictions, predictions_to_plot, label=f'测试集预测 (Step 1)', color='red', linestyle='-')
            else: # Fallback
                 print("时间坐标与预测数据长度不匹配，绘图可能不准确。")
                 plt.plot(np.arange(overlay_start_index_for_target, overlay_start_index_for_target + num_predictions_to_plot),
                     predictions_to_plot, label=f'测试集预测 (Step 1, approx.)', color='red', linestyle='-')

        else: # x轴是索引
            plt.plot(np.arange(overlay_start_index_for_target, overlay_start_index_for_target + num_predictions_to_plot),
                     predictions_to_plot, label=f'测试集预测 (Step 1)', color='red', linestyle='-')
        
        # 标记测试集开始位置
        if x_axis_is_time and overlay_start_index_for_target < len(original_time_coords):
            plt.axvline(x=original_time_coords[overlay_start_index_for_target], color='green', linestyle='--', label=f'测试集预测开始')
        elif not x_axis_is_time:
             plt.axvline(x=overlay_start_index_for_target, color='green', linestyle='--', label=f'测试集预测开始')


    plt.title('SST预测叠加在原始数据上')
    plt.xlabel('日期 (Date)' if x_axis_is_time else '时间步 (原始数据索引)')
    plt.ylabel('SST (°C)')
    plt.legend()
    plt.grid(True)
    plt.show()

In [27]:
# -----------------------------------------------------------------------------
# 模块 6: 主执行流程
# -----------------------------------------------------------------------------
def main():
    """主执行函数"""
    print("--- 开始SST预测流程 ---")

    # 1. 数据加载和预处理
    print("\n--- 1. 数据加载和预处理 ---")
    file_paths = get_file_paths_in_range(DATA_DIR, START_DATE_STR, END_DATE_STR,
                                         FILE_PREFIX, FILE_SUFFIX, FILENAME_DATE_FORMAT)
    if not file_paths:
        print("未能获取文件路径，程序退出。")
        return

    sst_values, time_coords = load_and_preprocess_sst_data(file_paths, LATITUDE_POINT, LONGITUDE_POINT, VARIABLE_NAME)
    if sst_values is None:
        print("未能加载SST数据，程序退出。")
        return

    preview_raw_data(sst_values, time_coords, LATITUDE_POINT, LONGITUDE_POINT, START_DATE_STR, END_DATE_STR)
    
    sst_scaled, scaler = scale_data(sst_values)
    
    X_sequences, y_sequences = create_sequences(sst_scaled, LOOK_BACK, PREDICT_STEPS)
    if X_sequences.size == 0:
        print("未能创建序列数据，程序退出。")
        return
    
    # 记录分割前X_sequences的长度，用于后续计算叠加图的起始索引
    num_original_sequences = len(X_sequences)

    train_loader, test_loader, X_test_tensor, y_test_tensor = split_and_prepare_data(
        X_sequences, y_sequences, TRAIN_SPLIT_RATIO, BATCH_SIZE, LOOK_BACK
    )

    if train_loader is None:
        print("未能准备训练数据加载器，程序退出。")
        return
    
    # 计算测试集在原始序列中的起始索引，用于后续绘图
    # split_index_in_sequences = len(X_sequences) - (len(X_test_tensor) if X_test_tensor is not None else 0)
    # 上面的计算方式不准确，因为split_index是基于len(X_data)*TRAIN_SPLIT_RATIO
    # 更准确的是，如果train_loader.dataset.X是X_train_tensor
    split_index_in_sequences = len(train_loader.dataset.X) if hasattr(train_loader.dataset, 'X') else int(num_original_sequences * TRAIN_SPLIT_RATIO)


    # 2. 模型初始化
    print("\n--- 2. 模型初始化 ---")
    model = LSTMModel(input_size=1, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS,
                      output_size=PREDICT_STEPS, dropout_rate=DROPOUT_RATE).to(DEVICE)
    print(model)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # 3. 模型训练
    print("\n--- 3. 模型训练 ---")
    # 注意：这里我们将test_loader作为验证集传入训练循环
    train_losses, val_losses = run_training_loop(model, train_loader, test_loader, criterion, optimizer,
                                                 NUM_EPOCHS, DEVICE, BEST_MODEL_PATH)
    plot_losses(train_losses, val_losses)

    # 4. 模型评估
    print("\n--- 4. 模型评估 ---")
    # y_test_tensor 是未经过DataLoader处理的完整测试集标签，用于绘图时获取原始测试集长度
    y_test_len_for_plot = len(y_test_tensor) if y_test_tensor is not None else 0

    test_predictions_orig, test_actuals_orig, rmse = evaluate_model(
        model, test_loader, criterion, DEVICE, scaler, BEST_MODEL_PATH
    )
    
    # 5. 结果可视化 (只有在评估成功后才进行)
    if test_predictions_orig is not None and test_actuals_orig is not None:
        print("\n--- 5. 结果可视化 ---")
        plot_test_predictions(test_actuals_orig, test_predictions_orig, PREDICT_STEPS, y_test_len_for_plot)
        
        plot_predictions_on_original(sst_values, time_coords,
                                     test_predictions_orig, test_actuals_orig,
                                     split_index_in_sequences, LOOK_BACK, PREDICT_STEPS)
    else:
        print("由于评估未成功或无测试数据，跳过部分结果可视化。")

    print("\n--- SST预测流程结束 ---")

if __name__ == "__main__":
    main()

--- 开始SST预测流程 ---

--- 1. 数据加载和预处理 ---
找到 91 个文件，将使用 xarray.open_mfdataset 进行合并。
加载或处理数据时发生错误: [Errno -101] NetCDF: HDF error: 'D:\\SCUT\\Second Year\\Second Semester\\地理所Project\\SST_Prediciton\\data\\ERA5\\20200101_ERA5_daily_mean_sst.nc'
未能加载SST数据，程序退出。


In [35]:
import xarray as xr
root_dir = "D:\\SCUT\\Second Year\\Second Semester\\地理所Project\\SST_Prediciton"
file = os.path.join(root_dir,"data/ERA5/20200101_ERA5_daily_mean_sst.nc")
try:
    dt = xr.open_dataset(file)
    print(dt.info())
except Exception as e:
    print(e)

[Errno -101] NetCDF: HDF error: 'D:\\SCUT\\Second Year\\Second Semester\\地理所Project\\SST_Prediciton\\data\\ERA5\\20200101_ERA5_daily_mean_sst.nc'
