In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from dataloader import PASTIS_Dataset
from collate import pad_collate
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchsummary import summary

Порядок комментариев к коду: китайский / русский / английский.

双卷积 / двойная свертка / double convolution

In [33]:
class DoubleConv(nn.Module):
    # (convolution => [BN] => ReLU) * 2
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

注意力 / механизм внимания / attention mechanism

In [34]:
class AttentionBlock(nn.Module):
    # 注意力模块-FPNUNet / Блок внимания для FPNUNet(Пока не используется) / Attention Block for FPNUNet

    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)


    def forward(self, g, x):
        g1 = self.W_g(g)
         # 在相加前上采样g以匹配x的尺寸 / Перед сложением g повышается до размеров x / Upsample g to match the size of x before adding
        g1 = F.interpolate(g1, size=x.shape[2:], mode='bilinear', align_corners=True)

        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        
        return x * psi

金字塔U-NET/ Модель FPN-UNet / FPN-UNet model

In [35]:
class FPNUNet(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.5):
        super().__init__()
        self.initial_conv = nn.Conv2d(300, 64, kernel_size=1)
        # 编码器 / Кодировщик / Encoder
        """
        使用了4层编码器 / Используется 4 слоя кодировщика / 4 layers of encoder
        正则化设定0.5 / Настройка регуляризации 0.5 / Regularization setting 0.5
        """
        self.encoder1 = DoubleConv(64, 64)
        self.droupout1 = nn.Dropout(dropout_rate)
        self.encoder2 = DoubleConv(64, 128)
        self.droupout2 = nn.Dropout(dropout_rate)
        self.encoder3 = DoubleConv(128, 256)
        self.droupout3 = nn.Dropout(dropout_rate)
        self.encoder4 = DoubleConv(256, 512)
        self.droupout4 = nn.Dropout(dropout_rate)

        # 解码器 / Декодер / Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(512, 256)
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(256, 128)
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(128, 64)
        self.upconv1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.decoder1 = DoubleConv(128, 64)
        # FPN桥 / Мост FPN / FPN Bridge
        self.fpn_bridge = DoubleConv(512, 512)
        self.fpn_bridge_dropout = nn.Dropout(dropout_rate)
        # 最终分类器 / Финальный классификатор / Final classifier
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
        # 注意力层（可选） / Слой внимания (опционально) / Attention Layer (Optional)
        self.attention1 = AttentionBlock(F_g=512, F_l=512, F_int=256)
        self.attention2 = AttentionBlock(F_g=256, F_l=256, F_int=128)
        self.attention3 = AttentionBlock(F_g=128, F_l=128, F_int=64)
        self.attention4 = AttentionBlock(F_g=64, F_l=64, F_int=32)

    def forward(self, x):
        x = self.initial_conv(x)
        # 编码 / Путь кодировщика / Encoder path
        x1 = self.droupout1(self.encoder1(x))
        x2 = self.droupout2(self.encoder2(x1))
        x3 = self.droupout3(self.encoder3(x2))
        x4 = self.droupout4(self.encoder4(x3))
       
        # FPN桥 / Мост FPN / FPN Bridge
        """
        FPN桥连接了编码器和解码器 / Мост FPN соединяет кодировщик и декодер / FPN Bridge connects the encoder and decoder
        FPN桥的作用额外增加了一层正则化 / Роль моста FPN дополнительно увеличивает уровень регуляризации / The role of FPN bridge additionally increases the level of regularization
        """
        x_bridge = self.fpn_bridge_dropout(self.fpn_bridge(x4))
        
        # 解码 / Путь декодера / Decoder path
        x = self.upconv4(x_bridge)
        # 使用双线性插值上采样 / Восстановление с использованием билинейной интерполяции / Upsampling using bilinear interpolation
        x = F.interpolate(x, size=x3.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat((x, x3), dim=1)
        # x = self.attention1(g=x, x=x4)
        x = self.decoder4(x)
        
        x = self.upconv3(x)
        x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat((x, x2), dim=1)
        # x = self.attention2(g=x, x=x3)
        x = self.decoder3(x)

        x = self.upconv2(x)
        x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat((x, x1), dim=1)
        x = self.decoder2(x)

        x = self.upconv1(x)
        x1_upsampled = F.interpolate(x1, size=x.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat((x, x1_upsampled), dim=1) 
        x = self.decoder1(x)

        x = F.interpolate(x, size=(128, 128), mode='bilinear', align_corners=False)
        # Final classification
        x = self.final_conv(x)
        return x

GPU / Определите, можно ли использовать Cuda / To see if Cuda can be used

In [36]:
if torch.cuda.is_available():
    print("CUDA is available. GPU support enabled.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

CUDA is available. GPU support enabled.


数据集加载 / Загрузка набора данных / loading dataset

In [37]:
# 获取并处理数据集 / Получение и обработка набора данных / Getting and processing the dataset
path_to_dataset = '/content/drive/MyDrive/Colab Notebooks/data/PASTIS'
dataset = PASTIS_Dataset(path_to_dataset, norm=True, target='semantic') # 使用语义分割标签 / Использование меток семантической сегментации / Using semantic segmentation labels
subset_indices = torch.randperm(len(dataset))[:1500].tolist()
subset_dataset = Subset(dataset, subset_indices)

# 划分训练集和验证集 / Разделение на обучающий и проверочный наборы / Splitting into training and validation sets
train_size = int(0.8 * len(subset_dataset))
valid_size = len(subset_dataset) - train_size
train_dataset, valid_dataset = random_split(subset_dataset, [train_size, valid_size])

# 创建 DataLoader / Создание DataLoader / Creating DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=pad_collate, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=4, collate_fn=pad_collate)

# 类别数 / Количество классов / Number of classes
num_classes = 20

SyntaxError: invalid syntax. Perhaps you forgot a comma? (3035178815.py, line 8)

早停 / Ранняя остановка / Early stopping

In [31]:
"""
patience: 训练过程中没有改进的次数。
patience: Количество раз, когда обучение не улучшается.
patience: The number of times the training does not improve.

min_delta: 被认为是改进的最小变化量。
min_delta: Минимальное изменение, которое считается улучшением.
min_delta: The minimum change that is considered an improvement.
"""
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

训练模型 / обучение модели / model training

In [28]:
# 初始化模型和优化器 / Инициализация модели и оптимизатора / Initializing the model and optimizer
model = FPNUNet(num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=3, min_delta=0.01)

# 初始化学习率调度器 / Инициализация планировщика скорости обучения / Initializing the learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1, verbose=True, min_lr=1e-6)

# 训练循环 / Цикл обучения / Training loop
epochs = 30 # 训练周期 / Эпохи обучения / Training epochs
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for batch_idx, batch_data in tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', leave=False):
        ((inputs_dict, dates), targets) = batch_data
        # 将三十个时间点合并 / Объединение тридцати временных точек / Combining thirty time points
        inputs_combined = torch.cat([inputs_dict['S2'][:, i, :, :, :] for i in range(30)], dim=1).to(device) 
        targets = targets.to(device).long()

        optimizer.zero_grad()
        outputs = model(inputs_combined)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  # 累加训练损失 / Накопление потерь обучения / Accumulating training loss

    train_loss /= len(train_loader)  # 计算平均训练损失 / Вычисление средних потерь обучения / Calculating average training loss

    # 验证阶段 / Валидация / Validation phase
    model.eval()  # 设置模型为评估模式 / Установка модели в режим оценки / Setting the model to evaluation mode
    val_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    with torch.no_grad():  # 在这个阶段不计算梯度 / На этом этапе градиенты не вычисляются / Gradients are not calculated at this stage
        for batch_data in valid_loader:
            ((inputs_dict, dates), targets) = batch_data
            inputs_combined = torch.cat([inputs_dict['S2'][:, i, :, :, :] for i in range(30)], dim=1).to(device)
            targets = targets.to(device).long()

            outputs = model(inputs_combined)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item()  # 累加验证损失 / Накопление потерь валидации / Accumulating validation loss
            # 计算准确率 /  Вычисление точности / Calculating accuracy
            _, predicted = torch.max(outputs, 1)  # 获取最大概率的预测结果 / Получение предсказанных результатов с максимальной вероятностью / Getting predicted results with maximum probability
            correct_pixels += (predicted == targets).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
            total_pixels += targets.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels

    val_loss /= len(valid_loader)  # 计算平均验证损失 / Вычисление средних потерь валидации / Calculating average validation loss
    overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy

    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}, Overall Accuracy: {overall_accuracy:.4f}")
    
    # 在这里调用学习率调度器，基于验证损失 / Вызов планировщика скорости обучения на основе потерь валидации / Calling the learning rate scheduler here, based on validation loss
    scheduler.step(val_loss)

    # 检查是否需要早停 / Проверка на необходимость досрочной остановки / Checking if early stopping is needed
    early_stopping(loss)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break



Epoch 1/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 1/30, Training Loss: 1.6537331926822663, Validation Loss: 1.2932375892003378, Overall Accuracy: 0.6040


Epoch 2/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 2/30, Training Loss: 1.342187801003456, Validation Loss: 1.1917772380510967, Overall Accuracy: 0.6375


Epoch 3/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 3/30, Training Loss: 1.2566123658418655, Validation Loss: 1.1145870892206828, Overall Accuracy: 0.6558


Epoch 4/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 4/30, Training Loss: 1.2027760229508082, Validation Loss: 1.123127597173055, Overall Accuracy: 0.6425
EarlyStopping counter: 1 out of 3


Epoch 5/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 5/30, Training Loss: 1.1574705028533936, Validation Loss: 1.04104363600413, Overall Accuracy: 0.6640


Epoch 6/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 6/30, Training Loss: 1.116131874124209, Validation Loss: 1.0366944082578022, Overall Accuracy: 0.6710
EarlyStopping counter: 1 out of 3


Epoch 7/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 7/30, Training Loss: 1.0964794039726258, Validation Loss: 1.0506828173001608, Overall Accuracy: 0.6597
EarlyStopping counter: 2 out of 3


Epoch 8/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 8/30, Training Loss: 1.0720294284820557, Validation Loss: 0.9907202394803365, Overall Accuracy: 0.6769


Epoch 9/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 9/30, Training Loss: 1.0548770779371262, Validation Loss: 0.9840745250384013, Overall Accuracy: 0.6798


Epoch 10/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 10/30, Training Loss: 1.045912851492564, Validation Loss: 0.9822000702222188, Overall Accuracy: 0.6813
EarlyStopping counter: 1 out of 3


Epoch 11/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 11/30, Training Loss: 1.031415196855863, Validation Loss: 0.9667937167485555, Overall Accuracy: 0.6845


Epoch 12/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 12/30, Training Loss: 1.0193106210231782, Validation Loss: 0.9627852114041646, Overall Accuracy: 0.6863
EarlyStopping counter: 1 out of 3


Epoch 13/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 13/30, Training Loss: 1.0043124649922053, Validation Loss: 0.9636413995424906, Overall Accuracy: 0.6864
EarlyStopping counter: 2 out of 3


Epoch 14/30:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 14/30, Training Loss: 1.0004757742087047, Validation Loss: 0.9452791961034139, Overall Accuracy: 0.6920
EarlyStopping counter: 3 out of 3
Early stopping triggered.


计算mIoU / Рассчитать mIoU / Calculate mIoU

In [29]:
def calculate_iou(predicted, target, num_classes):
    """
    计算平均IoU，对每个类别计算IoU，然后取平均值。
    IoU = TP / (TP + FP + FN) 指的是交集与并集的比值。
    Вычисление среднего IoU, вычисление IoU для каждого класса, а затем взятие среднего значения.
    IoU = TP / (TP + FP + FN) отношение пересечения к объединению.
    Calculate mean IoU, calculate IoU for each class, then take the average.
    IoU = TP / (TP + FP + FN) the ratio of intersection to union.
    TP: True Positive, FP: False Positive, FN: False Negative
    """
    iou_list = []
    for cls in range(num_classes):
        pred_inds = predicted == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0:
            # 避免除以0 / Избегание деления на 0 / Avoiding division by zero
            iou_list.append(float('nan'))  # 该类别未出现在预测和目标中 / Этот класс не появляется в прогнозе и цели / This class does not appear in the prediction and target
        else:
            iou_list.append(intersection / union)
    # 忽略nan值计算平均IoU / Игнорирование значений nan при вычислении среднего IoU / Ignoring nan values when calculating mean IoU
    iou_list = [x for x in iou_list if not np.isnan(x)]
    mean_iou = sum(iou_list) / len(iou_list) if iou_list else float('nan')
    return mean_iou

# 模型验证和计算Mean IoU / Проверка модели и вычисление среднего IoU / Model validation and calculating Mean IoU
def validate_and_calculate_iou(model, loader, device, num_classes):
    model.eval()
    total_iou = 0.0
    correct_pixels = 0
    total_pixels = 0
    with torch.no_grad():
        for ((inputs_dict, dates), targets) in loader:
            inputs_combined = torch.cat([inputs_dict['S2'][:, i, :, :, :] for i in range(30)], dim=1).to(device)
            targets = targets.to(device).long()

            outputs = model(inputs_combined)
            _, predicted = torch.max(outputs, 1)
            total_iou += calculate_iou(predicted, targets, num_classes)
            # 计算准确率 / Вычисление точности / Calculating accuracy
            _, predicted = torch.max(outputs, 1)  # 获取最大概率的预测结果 / Получение предсказанных результатов с максимальной вероятностью / Getting predicted results with maximum probability
            correct_pixels += (predicted == targets).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
            total_pixels += targets.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels

    mean_iou = total_iou / len(loader)
    overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy
    print(f"Mean IoU on validation set: {mean_iou}, Overall Accuracy: {overall_accuracy:.4f}")

验证 / Проверить модель / Validate model

In [30]:
# 调用验证函数 / Вызов функции валидации / Calling the validation function
validate_and_calculate_iou(model, valid_loader, device, num_classes)

# 计算模型参数数量 / Вычисление количества параметров модели / Calculating the number of model parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total trainable parameters: {total_params}")


Mean IoU on validation set: 0.26635117980806505, Overall Accuracy: 0.6920
Total trainable parameters: 12959232
