In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
from tqdm import tqdm
from torchvision.models.segmentation import deeplabv3_resnet50
from torch.optim.lr_scheduler import ReduceLROnPlateau
from dataloader import PASTIS_Dataset
from collate import pad_collate

Порядок комментариев к коду: китайский / русский / английский.

数据集加载 / Загрузка набора данных / loading dataset

In [None]:
# 获取并处理数据集 / Получение и обработка набора данных / Getting and processing the dataset
path_to_dataset = '/content/drive/MyDrive/Colab Notebooks/data/PASTIS'
dataset = PASTIS_Dataset(path_to_dataset, norm=True, target='semantic') # 适用语义分割任务 / Для задач семантического сегментирования / For semantic segmentation tasks
subset_indices = torch.randperm(len(dataset))[:2430].tolist()
subset_dataset = Subset(dataset, subset_indices)

# 划分训练集和验证集 / Разделение на обучающий и проверочный наборы / Splitting into training and validation sets
train_size = int(0.8 * len(subset_dataset))
valid_size = len(subset_dataset) - train_size
train_dataset, valid_dataset = random_split(subset_dataset, [train_size, valid_size])

# 创建 DataLoader / Создание DataLoader / Creating DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=pad_collate, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=4, collate_fn=pad_collate)

# 类别数 / Количество классов / Number of classes
num_classes = 20

GPU / Определите, можно ли использовать Cuda / To see if Cuda can be used

In [None]:
if torch.cuda.is_available():
    print("CUDA is available. GPU support enabled.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

早停 / Ранняя остановка / Early stopping

In [None]:
"""
patience: 训练过程中没有改进的次数。
patience: Количество раз, когда обучение не улучшается.
patience: The number of times the training does not improve.

min_delta: 被认为是改进的最小变化量。
min_delta: Минимальное изменение, которое считается улучшением.
min_delta: The minimum change that is considered an improvement.
"""
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

减少deeplab输入通道数 / Уменьшение количества входных каналов deeplab / Reducing the number of input channels for deeplab

In [None]:
"""
由于deepLabv3模型的初始通道数是3，而数据集的通道数是10个卫星通道*30个时间点=300，所以需要修改模型第一层的通道数。
Поскольку исходное количество каналов модели deepLabv3 равно 3, а количество каналов набора данных равно 10 каналам спутника * 30 временным точкам = 300, необходимо изменить количество каналов первого слоя модели.
Since the initial number of channels of the deepLabv3 model is 3, and the number of channels of the dataset is 10 satellite channels * 30 time points = 300, the number of channels of the first layer of the model needs to be modified.
"""
def reduce_channels(model, in_channels=300):
  deeplab_first_conv = model.backbone.conv1
  new_first_conv = nn.Conv2d(in_channels, deeplab_first_conv.out_channels, kernel_size=deeplab_first_conv.kernel_size, stride=deeplab_first_conv.stride, padding=deeplab_first_conv.padding, bias=False)
  model.backbone.conv1 = new_first_conv
  return model

模型训练  / обучение модели / model training

In [None]:
# 初始化 DeepLab 模型和优化器 / Инициализация модели DeepLab и оптимизатора / Initializing the DeepLab model and optimizer
deeplab_model = deeplabv3_resnet50(pretrained=False, num_classes=num_classes).to(device)
deeplab_model = reduce_channels(deeplab_model, in_channels=300)
model = deeplab_model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=3, min_delta=0.01)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1, verbose=True, min_lr=1e-6)

epochs = 30
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for batch_idx, batch_data in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', leave=False):
        ((inputs_dict, dates), targets) = batch_data
        inputs_combined = torch.cat([inputs_dict['S2'][:, i, :, :, :] for i in range(30)], dim=1).to(device)
        targets = targets.to(device).long()

        optimizer.zero_grad()
        outputs = model(inputs_combined)['out']
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # 验证阶段 / Валидация / Validation
    model.eval()  # 设置模型为评估模式 / Установка модели в режим оценки / Setting the model to evaluation mode
    val_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    with torch.no_grad():  # 在这个阶段不计算梯度 / На этом этапе не вычисляются градиенты / Gradients are not calculated at this stage
        for batch_data in valid_loader:
            ((inputs_dict, dates), targets) = batch_data
            inputs_combined = torch.cat([inputs_dict['S2'][:, i, :, :, :] for i in range(30)], dim=1).to(device)
            targets = targets.to(device).long()

            outputs = model(inputs_combined)['out']
            loss = criterion(outputs, targets)

            val_loss += loss.item()  # 累加验证损失 / Накопление проверочной потери / Accumulating validation loss
            # 计算准确率 / Вычисление точности / Calculating accuracy
            _, predicted = torch.max(outputs, 1)  # 获取最大概率的预测结果 / Получение предсказанного результата с максимальной вероятностью / Getting the predicted result with the maximum probability
            correct_pixels += (predicted == targets).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
            total_pixels += targets.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels

    val_loss /= len(valid_loader)  # 计算平均验证损失 / Вычисление средней проверочной потери / Calculating the average validation loss
    overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy

    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}, Overall Accuracy: {overall_accuracy:.4f}")

    # 在这里调用学习率调度器，基于验证损失 / Вызов планировщика скорости обучения на основе проверочной потери / Calling the learning rate scheduler here, based on the validation loss
    scheduler.step(val_loss)

    # 检查是否需要早停 / Проверка на необходимость досрочной остановки / Checking if early stopping is needed
    early_stopping(loss)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break


计算mIoU / Рассчитать mIoU / Calculate mIoU

In [None]:
def calculate_iou(predicted, target, num_classes):
    """
    计算平均IoU，对每个类别计算IoU，然后取平均值。
    IoU = TP / (TP + FP + FN) 指的是交集与并集的比值。
    Вычисление среднего IoU, вычисление IoU для каждого класса, а затем взятие среднего значения.
    IoU = TP / (TP + FP + FN) отношение пересечения к объединению.
    Calculate mean IoU, calculate IoU for each class, then take the average.
    IoU = TP / (TP + FP + FN) the ratio of intersection to union.
    """
    iou_list = []
    for cls in range(num_classes):
        pred_inds = predicted == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0:
            # 避免除以0 / Избегание деления на 0 / Avoiding division by 0
            iou_list.append(float('nan'))  # 该类别未出现在预测和目标中 / Этот класс не появляется в прогнозе и цели / This class does not appear in the prediction and target
        else:
            iou_list.append(intersection / union)
    # 忽略nan值计算平均IoU / Игнорирование значений nan при вычислении среднего IoU / Ignoring nan values when calculating the mean IoU
    iou_list = [x for x in iou_list if not np.isnan(x)]
    mean_iou = sum(iou_list) / len(iou_list) if iou_list else float('nan')
    return mean_iou

# 模型验证和计算Mean IoU /  Проверка модели и вычисление среднего IoU / Model validation and calculation of Mean IoU
def validate_and_calculate_iou(model, loader, device, num_classes):
    model.eval()
    total_iou = 0.0
    correct_pixels = 0
    total_pixels = 0
    with torch.no_grad():
        for ((inputs_dict, dates), targets) in loader:
            inputs_combined = torch.cat([inputs_dict['S2'][:, i, :, :, :] for i in range(30)], dim=1).to(device)
            targets = targets.to(device).long()

            outputs = model(inputs_combined)['out']
            _, predicted = torch.max(outputs, 1)
            total_iou += calculate_iou(predicted, targets, num_classes)
            # 计算准确率 / Вычисление точности / Calculating accuracy
            _, predicted = torch.max(outputs, 1)  # 获取最大概率的预测结果 / Получение предсказанного результата с максимальной вероятностью / Getting the predicted result with the maximum probability
            correct_pixels += (predicted == targets).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
            total_pixels += targets.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels

    mean_iou = total_iou / len(loader)
    overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy
    print(f"Mean IoU on validation set: {mean_iou}, Overall Accuracy: {overall_accuracy:.4f}")

验证 / Проверить модель / Validate model

In [None]:
# 调用验证函数 / Вызов функции проверки / Calling the validation function
validate_and_calculate_iou(model, valid_loader, device, num_classes)

# 计算模型参数数量 / Вычисление количества параметров модели / Calculating the number of model parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total trainable parameters: {total_params}")