In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import torch
import copy
import tqdm
import random
import numpy as np
import xarray as xr
from collections import defaultdict
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.metrics import confusion_matrix, recall_score, f1_score, precision_score, r2_score
import matplotlib.pyplot as plt
# 绘制混淆矩阵图
import seaborn as sns

In [None]:
def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
class Model_class(nn.Module):
    def __init__(self):
        super(Model_class, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
        # ----------------------------
        # Calculate the number of features after the feature extractor
        # ----------------------------
        with torch.no_grad():
            # Create a dummy input tensor with the correct input size
            dummy_input = torch.zeros(1, 6, 24, 36)  # Batch size 1, 6 channels, 24x36 image
            features_output = self.conv1(dummy_input)
            n_features = features_output.shape[1] * features_output.shape[2] * features_output.shape[3]
            
        # Shared fully connected layers
        self.fc_shared = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(n_features, 100),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        
        # Separate output layers
        self.fc_regression = nn.Linear(100, 1)      # For Niño3.4 regression output
        self.fc_classification = nn.Linear(100, 5)  # For 5-class classification output

    def forward(self, x):
        x = self.conv1(x)
        x = torch.flatten(x, start_dim=1)
        x_shared = self.fc_shared(x)
        x_regression = self.fc_regression(x_shared)
        x_classification = self.fc_classification(x_shared)
        return x_regression, x_classification


In [None]:
net = Model_class()
X = torch.randn(1, 6, 24, 36)

# Conv 部分
for layer in net.conv1:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)

# Flatten the features into a 1D tensor
X = torch.flatten(X, 1)  # Flatten all dimensions except batch dimension
    
# Shared fully connected layers 部分
for layer in net.fc_shared:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
    
# Regression output for label_pr
X_pr = net.fc_regression(X)
print(f"Regression output for label_pr形状:\t{X_pr.shape}")

# Classification output for label_month
X_month = net.fc_classification(X)
print(f"Regression output for label_pr形状:\t{X_month.shape}")


In [None]:
class MyDataset(Dataset):
    """
    Custom Dataset for loading images and labels from NetCDF files for multi-task learning.

    Each sample consists of:
    - image: ssta data at a given time index, shape (6, 24, 36)
    - label_nino34: 'Niño3.4' variable at the same time index, scalar regression target
    - label_softmax5: 'softmax5' variable at the same time index, integer from 0 to 4 (5 classes)
    """

    def __init__(self,
                 data_root='.',   # Current working directory
                 phase='train',   # 'train' or 'val' to specify dataset phase
                 normalize=False,  # Whether to normalize the input data
                 transform=None):
        """
        Initializes the dataset by loading data from NetCDF files and splitting into train and val sets.

        Args:
            data_root (str): Root directory containing the 'train' folder with the data files.
            phase (str): Indicates whether to load 'train' or 'val' dataset.
            normalize (bool): If True, normalize the input images.
            transform (callable, optional): Optional transform to be applied to the images.
        """

        # ----------------------------
        # Load input and label data
        # ----------------------------

        # File paths for input data and labels
        input_file = os.path.join(data_root, 'train', '')
        label_file = os.path.join(data_root, 'train', '')

        # Load the input data using xarray
        self.input_ds = xr.open_dataset(input_file)
        # Extract the 'ssta' variable which has dimensions (time, lev, latitude, longitude)
        self.images = self.input_ds['ssta']

        # Load the label data using xarray
        self.label_ds = xr.open_dataset(label_file)
        # Extract 'Niño3.4' and 'softmax5' variables
        self.labels_nino34 = self.label_ds['Niño3.4']
        self.labels_softmax5 = self.label_ds['softmax5']

        # ----------------------------
        # Ensure time alignment
        # ----------------------------

        # Number of samples
        self.length = self.images.sizes['time']  # Should be 1490

        # ----------------------------
        # Split data into train and val sets
        # ----------------------------

        # Create an array of indices
        indices = np.arange(self.length)
        # Shuffle the indices
        np.random.shuffle(indices)
        # Compute the split index
        split_idx = int(0.7 * self.length)
        # Split indices into train and val
        train_indices = indices[:split_idx]
        val_indices = indices[split_idx:]

        # Assign indices based on phase
        if phase == 'train':
            self.indices = train_indices
        elif phase == 'val':
            self.indices = val_indices
        else:
            raise ValueError("phase must be 'train' or 'val'")

        # Store the transform if provided
        self.transform = transform

        # Store normalization flag
        self.normalize = normalize

        # Precompute mean and std if normalization is True
        if self.normalize:
            # For simplicity, we'll assume mean and std are 0 and 1 (no normalization)
            self.mean = 0.0
            self.std = 1.0

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.indices)

    def __getitem__(self, idx):
        """
        Retrieves the image and labels at the specified index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            image (Tensor): The input image tensor of shape (6, 24, 36).
            labels (tuple): A tuple containing:
                - label_nino34 (Tensor): Scalar tensor, the 'Niño3.4' variable (regression target).
                - label_softmax5 (Tensor): Scalar tensor, the 'softmax5' variable (classification target).
        """

        # Get the actual index in the dataset
        actual_idx = self.indices[idx]

        # ----------------------------
        # Load and process the image
        # ----------------------------

        # Select the image data at the given time index
        # Shape: (lev, latitude, longitude) = (6, 24, 36)
        image = self.images.isel(time=actual_idx)

        # Convert to NumPy array and ensure type float32
        image = image.values.astype(np.float32)

        # If normalization is True, apply standardization
        if self.normalize:
            image = (image - self.mean) / self.std

        # If a transform is provided (e.g., additional preprocessing), apply it to the image
        if self.transform:
            image = self.transform(image)

        # Convert to PyTorch tensor
        image = torch.from_numpy(image)

        # ----------------------------
        # Load and process the 'Niño3.4' label
        # ----------------------------

        # Select the 'Niño3.4' label at the given time index
        label_nino34 = self.labels_nino34.isel(time=actual_idx)

        # Convert to NumPy scalar and ensure type float32
        label_nino34 = label_nino34.values.astype(np.float32)

        # Convert to PyTorch tensor
        label_nino34 = torch.tensor(label_nino34)

        # ----------------------------
        # Load and process the 'softmax5' label
        # ----------------------------

        # Select the 'softmax5' label at the given time index
        label_softmax5 = self.labels_softmax5.isel(time=actual_idx)

        # Convert to integer
        label_softmax5 = int(label_softmax5.values)

        # Convert to PyTorch tensor of type long
        label_softmax5 = torch.tensor(label_softmax5, dtype=torch.long)

        # ----------------------------
        # Return the image and labels
        # ----------------------------

        # Return the image and a tuple of labels for multi-task learning
        return image, (label_nino34, label_softmax5)


In [None]:
# Test for MyDataset
# Create training and validation datasets
train_dataset = MyDataset(data_root='.', phase='train', normalize=False)
val_dataset = MyDataset(data_root='.', phase='val', normalize=False)

# Access a sample from the training dataset
image, (label_nino34, label_softmax5) = train_dataset[0]

print("Image shape:", image.shape)  # Should be (6, 24, 36)
print("Niño3.4 label:", label_nino34.item())
print("Softmax5 label:", label_softmax5.item())


In [None]:
def train_model(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epochs=25, save_dir="checkpoints"):
    """
    训练多任务学习模型的函数。

    参数：
        model: 要训练的模型（Model_class）。
        train_dataloader: 训练数据集的 DataLoader。
        val_dataloader: 验证数据集的 DataLoader。
        optimizer: 优化器。
        scheduler: 学习率调度器。
        num_epochs: 训练的轮数。
        save_dir: 模型和结果保存的目录。
    """
    # 设置设备（GPU 或 CPU）
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 初始化最佳模型的权重和最佳验证损失
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_loss = float('inf')

    # 保存训练和验证过程中的统计结果
    results = {
        "train_loss": [],
        "train_mae_nino34": [],
        "train_mse_nino34": [],
        "train_r2_nino34": [],
        "train_acc_softmax5": [],
        "val_loss": [],
        "val_mae_nino34": [],
        "val_mse_nino34": [],
        "val_r2_nino34": [],
        "val_acc_softmax5": [],
    }

    # 确保保存目录存在
    os.makedirs(save_dir, exist_ok=True)

    # 开始训练循环
    for epoch in range(num_epochs):
        print('-' * 30)
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print('-' * 30)

        # 每个 epoch 包含一个训练阶段和一个验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置模型为训练模式
                dataloader = train_dataloader
            else:
                model.eval()   # 设置模型为验证模式
                dataloader = val_dataloader

            # 初始化本阶段的指标
            metrics = defaultdict(float)
            # 样本总数
            total_samples = 0
            # 存储预测值和真实值，用于计算统计指标
            preds_nino34_list = []
            labels_nino34_list = []
            preds_softmax5_list = []
            labels_softmax5_list = []

            # 使用 tqdm 显示进度条
            bar = tqdm.tqdm(dataloader)
            for inputs, (labels_nino34, labels_softmax5) in bar:
                # 将数据移动到设备上
                inputs = inputs.to(device)
                labels_nino34 = labels_nino34.to(device)
                labels_softmax5 = labels_softmax5.to(device)

                # 优化器梯度清零
                optimizer.zero_grad()

                # 前向传播
                with torch.set_grad_enabled(phase == 'train'):
                    # 获取模型输出
                    outputs_nino34, outputs_softmax5 = model(inputs)

                    # 计算损失
                    loss_nino34 = nn.MSELoss()(outputs_nino34.squeeze(), labels_nino34)
                    loss_softmax5 = nn.CrossEntropyLoss()(outputs_softmax5, labels_softmax5)
                    loss = 0.8 * loss_softmax5 + 0.2 * loss_nino34  # 多任务的总损失

                    # 获取预测值
                    preds_nino34 = outputs_nino34.detach().squeeze()
                    preds_softmax5 = torch.argmax(outputs_softmax5, dim=1)

                    # 训练阶段进行反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 更新指标
                batch_size = inputs.size(0)
                metrics['loss'] += loss.item() * batch_size  # 累积损失

                # 回归任务的平均绝对误差（MAE）和均方误差（MSE）
                mae_nino34 = F.l1_loss(preds_nino34, labels_nino34, reduction='mean')
                mse_nino34 = F.mse_loss(preds_nino34, labels_nino34, reduction='mean')
                metrics['mae_nino34'] += mae_nino34.item() * batch_size
                metrics['mse_nino34'] += mse_nino34.item() * batch_size

                # 分类任务的正确预测数量
                corrects_softmax5 = torch.sum(preds_softmax5 == labels_softmax5)
                metrics['corrects_softmax5'] += corrects_softmax5.item()

                # 样本总数
                total_samples += batch_size

                # 存储预测值和真实值
                preds_nino34_list.append(preds_nino34.cpu().numpy())
                labels_nino34_list.append(labels_nino34.cpu().numpy())
                preds_softmax5_list.append(preds_softmax5.cpu().numpy())
                labels_softmax5_list.append(labels_softmax5.cpu().numpy())

                # 更新进度条显示
                bar.set_description(f"{phase.capitalize()} Loss: {metrics['loss'] / total_samples :.4f}")
                
                # Update progress bar with current metrics
                bar.set_postfix({
                    "MSE_Nino34": f"{metrics['mse_nino34']/total_samples:.4f}",
                    "Acc_Softmax5": f"{metrics['corrects_softmax5']/total_samples:.4f}"
                })
                
            # 计算本阶段的平均损失和指标
            # R2越接近1，表示拟合效果越好
            epoch_loss = metrics['loss'] / total_samples
            epoch_mae_nino34 = metrics['mae_nino34'] / total_samples
            epoch_mse_nino34 = metrics['mse_nino34'] / total_samples
            epoch_r2_nino34 = r2_score(np.concatenate(labels_nino34_list, axis=0),
                                       np.concatenate(preds_nino34_list, axis=0))
            epoch_acc_softmax5 = metrics['corrects_softmax5'] / total_samples

            print(f"{phase.capitalize()} Loss: {epoch_loss:.4f}\n"
                  f"MSE_Nino34: {epoch_mse_nino34:.4f}\n"
                  f"MAE_Nino34: {epoch_mae_nino34:.4f}\n"
                  f"R2_Nino34: {epoch_r2_nino34:.4f}\n Acc_Softmax5: {epoch_acc_softmax5:.4f}\n")

            # 计算分类任务的其他指标
            labels_softmax5_array = np.concatenate(labels_softmax5_list, axis=0)
            preds_softmax5_array = np.concatenate(preds_softmax5_list, axis=0)
            cm_softmax5 = confusion_matrix(labels_softmax5_array, preds_softmax5_array)
            precision_softmax5 = precision_score(labels_softmax5_array, preds_softmax5_array, average='macro', zero_division=0)
            recall_softmax5 = recall_score(labels_softmax5_array, preds_softmax5_array, average='macro')
            f1_softmax5 = f1_score(labels_softmax5_array, preds_softmax5_array, average='macro')

            print(f"{phase.capitalize()} Softmax5 Classification Metrics:")
            print("Confusion Matrix:")
            print(cm_softmax5)
            print(f"Precision_Softmax5: {precision_softmax5 :.4f} Recall_Softmax5: {recall_softmax5 :.4f} F1-score_Softmax5: {f1_softmax5 :.4f}")

            # 保存结果
            if phase == 'train':
                results["train_loss"].append(epoch_loss)
                results["train_mae_nino34"].append(epoch_mae_nino34)
                results["train_mse_nino34"].append(epoch_mse_nino34)
                results["train_r2_nino34"].append(epoch_r2_nino34)
                results["train_acc_softmax5"].append(epoch_acc_softmax5)
            else:
                results["val_loss"].append(epoch_loss)
                results["val_mae_nino34"].append(epoch_mae_nino34)
                results["val_mse_nino34"].append(epoch_mse_nino34)
                results["val_r2_nino34"].append(epoch_r2_nino34)
                results["val_acc_softmax5"].append(epoch_acc_softmax5)

                # 如果当前验证损失小于最佳验证损失，保存模型
                if epoch_loss < best_val_loss:
                    print("Saving the best model based on Validation Loss.")
                    best_val_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(best_model_wts, os.path.join(save_dir, "best_model.pth"))
                    print("Best model saved")

                    # 保存最佳模型的指标到文件
                    with open(os.path.join(save_dir, "best_metrics.txt"), "w") as f:
                        f.write(f"Epoch: {epoch + 1}\n")
                        f.write(f"Val Loss: {epoch_loss:.4f}\n")
                        f.write(f"Val MAE_Nino34: {epoch_mae_nino34:.4f}\n")
                        f.write(f"Val MSE_Nino34: {epoch_mse_nino34:.4f}\n")
                        f.write(f"Val R2_Nino34: {epoch_r2_nino34:.4f}\n")
                        f.write(f"Val Acc_Softmax5: {epoch_acc_softmax5:.4f}\n")
                        f.write(f"Precision: {precision_softmax5:.4f}\n")
                        f.write(f"Recall: {recall_softmax5:.4f}\n")
                        f.write(f"F1-score: {f1_softmax5:.4f}\n")
                        f.write("Confusion Matrix:\n")
                        f.write(f"{cm_softmax5}\n")

            # 每 10 个 epoch 保存一次模型
            if (epoch + 1) % 10 == 0:
                torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch + 1}.pth"))
                print(f"Model saved at epoch {epoch + 1}")

        # 学习率调度器更新
        scheduler.step()

        # 绘制并保存训练曲线
        epochs = np.arange(1, epoch + 2)
        plt.figure()
        plt.plot(epochs, results["train_loss"], label="Train Loss")
        plt.plot(epochs, results["val_loss"], label="Val Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Loss Curve")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "loss_curve.png"))
        plt.close()

        plt.figure()
        plt.plot(epochs, results["train_acc_softmax5"], label="Train Acc Softmax5")
        plt.plot(epochs, results["val_acc_softmax5"], label="Val Acc Softmax5")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.title("Accuracy Curve (Softmax5 Classification)")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "accuracy_curve_softmax5.png"))
        plt.close()

        plt.figure()
        plt.plot(epochs, results["train_mae_nino34"], label="Train MAE Nino34")
        plt.plot(epochs, results["val_mae_nino34"], label="Val MAE Nino34")
        plt.xlabel("Epoch")
        plt.ylabel("MAE")
        plt.title("MAE Curve (Nino34 Regression)")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "mae_curve_nino34.png"))
        plt.close()

        plt.figure()
        plt.plot(epochs, results["train_r2_nino34"], label="Train R2 Nino34")
        plt.plot(epochs, results["val_r2_nino34"], label="Val R2 Nino34")
        plt.xlabel("Epoch")
        plt.ylabel("R2 Score")
        plt.title("R2 Score Curve (Nino34 Regression)")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "r2_curve_nino34.png"))
        plt.close()

    # 训练完成后，加载最佳模型权重
    model.load_state_dict(best_model_wts)
    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
    return model


In [None]:
seed_torch(1024)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model_class()
# model = model.to(device)

# Load the state dictionary
# model.load_state_dict(torch.load('checkpoints/train_1/best_model.pth'))

train_set = MyDataset(data_root=".", phase="train", normalize=False)
valid_set = MyDataset(data_root=".", phase="val", normalize=False)
train_dataloader = DataLoader(train_set, batch_size=30, shuffle=True, num_workers=0, drop_last=True)
val_dataloader = DataLoader(valid_set, batch_size=30, shuffle=False, num_workers=0, drop_last=True)

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_ft, T_max=30, eta_min=1e-6)

model = train_model(model, train_dataloader, val_dataloader, optimizer_ft, exp_lr_scheduler, num_epochs=25, save_dir="checkpoints/train_1")