In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import torch
import copy
import tqdm
import random
import numpy as np
import xarray as xr
from collections import defaultdict
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.metrics import confusion_matrix, recall_score, f1_score, precision_score, r2_score
import matplotlib.pyplot as plt
# 绘制混淆矩阵图
import seaborn as sns

In [None]:
def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


In [None]:
class AlexNet(nn.Module):
    """
    Modified AlexNet model for multi-task learning.
    
    This model outputs:
    - Regression output of size 23 for label_pr.
    - Classification output of size 12 for label_month.
    """
    def __init__(self):
        super(AlexNet, self).__init__()
        # ----------------------------
        # Feature extraction layers
        # ----------------------------
        self.features = nn.Sequential(
            # Convolutional layer 1
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),  # Adjusted input channels to 3
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            # Convolutional layer 2
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            # Convolutional layer 3
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # Convolutional layer 4
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # Convolutional layer 5
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # Adjusted pooling layer to prevent feature map from becoming too small
            nn.MaxPool2d(kernel_size=2, stride=1),  # Modified kernel size and stride
        )

        # ----------------------------
        # Calculate the number of features after the feature extractor
        # ----------------------------
        with torch.no_grad():
            # Create a dummy input tensor with the correct input size
            dummy_input = torch.zeros(1, 3, 24, 36)  # Batch size 1, 3 channels, 24x36 image
            features_output = self.features(dummy_input)
            n_features = features_output.shape[1] * features_output.shape[2] * features_output.shape[3]
            # In this configuration, n_features should be 512 (256 * 1 * 2)

        # ----------------------------
        # Shared fully connected layers
        # ----------------------------
        self.fc_shared = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(n_features, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
        )

        # ----------------------------
        # Output layers for each task
        # ----------------------------
        # Regression output for label_pr (size 23)
        self.fc_pr = nn.Linear(4096, 23)  # Regression output of size 23

        # Classification output for label_month (12 classes)
        self.fc_month = nn.Linear(4096, 12)  # Classification output of size 12

    def forward(self, x):
        # Pass input through feature extractor
        x = self.features(x)
        # Flatten the features into a 1D tensor
        x = torch.flatten(x, 1)  # Flatten all dimensions except batch dimension

        # Pass through shared fully connected layers
        x = self.fc_shared(x)

        # Pass through task-specific output layers
        output_pr = self.fc_pr(x)        # Regression output for label_pr
        output_month = self.fc_month(x)  # Classification output for label_month

        return output_pr, output_month


In [None]:
net = AlexNet()
X = torch.randn(1, 3, 24, 36)

# feature 部分
for layer in net.features:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)

# Flatten the features into a 1D tensor
X = torch.flatten(X, 1)  # Flatten all dimensions except batch dimension
    
# Shared fully connected layers 部分
for layer in net.fc_shared:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
    
# Regression output for label_pr
X_pr = net.fc_pr(X)
print(f"Regression output for label_pr形状:\t{X_pr.shape}")

# Classification output for label_month
X_month = net.fc_month(X)
print(f"Regression output for label_pr形状:\t{X_month.shape}")


In [None]:
class MyDataset(Dataset):
    """
    Custom Dataset for loading images and labels from NetCDF files for multi-task learning.

    Each sample consists of:
    - image: sst data at a given time index, shape (3, 24, 36)
    - label_pr: 'pr' variable at the same time index, shape (23,), regression target
    - label_month: 'month_label' variable at the same time index, integer from 0 to 11 (12 classes)
    """

    def __init__(self, 
                 data_root='.',   # Current working directory
                 phase='train',    # 'train' or 'val' to specify dataset phase
                 normalize=False,   # Whether to normalize the input data
                 transform=None):
        """
        Initializes the dataset by loading data from NetCDF files.

        Args:
            data_root (str): Root directory containing the data folders 'train' and 'val'.
            phase (str): Indicates whether to load 'train' or 'val' dataset.
            normalize (bool): If True, normalize the input images.
            transform (callable, optional): Optional transform to be applied to the images.
        """

        # ----------------------------
        # Set dataset paths based on phase
        # ----------------------------
        if phase == 'train':
            input_file = os.path.join(data_root, 'train', '')
            label_file = os.path.join(data_root, 'train', '')
        elif phase == 'val':
            input_file = os.path.join(data_root, 'val', '')
            label_file = os.path.join(data_root, 'val', '')
        else:
            raise ValueError("phase must be 'train' or 'val'")

        # ----------------------------
        # Load image data
        # ----------------------------
        # Load the image data using xarray
        self.image_ds = xr.open_dataset(input_file)
        # Extract the 'sst' variable which has dimensions (time, lev, lat, lon)
        self.images = self.image_ds['sst']

        # ----------------------------
        # Load label data
        # ----------------------------
        # Load the label data using xarray
        self.label_ds = xr.open_dataset(label_file)
        # Extract the 'pr' variable (time, lev, lat, lon) and 'month_label' (time, lat, lon)
        self.labels_pr = self.label_ds['pr']
        self.labels_month = self.label_ds['month_label']

        # Get the number of samples from the time dimension
        self.length = self.images.sizes['time']

        # Store the transform if provided
        self.transform = transform

        # Store normalization flag
        self.normalize = normalize

        # Precompute mean and std if normalization is True
        if self.normalize:
            # Compute mean and std over the dataset
            # Note: Depending on the size of the dataset, this might be memory-intensive
            # Alternative: Use predefined mean and std values if known
            # Here, we will compute mean and std over the 'lev' (channel), 'lat', and 'lon' dimensions
            # For simplicity, we'll assume mean and std are 0 and 1 (i.e., standardization is not applied)
            # Users can compute and set their own mean and std if needed
            # self.mean = self.images.mean(dim=('time', 'lat', 'lon')).values
            # self.std = self.images.std(dim=('time', 'lat', 'lon')).values
            # For this example, we set mean and std to 0 and 1
            self.mean = 0.0
            self.std = 1.0

    def __len__(self):
        """Returns the total number of samples."""
        return self.length

    def __getitem__(self, idx):
        """
        Retrieves the image and labels at the specified index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            image (Tensor): The input image tensor of shape (3, 24, 36).
            labels (tuple): A tuple containing:
                - label_pr (Tensor): Tensor of shape (23,), the 'pr' variable (regression target).
                - label_month (Tensor): Scalar tensor, the month label (0-11) for classification.
        """

        # ----------------------------
        # Load and process the image
        # ----------------------------

        # Select the image data at the given time index
        # Resulting shape: (lev, lat, lon) = (3, 24, 36)
        image = self.images.isel(time=idx)

        # Convert the xarray DataArray to a NumPy array and ensure it's of type float32
        image = image.values.astype(np.float32)

        # If normalization is True, apply standardization
        if self.normalize:
            image = (image - self.mean) / self.std

        # If a transform is provided (e.g., additional preprocessing), apply it to the image
        if self.transform:
            image = self.transform(image)

        # Convert the NumPy array to a PyTorch tensor
        # Final shape: (3, 24, 36)
        image = torch.from_numpy(image)

        # ----------------------------
        # Load and process the 'pr' label
        # ----------------------------

        # Select the 'pr' label at the given time index
        # Initial shape: (lev, lat, lon) = (23, 1, 1)
        label_pr = self.labels_pr.isel(time=idx)

        # Squeeze singleton dimensions (lat and lon) to get shape (23,)
        label_pr = label_pr.values.squeeze()

        # Ensure the label is of type float32
        label_pr = label_pr.astype(np.float32)

        # Convert to a PyTorch tensor
        label_pr = torch.from_numpy(label_pr)

        # ----------------------------
        # Load and process the 'month_label'
        # ----------------------------

        # Select the 'month_label' at the given time index
        # Initial shape: (lat, lon) = (1, 1)
        label_month = self.labels_month.isel(time=idx)

        # Squeeze singleton dimensions to get a scalar value
        label_month = label_month.values.squeeze()

        # Adjust month label to be in range 0-11 for classification (if needed)
        # Assuming label_month is in 1-12
        label_month = int(label_month) - 1  # Adjust to 0-11

        # Convert to a PyTorch tensor of type long (integer type)
        label_month = torch.tensor(label_month, dtype=torch.long)

        # ----------------------------
        # Return the image and labels
        # ----------------------------

        # Return the image and a tuple of labels for multi-task learning
        return image, (label_pr, label_month)


In [None]:
def train_model(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epochs=25, save_dir="checkpoints"):
    """
    训练多任务学习模型的函数。

    参数：
        model: 要训练的模型（AlexNet）。
        train_dataloader: 训练数据集的 DataLoader。
        val_dataloader: 验证数据集的 DataLoader。
        optimizer: 优化器。
        scheduler: 学习率调度器。
        num_epochs: 训练的轮数。
        save_dir: 模型和结果保存的目录。
    """
    # 设置设备（GPU 或 CPU）
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 初始化最佳模型的权重和最佳验证损失
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_loss = float('inf')

    # 保存训练和验证过程中的统计结果
    results = {
        "train_loss": [],
        "train_mae_pr": [],
        "train_mse_pr": [],
        "train_r2_pr": [],
        "train_acc_month": [],
        "val_loss": [],
        "val_mae_pr": [],
        "val_mse_pr": [],
        "val_r2_pr": [],
        "val_acc_month": [],
    }

    # 确保保存目录存在
    os.makedirs(save_dir, exist_ok=True)

    # 开始训练循环
    for epoch in range(num_epochs):
        print('-' * 30)
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print('-' * 30)

        # 每个 epoch 包含一个训练阶段和一个验证阶段
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # 设置模型为训练模式
                dataloader = train_dataloader
            else:
                model.eval()   # 设置模型为验证模式
                dataloader = val_dataloader

            # 初始化本阶段的指标
            metrics = defaultdict(float)
            # 样本总数
            total_samples = 0
            # 存储预测值和真实值，用于计算统计指标
            preds_pr_list = []
            labels_pr_list = []
            preds_month_list = []
            labels_month_list = []

            # 使用 tqdm 显示进度条
            bar = tqdm.tqdm(dataloader)
            for inputs, (labels_pr, labels_month) in bar:
                # 将数据移动到设备上
                inputs = inputs.to(device)
                labels_pr = labels_pr.to(device)
                labels_month = labels_month.to(device)

                # 优化器梯度清零
                optimizer.zero_grad()

                # 前向传播
                with torch.set_grad_enabled(phase == 'train'):
                    # 获取模型输出
                    outputs_pr, outputs_month = model(inputs)

                    # 计算损失
                    loss_pr = nn.MSELoss()(outputs_pr, labels_pr)  # 回归任务的均方误差损失
                    loss_month = nn.CrossEntropyLoss()(outputs_month, labels_month)  # 分类任务的交叉熵损失
                    loss = 0.8 * loss_pr + 0.2 * loss_month  # 多任务的总损失

                    # 获取预测值
                    preds_pr = outputs_pr.detach()# 获取值，并保证不参与梯度计算
                    preds_month = torch.argmax(outputs_month, 1)# 1表示在第二个维度中选择；0表示第一个维度，batch siz

                    # 训练阶段进行反向传播和优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 更新指标
                batch_size = inputs.size(0)
                metrics['loss'] += loss.item() * batch_size  # 累积损失

                # 回归任务的平均绝对误差（MAE）和均方误差（MSE），经过累加后，metrics里存的值都需要再除以总数
                mae_pr = F.l1_loss(preds_pr, labels_pr, reduction='mean')
                mse_pr = F.mse_loss(preds_pr, labels_pr, reduction='mean')
                metrics['mae_pr'] += mae_pr.item() * batch_size
                metrics['mse_pr'] += mse_pr.item() * batch_size

                # 分类任务的正确预测数量
                corrects_month = torch.sum(preds_month == labels_month)
                metrics['corrects_month'] += corrects_month.item()

                # 样本总数
                total_samples += batch_size

                # 存储预测值和真实值
                preds_pr_list.append(preds_pr.cpu().numpy())
                labels_pr_list.append(labels_pr.cpu().numpy())
                preds_month_list.append(preds_month.cpu().numpy())
                labels_month_list.append(labels_month.cpu().numpy())

                # 更新进度条显示
                bar.set_description(f"{phase.capitalize()} Loss: {metrics['loss'] / total_samples :.4f}")
                
                # Update progress bar with current metrics
                bar.set_postfix({
                    "MSE_PR": f"{metrics['mse_pr']/total_samples:.4f}",
                    "Acc_Month": f"{metrics['corrects_month']/total_samples:.4f}"
                })
                
            # 计算本阶段的平均损失和指标
            # R2越接近1，表示拟合效果越好
            epoch_loss = metrics['loss'] / total_samples
            epoch_mae_pr = metrics['mae_pr'] / total_samples
            epoch_mse_pr = metrics['mse_pr'] / total_samples
            epoch_r2_pr = r2_score(np.concatenate(labels_pr_list, axis=0),
                                   np.concatenate(preds_pr_list, axis=0))
            epoch_acc_month = metrics['corrects_month'] / total_samples

            print(f"{phase.capitalize()} Loss: {epoch_loss:.4f}\n"
                  f"MSE_PR: {epoch_mse_pr:.4f}\n"
                  f"MAE_PR: {epoch_mae_pr:.4f}\n"
                  f"R2_PR: {epoch_r2_pr:.4f}\n Acc_Month: {epoch_acc_month:.4f}\n")

            # 计算分类任务的其他指标
            labels_month_array = np.concatenate(labels_month_list, axis=0)
            preds_month_array = np.concatenate(preds_month_list, axis=0)
            cm_month = confusion_matrix(labels_month_array, preds_month_array)
            precision_month = precision_score(labels_month_array, preds_month_array, average='macro', zero_division=0)
            recall_month = recall_score(labels_month_array, preds_month_array, average='macro')
            f1_month = f1_score(labels_month_array, preds_month_array, average='macro')

            print(f"{phase.capitalize()} Month Classification Metrics:")
            print("Confusion Matrix:")
            print(cm_month)
            print(f"Precision_month: {precision_month :.4f} Recall_month: {recall_month :.4f} F1-score_month: {f1_month :.4f}")

            # Plot and save the confusion matrix heatmap
            # plt.figure(figsize=(10, 8))
            # sns.heatmap(cm_month, annot=True, fmt='d', cmap='Blues')
            # plt.xlabel('Predicted')
            # plt.ylabel('True')
            # plt.title(f'{phase.capitalize()} Confusion Matrix (Epoch {epoch + 1})')
            # plt.savefig(os.path.join(save_dir, f"{phase}_confusion_matrix_epoch_{epoch + 1}.png"))
            # plt.close()
            
            # 保存结果
            if phase == 'train':
                results["train_loss"].append(epoch_loss)
                results["train_mae_pr"].append(epoch_mae_pr)
                results["train_mse_pr"].append(epoch_mse_pr)
                results["train_r2_pr"].append(epoch_r2_pr)
                results["train_acc_month"].append(epoch_acc_month)
            else:
                results["val_loss"].append(epoch_loss)
                results["val_mae_pr"].append(epoch_mae_pr)
                results["val_mse_pr"].append(epoch_mse_pr)
                results["val_r2_pr"].append(epoch_r2_pr)
                results["val_acc_month"].append(epoch_acc_month)

                # 如果当前验证损失小于最佳验证损失，保存模型
                if epoch_loss < best_val_loss:
                    print("Saving the best model based on Validation Accuracy for label_month.")
                    best_val_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(best_model_wts, os.path.join(save_dir, "best_model.pth"))
                    print("Best model saved")

                    # 保存最佳模型的指标到文件
                    with open(os.path.join(save_dir, "best_metrics.txt"), "w") as f:
                        f.write(f"Epoch: {epoch + 1}\n")
                        f.write(f"Val Loss: {epoch_loss:.4f}\n")
                        f.write(f"Val MAE_PR: {epoch_mae_pr:.4f}\n")
                        f.write(f"Val MSE_PR: {epoch_mse_pr:.4f}\n")
                        f.write(f"Val R2_PR: {epoch_r2_pr:.4f}\n")
                        f.write(f"Val Acc_Month: {epoch_acc_month:.4f}\n")
                        f.write(f"Precision: {precision_month:.4f}\n")
                        f.write(f"Recall: {recall_month:.4f}\n")
                        f.write(f"F1-score: {f1_month:.4f}\n")
                        f.write("Confusion Matrix:\n")
                        f.write(f"{cm_month}\n")

            # 每 10 个 epoch 保存一次模型
            if (epoch + 1) % 10 == 0:
                torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch + 1}.pth"))
                print(f"Model saved at epoch {epoch + 1}")

        # 学习率调度器更新
        scheduler.step()

        # 绘制并保存训练曲线
        epochs = np.arange(1, epoch + 2)
        plt.figure()
        plt.plot(epochs, results["train_loss"], label="Train Loss")
        plt.plot(epochs, results["val_loss"], label="Val Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Loss Curve")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "loss_curve.png"))
        plt.close()

        plt.figure()
        plt.plot(epochs, results["train_acc_month"], label="Train Acc Month")
        plt.plot(epochs, results["val_acc_month"], label="Val Acc Month")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.title("Accuracy Curve (Month Classification)")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "accuracy_curve_month.png"))
        plt.close()

        plt.figure()
        plt.plot(epochs, results["train_mae_pr"], label="Train MAE PR")
        plt.plot(epochs, results["val_mae_pr"], label="Val MAE PR")
        plt.xlabel("Epoch")
        plt.ylabel("MAE")
        plt.title("MAE Curve (PR Regression)")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "mae_curve_pr.png"))
        plt.close()

        plt.figure()
        plt.plot(epochs, results["train_r2_pr"], label="Train R2 PR")
        plt.plot(epochs, results["val_r2_pr"], label="Val R2 PR")
        plt.xlabel("Epoch")
        plt.ylabel("R2 Score")
        plt.title("R2 Score Curve (PR Regression)")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "r2_curve_pr.png"))
        plt.close()

    # 训练完成后，加载最佳模型权重
    model.load_state_dict(best_model_wts)
    print("Training complete. Best validation loss: {:.4f}".format(best_val_loss))
    return model


In [None]:
# 这个train方法not used
def train2_model(model, 
               train_dataloader, 
               val_dataloader, 
               optimizer, 
               scheduler, 
               num_epochs=25, 
               save_dir="checkpoints"):
    """
    Trains and validates the given model using the provided dataloaders, optimizer, and scheduler.
    Implements multi-task learning with two outputs: regression for label_pr and classification for label_month.

    Args:
        model (nn.Module): The CNN model to train.
        train_dataloader (DataLoader): DataLoader for the training dataset.
        val_dataloader (DataLoader): DataLoader for the validation dataset.
        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
        scheduler (torch.optim.lr_scheduler): Learning rate scheduler.
        num_epochs (int, optional): Number of training epochs. Defaults to 25.
        save_dir (str, optional): Directory to save model checkpoints and metrics. Defaults to "checkpoints".

    Returns:
        nn.Module: The best model based on validation accuracy.
    """
    
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    # Initialize variables to track the best model and best validation accuracy
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_acc_month = 0  # Best validation accuracy for label_month classification
    best_val_loss_pr = float('inf')  # Best validation loss for label_pr regression

    # Initialize a dictionary to store training and validation metrics
    results = {
        "train_loss_pr": [],
        "train_loss_month": [],
        "train_loss_total": [],
        "train_rmse_pr": [],
        "train_acc_month": [],
        "val_loss_pr": [],
        "val_loss_month": [],
        "val_loss_total": [],
        "val_rmse_pr": [],
        "val_acc_month": []
    }

    # Move the model to the appropriate device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Define loss functions for each task
    criterion_pr = nn.MSELoss()  # Mean Squared Error for regression task
    criterion_month = nn.CrossEntropyLoss()  # Cross-Entropy Loss for classification task

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 30)
        
        # ----------------------------
        # Training Phase
        # ----------------------------
        model.train()  # Set model to training mode

        # Initialize metrics for training
        train_metrics = defaultdict(float)
        train_metrics['loss_pr'] = 0.0
        train_metrics['loss_month'] = 0.0
        train_metrics['loss_total'] = 0.0
        train_metrics['rmse_pr'] = 0.0
        train_metrics['acc_month'] = 0.0
        train_samples = 0  # Number of samples processed

        # Create a tqdm progress bar for the training loop
        train_bar = tqdm.tqdm(train_dataloader, desc="Training", leave=False)
        for images, (labels_pr, labels_month) in train_bar:
            # Move data to the appropriate device
            images = images.to(device)  # Input images
            labels_pr = labels_pr.to(device)  # Regression labels
            labels_month = labels_month.to(device)  # Classification labels

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass: compute model outputs
            outputs_pr, outputs_month = model(images)

            # Compute individual losses
            loss_pr = criterion_pr(outputs_pr, labels_pr)  # Regression loss
            loss_month = criterion_month(outputs_month, labels_month)  # Classification loss

            # Combine losses with specified weights for multi-task learning
            loss = 0.8 * loss_pr + 0.2 * loss_month

            # Backward pass and optimization step
            loss.backward()
            optimizer.step()

            # Compute metrics
            with torch.no_grad():
                # Calculate RMSE for regression
                rmse_pr = torch.sqrt(loss_pr).item()

                # Calculate accuracy for classification
                _, preds_month = torch.max(outputs_month, 1)
                acc_month = torch.sum(preds_month == labels_month).item() / labels_month.size(0)

            # Update training metrics
            train_metrics['loss_pr'] += loss_pr.item() * images.size(0)
            train_metrics['loss_month'] += loss_month.item() * images.size(0)
            train_metrics['loss_total'] += loss.item() * images.size(0)
            train_metrics['rmse_pr'] += rmse_pr * images.size(0)
            train_metrics['acc_month'] += acc_month * images.size(0)
            train_samples += images.size(0)

            # Update progress bar with current metrics
            train_bar.set_postfix({
                "Loss_PR": f"{train_metrics['loss_pr']/train_samples:.4f}",
                "Loss_Month": f"{train_metrics['loss_month']/train_samples:.4f}",
                "RMSE_PR": f"{train_metrics['rmse_pr']/train_samples:.4f}",
                "Acc_Month": f"{train_metrics['acc_month']/train_samples:.4f}"
            })

        # Compute average metrics for the epoch
        epoch_train_loss_pr = train_metrics['loss_pr'] / train_samples
        epoch_train_loss_month = train_metrics['loss_month'] / train_samples
        epoch_train_loss_total = train_metrics['loss_total'] / train_samples
        epoch_train_rmse_pr = train_metrics['rmse_pr'] / train_samples
        epoch_train_acc_month = train_metrics['acc_month'] / train_samples

        # Store training metrics
        results["train_loss_pr"].append(epoch_train_loss_pr)
        results["train_loss_month"].append(epoch_train_loss_month)
        results["train_loss_total"].append(epoch_train_loss_total)
        results["train_rmse_pr"].append(epoch_train_rmse_pr)
        results["train_acc_month"].append(epoch_train_acc_month)

        # ----------------------------
        # Validation Phase
        # ----------------------------
        model.eval()  # Set model to evaluation mode

        # Initialize metrics for validation
        val_metrics = defaultdict(float)
        val_metrics['loss_pr'] = 0.0
        val_metrics['loss_month'] = 0.0
        val_metrics['loss_total'] = 0.0
        val_metrics['rmse_pr'] = 0.0
        val_metrics['acc_month'] = 0.0
        val_samples = 0  # Number of samples processed

        # Lists to store true and predicted labels for computing confusion matrix and other metrics
        true_month = []
        pred_month = []

        # Create a tqdm progress bar for the validation loop
        val_bar = tqdm.tqdm(val_dataloader, desc="Validation", leave=False)
        for images, (labels_pr, labels_month) in val_bar:
            # Move data to the appropriate device
            images = images.to(device)  # Input images
            labels_pr = labels_pr.to(device)  # Regression labels
            labels_month = labels_month.to(device)  # Classification labels

            # Forward pass: compute model outputs
            with torch.no_grad():
                outputs_pr, outputs_month = model(images)

                # Compute individual losses
                loss_pr = criterion_pr(outputs_pr, labels_pr)  # Regression loss
                loss_month = criterion_month(outputs_month, labels_month)  # Classification loss

                # Combine losses with specified weights for multi-task learning
                loss = 0.8 * loss_pr + 0.2 * loss_month

                # Compute metrics
                rmse_pr = torch.sqrt(loss_pr).item()
                _, preds_month_batch = torch.max(outputs_month, 1)
                acc_month = torch.sum(preds_month_batch == labels_month).item() / labels_month.size(0)

            # Update validation metrics
            val_metrics['loss_pr'] += loss_pr.item() * images.size(0)
            val_metrics['loss_month'] += loss_month.item() * images.size(0)
            val_metrics['loss_total'] += loss.item() * images.size(0)
            val_metrics['rmse_pr'] += rmse_pr * images.size(0)
            val_metrics['acc_month'] += acc_month * images.size(0)
            val_samples += images.size(0)

            # Collect true and predicted labels for classification metrics
            true_month.extend(labels_month.cpu().numpy())
            pred_month.extend(preds_month_batch.cpu().numpy())

            # Update progress bar with current metrics
            val_bar.set_postfix({
                "Loss_PR": f"{val_metrics['loss_pr']/val_samples:.4f}",
                "Loss_Month": f"{val_metrics['loss_month']/val_samples:.4f}",
                "RMSE_PR": f"{val_metrics['rmse_pr']/val_samples:.4f}",
                "Acc_Month": f"{val_metrics['acc_month']/val_samples:.4f}"
            })

        # Compute average metrics for the epoch
        epoch_val_loss_pr = val_metrics['loss_pr'] / val_samples
        epoch_val_loss_month = val_metrics['loss_month'] / val_samples
        epoch_val_loss_total = val_metrics['loss_total'] / val_samples
        epoch_val_rmse_pr = val_metrics['rmse_pr'] / val_samples
        epoch_val_acc_month = val_metrics['acc_month'] / val_samples

        # Store validation metrics
        results["val_loss_pr"].append(epoch_val_loss_pr)
        results["val_loss_month"].append(epoch_val_loss_month)
        results["val_loss_total"].append(epoch_val_loss_total)
        results["val_rmse_pr"].append(epoch_val_rmse_pr)
        results["val_acc_month"].append(epoch_val_acc_month)

        # Calculate additional classification metrics for label_month
        cm_month = confusion_matrix(true_month, pred_month)
        precision_month = precision_score(true_month, pred_month, average="macro")
        recall_month = recall_score(true_month, pred_month, average="macro")
        f1_month = f1_score(true_month, pred_month, average="macro")

        # Print validation metrics
        print(f"Validation Loss PR: {epoch_val_loss_pr:.4f}, Loss Month: {epoch_val_loss_month:.4f}, Total Loss: {epoch_val_loss_total:.4f}")
        print(f"Validation RMSE PR: {epoch_val_rmse_pr:.4f}")
        print(f"Validation Accuracy Month: {epoch_val_acc_month:.4f}")
        print("Confusion Matrix for label_month:")
        print(cm_month)
        print(f"Precision Month: {precision_month:.4f}, Recall Month: {recall_month:.4f}, F1 Score Month: {f1_month:.4f}")

        # Step the scheduler
        scheduler.step()

        # ----------------------------
        # Save the Best Model
        # ----------------------------
        # Define a criterion to determine the best model
        # Here, we use validation accuracy for label_month
        if epoch_val_acc_month > best_val_acc_month:
            print("Saving the best model based on Validation Accuracy for label_month.")
            best_val_acc_month = epoch_val_acc_month
            best_val_loss_pr = epoch_val_loss_pr
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(best_model_wts, os.path.join(save_dir, "best_model.pth"))

            # Save confusion matrix and other metrics to a text file
            with open(os.path.join(save_dir, "metrics_best.txt"), "w") as f:
                f.write(f"Epoch: {epoch+1}\n")
                f.write(f"Validation Loss PR: {epoch_val_loss_pr:.4f}\n")
                f.write(f"Validation Loss Month: {epoch_val_loss_month:.4f}\n")
                f.write(f"Validation RMSE PR: {epoch_val_rmse_pr:.4f}\n")
                f.write(f"Validation Accuracy Month: {epoch_val_acc_month:.4f}\n")
                f.write("Confusion Matrix for label_month:\n")
                f.write(np.array2string(cm_month))
                f.write(f"\nPrecision Month: {precision_month:.4f}\n")
                f.write(f"Recall Month: {recall_month:.4f}\n")
                f.write(f"F1 Score Month: {f1_month:.4f}\n")

        # ----------------------------
        # Save Model Checkpoints Every 10 Epochs
        # ----------------------------
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(save_dir, f"epoch_{epoch+1}.pth")
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved checkpoint: {checkpoint_path}")

        # ----------------------------
        # Plot and Save Training Progress
        # ----------------------------
        # Plot Accuracy
        plt.figure()
        plt.plot(range(1, epoch+2), results["train_acc_month"], label="Train Accuracy Month")
        plt.plot(range(1, epoch+2), results["val_acc_month"], label="Validation Accuracy Month")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.title("Training and Validation Accuracy for label_month")
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, "accuracy.png"))
        plt.close()

        # Plot Loss for label_pr
        plt.figure()
        plt.plot(range(1, epoch+2), results["train_loss_pr"], label="Train Loss PR")
        plt.plot(range(1, epoch+2), results["val_loss_pr"], label="Validation Loss PR")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss for label_pr")
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, "loss_pr.png"))
        plt.close()

        # Plot Loss for label_month
        plt.figure()
        plt.plot(range(1, epoch+2), results["train_loss_month"], label="Train Loss Month")
        plt.plot(range(1, epoch+2), results["val_loss_month"], label="Validation Loss Month")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss for label_month")
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(save_dir, "loss_month.png"))
        plt.close()

    # Load best model weights
    model.load_state_dict(best_model_wts)
    print(f"\nTraining complete. Best Validation Accuracy Month: {best_val_acc_month:.4f}")

    return model


In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)


In [None]:
seed_torch(1024)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AlexNet()
# model = model.to(device)

# Load the state dictionary
model.load_state_dict(torch.load('checkpoints/train_1/best_model.pth'))

train_set = MyDataset(data_root=".", phase="train", normalize=False)
valid_set = MyDataset(data_root=".", phase="val", normalize=False)
train_dataloader = DataLoader(train_set, batch_size=30, shuffle=True, num_workers=0, drop_last=True)
val_dataloader = DataLoader(valid_set, batch_size=30, shuffle=False, num_workers=0, drop_last=True)

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_ft, T_max=30, eta_min=1e-6)

model = train_model(model, train_dataloader, val_dataloader, optimizer_ft, exp_lr_scheduler, num_epochs=30, save_dir="checkpoints/train_2")