In [1]:
%load_ext autoreload
%autoreload 2

# 训练 Conv-LSTM 模型。
# 该模型通过同一个月的前 14 天的 SST 数据预测未来 1 天的 SST 数据。

# 导入数据集
from torch.utils.data import DataLoader

from src.dataset.ERA5 import ERA5SSTDataset

# 定义参数
offset = 1461
width = 15
step = 15
lon = [60, 80]
lat=[160, 180]

# 创建全新的数据集
train_data_set = ERA5SSTDataset(width, step, offset, lon, lat)
test_data_set = ERA5SSTDataset(width, step, offset + 10, lon, lat)

train_dataloader = DataLoader(train_data_set, batch_size=10, shuffle=False)
test_dataloader = DataLoader(test_data_set, batch_size=10, shuffle=False)

In [None]:
from lightning import Trainer

from src.models.LSTM import ConvLSTM
from src.config.params import CHECK_POINT

model = ConvLSTM(1, 5, kernel_size=(5,5), num_layers=2)
print(model)
trainer = Trainer(max_epochs=100, limit_train_batches=20)
trainer.fit(model, train_dataloaders=train_dataloader, ckpt_path=CHECK_POINT)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


ConvLSTM(
  (cell_list): ModuleList(
    (0): ConvLSTMCell(
      (conv): Conv2d(6, 20, kernel_size=(5, 5), stride=(1, 1), padding=same)
    )
    (1): ConvLSTMCell(
      (conv): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1), padding=same)
    )
  )
  (fc): Linear(in_features=32000, out_features=6400, bias=True)
)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params | Mode 
-------------------------------------------------
0 | cell_list | ModuleList | 8.0 K  | train
1 | fc        | Linear     | 204 M  | train
-------------------------------------------------
204 M     Trainable params
0         Non-trainable params
204 M     Total params
819.258   Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


Epoch 0:   0%|          | 0/20 [00:00<?, ?it/s] output: torch.Size([10, 1, 80, 80])
losses: 660.983642578125
Epoch 0:   5%|▌         | 1/20 [00:00<00:09,  1.94it/s, v_num=0]output: torch.Size([10, 1, 80, 80])
losses: 606.7784423828125
Epoch 0:  10%|█         | 2/20 [00:00<00:08,  2.12it/s, v_num=0]output: torch.Size([10, 1, 80, 80])
losses: 558.1234741210938
Epoch 0:  15%|█▌        | 3/20 [00:01<00:07,  2.21it/s, v_num=0]output: torch.Size([10, 1, 80, 80])
losses: 510.2666015625
Epoch 0:  20%|██        | 4/20 [00:01<00:07,  2.28it/s, v_num=0]output: torch.Size([10, 1, 80, 80])
losses: 457.17413330078125
Epoch 0:  25%|██▌       | 5/20 [00:02<00:06,  2.31it/s, v_num=0]output: torch.Size([10, 1, 80, 80])
losses: 401.9193420410156
Epoch 0:  30%|███       | 6/20 [00:02<00:06,  2.33it/s, v_num=0]output: torch.Size([10, 1, 80, 80])
losses: 354.2073059082031
Epoch 0:  35%|███▌      | 7/20 [00:02<00:05,  2.35it/s, v_num=0]output: torch.Size([10, 1, 80, 80])
losses: 311.63616943359375
Epoch 0:  

In [None]:
# 预测
from src.utils.plot import plot_sst_distribution_compare # 导入绘图函数，绘制SST分布的比较图
from src.models.model import ssim_loss # 导入自定义的SSIM损失函数，模型训练时计算损失

from keras.src.saving import load_model  # 从Keras中导入加载模型的函数

# saved_model = load_model(MODEL_SAVE_PATH + '/Conv2DLSTMNetwork.keras', custom_objects={ 'log_cosh':losses.log_cosh, 'Conv2DLSTMNetwork': Conv2DLSTMNetwork})

y_pred = model.predict(x_test) # 使用模型进行预测

print(model.evaluate(x_test,y_test)) # x_test测试数据

print(y_pred[5, :, :, 0].shape)  # 打印第15个样本的预测结果的形状
print(y_test[5, :, :, 0].shape)

y = y_pred[5, :, :, 0]
g = y_test[5, :, :, 0]

print(y)
print(g)

# 计算RMSE
from sklearn.metrics import mean_squared_error

rmse = mean_squared_error(y, g.cpu().numpy())

print("RMSE: ", rmse)

plot_sst_distribution_compare(y, g.cpu().numpy())