In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

# 已定义的模块
from model.CNN_feature_extractor import CNNFeatureExtractor
from model.model1 import TemporalLSTM
from dataloader1 import ImageSequenceDataset 


In [None]:
# 图像预处理
import os
from model.model1 import OzonePredictor


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# 超参数
# 提供根目录
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
json_path = os.path.join(root_dir, 'data/dataset.json')
json_path = os.path.normpath(json_path)

batch_size = 4
num_epochs = 20
lr = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 数据加载
dataset = ImageSequenceDataset(json_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型构建
model = OzonePredictor(cnn_out_dim=128, lstm_hidden_dim=64, output_dim=3)
model.to(device)

# 优化器与损失函数
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# 训练循环
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch in dataloader:
        imgs = batch['images'].to(device)   # (B, T, 4, C, H, W)
        labels = batch['npy'].to(device)    # (B, T) or (B, output_dim)

        optimizer.zero_grad()
        outputs = model(imgs)               # (B, output_dim)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")