In [2]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.cuda.amp import autocast, GradScaler
from torchvision import models
from torchvision.models.vgg import VGG
from tqdm.auto import tqdm
import rasterio
from rasterio.windows import Window
from collections import Counter
import random
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import WeightedRandomSampler
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from torchvision.transforms import ToTensor, Normalize

In [3]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

图像预处理 / Предварительная обработка изображений / Image preprocessing

In [None]:
class CustomTransform:
    def __call__(self, x):
        if random.random() > 0.5:
            x = torch.flip(x, [2])
        
        if random.random() > 0.5:
            x = torch.flip(x, [1])
        
        k = random.choice([0, 1, 2, 3])
        x = torch.rot90(x, k, [1, 2])

        if random.random() > 0.5:
            brightness_factor = random.uniform(0.9, 1.1)
            x = x * brightness_factor

        if random.random() > 0.5:
            noise = torch.randn_like(x) * 0.02  
            x = x + noise
        
        return x
transform = CustomTransform()

图像分割 / Сегментация изображений / Image segmentation

In [None]:
def split_images_and_labels(image_paths, label_paths, tile_size=(128, 128), overlap=0.1, ignore_label=0):
    stride = int(tile_size[0] * (1 - overlap))  
    image_tiles = []
    label_tiles = []

    for image_path, label_path in zip(image_paths, label_paths):
        with rasterio.open(image_path) as img:
            with rasterio.open(label_path) as lbl:
                for top in range(0, lbl.height - tile_size[1] + 1, stride):
                    for left in range(0, lbl.width - tile_size[0] + 1, stride):
                        window = Window(left, top, tile_size[0], tile_size[1])
                        img_tile = img.read(window=window)
                        lbl_tile = lbl.read(1, window=window)

                        if np.any(lbl_tile != ignore_label):
                            image_tiles.append(img_tile)
                            label_tiles.append(lbl_tile)

    return image_tiles, label_tiles

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_tiles, label_tiles, transform=None):
        self.image_tiles = image_tiles
        self.label_tiles = label_tiles
        self.transform = transform

    def __len__(self):
        return len(self.image_tiles)

    def __getitem__(self, idx):
        if idx >= len(self.image_tiles) or idx >= len(self.label_tiles):
            raise IndexError("Index out of range")
        image = self.image_tiles[idx].astype(np.float32)  
        image = torch.from_numpy(image)
        if self.transform:
            image = self.transform(image)

        if image.dim() == 4:  
            image = image.permute(0, 2, 1, 3)  

        label = self.label_tiles[idx].astype(np.int64)
        label = torch.from_numpy(label) 

        if label.ndim > 2:
            label = label.squeeze(0)  

        return image, label

数据集加载 / Загрузка набора данных / loading dataset

In [None]:
image_paths = [
    r'E:\Research\code\newdata\1.tiff',
    r'E:\Research\code\newdata\2.tiff',
    r'E:\Research\code\newdata\3.tiff',
    r'E:\Research\code\newdata\4.tiff',
    r'E:\Research\code\newdata\5.tiff',
    r'E:\Research\code\newdata\6.tiff',
]

label_paths = [
    r'E:\Research\code\newdata\1_mask.tiff',
    r'E:\Research\code\newdata\2_mask.tiff',
    r'E:\Research\code\newdata\3_mask.tiff',
    r'E:\Research\code\newdata\4_mask.tiff',
    r'E:\Research\code\newdata\5_mask.tiff',
    r'E:\Research\code\newdata\6_mask.tiff',
]

image_tiles, label_tiles = split_images_and_labels(image_paths, label_paths)

print("Loaded image tiles:", len(image_tiles))
print("Loaded label tiles:", len(label_tiles))

num_classes = 7

full_dataset = CustomDataset(image_tiles, label_tiles, transform=transform)

dataset_size = len(full_dataset) 
indices = list(range(dataset_size))
np.random.shuffle(indices)

train_size = int(0.8 * dataset_size) 
train_indices = indices[:train_size] 
val_indices = indices[train_size:]

class_counts = np.zeros(num_classes, dtype=np.int32)

for labels in label_tiles:
    labels = labels.flatten() 
    labels = labels.astype(np.int64)
    counts = np.bincount(labels, minlength=num_classes)
    class_counts += counts

class_weights = 1.0 / class_counts
class_weights[np.isinf(class_weights)] = 0  

sample_weights = []
for labels in label_tiles:
    labels = labels.flatten()
    labels = labels.astype(np.int64) 
    weights = class_weights[labels]
    sample_weights.append(weights)

sample_weights = np.concatenate(sample_weights)

sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

train_sample_weights = [sample_weights[i] for i in train_indices]

train_sampler = WeightedRandomSampler(train_sample_weights, len(train_sample_weights), replacement=True)

train_loader = DataLoader(Subset(full_dataset, train_indices), batch_size=16, sampler=train_sampler, drop_last=True)

val_loader = DataLoader(Subset(full_dataset, val_indices), batch_size=16, shuffle=True)

valid_data = [(data, label) for data, label in DataLoader(Subset(full_dataset, val_indices), batch_size=1)]
torch.save(valid_data, 'new_valid_data.pth')

for images, labels in train_loader:
    print(images.shape, labels.shape)  

print("Label max:", labels.max())
print("Label min:", labels.min())

def print_batch_class_distribution(loader, num_classes):
    for images, labels in loader:
        labels_flattened = labels.view(-1).numpy()
        class_counts = Counter(labels_flattened)
        distribution = {k: class_counts.get(k, 0) for k in range(num_classes)}
        print(f"Batch class distribution: {distribution}")

print_batch_class_distribution(train_loader, num_classes)

FCN


In [9]:
class FCN10Channel(nn.Module):
    def __init__(self, num_channels=4, num_classes=7):
        super(FCN10Channel, self).__init__()
        vgg = models.vgg16(pretrained=True)
        features = list(vgg.features.children())
        
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.7)

        features[0] = nn.Conv2d(4, 64, kernel_size=3, padding=1)
        self.features = nn.Sequential(*features)

        for layer in self.features[:-6]:
            for param in layer.parameters():
                param.requires_grad = False
        
        self.fcn = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )
        
        self.upsample = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, stride=32, padding=16)

    def forward(self, x):
        x = self.features(x) 
        x = self.fcn(x)     
        x = self.upsample(x)  
        return x


早停 / Ранняя остановка / Early stopping

In [10]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

GPU / Определите, можно ли использовать Cuda / To see if Cuda can be used

In [None]:
if torch.cuda.is_available():
    print("CUDA is available. GPU is used now.")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Using CPU.")
    device = torch.device("cpu")

模型训练  / обучение модели / model training

In [None]:
model = FCN10Channel(num_channels=4, num_classes=7).to(device)

pretrained_dict = torch.load('best_fcn.pth')

model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}

model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.to(device)

def save_model(model, path):
    torch.save(model.state_dict(), path)

early_stopping = EarlyStopping(patience=5, min_delta=0.0001)

writer = SummaryWriter()

optimizer = optim.Adam(model.parameters(), lr=1e-6, weight_decay=1e-6)

criterion = nn.CrossEntropyLoss(ignore_index=-1)
scaler = GradScaler()
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.1, verbose=True, min_lr=1e-6)

train_losses = []
val_losses = []
overall_accuracies = []
precision_scores = []
f1_scores = []
recall_scores = []
best_val_loss = float('inf') 

epochs = 100
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', leave=True):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    if (epoch +1) % 2 == 0:
        model.eval()

        val_loss = 0.0
        correct_pixels = 0
        total_pixels = 0
        all_predictions = []
        all_targets = []

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() 
                _, predicted = torch.max(outputs, 1)
                correct_pixels += (predicted == labels).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
                total_pixels += labels.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels
                all_predictions.append(predicted.cpu().numpy())
                all_targets.append(labels.cpu().numpy())

        # 使用np.concatenate来合并列表中的所有数组，然后进行扁平化处理
        all_predictions_flattened = np.concatenate(all_predictions).reshape(-1)
        all_targets_flattened = np.concatenate(all_targets).reshape(-1)

        val_loss /= len(val_loader)  # 计算平均验证损失 / Вычисление средних потерь валидации / Calculating average validation loss
        overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy
        precision = precision_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算精确率 / Вычисление точности / Calculating precision
        recall = recall_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算召回率 / Вычисление полноты / Calculating recall
        f1 = f1_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算F1 / Вычисление F1 / Calculating F1

        print(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}, Overall Accuracy: {overall_accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        train_losses.append(train_loss)  
        val_losses.append(val_loss)  
        overall_accuracies.append(overall_accuracy) 
        precision_scores.append(precision)  
        recall_scores.append(recall)  
        f1_scores.append(f1)  

        # 记录到TensorBoard / Запись в TensorBoard / Recording to TensorBoard
        for name, param in model.named_parameters():
            writer.add_histogram(f'Weights/{name}', param, epoch)
            if param.grad is not None:
                writer.add_histogram(f'Gradients/{name}', param.grad, epoch)
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/overall', overall_accuracy, epoch)
        writer.add_scalar('Precision', precision, epoch)
        writer.add_scalar('Recall', recall, epoch)
        writer.add_scalar('F1', f1, epoch)
        writer.add_scalar('Learning rate', optimizer.param_groups[0]['lr'], epoch)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_model(model, 'best_new_fcn_state_dict.pth')
            print(f"Model saved at Epoch {epoch+1}: Improved validation loss to {best_val_loss:.4f}")

        scheduler.step(val_loss)  # 更新学习率 / Обновление скорости обучения / Updating the learning rate
        early_stopping(val_loss)  # 检查是否需要提前停止 / Проверка на необходимость досрочного завершения / Checking if early stopping criterion is met
        if early_stopping.early_stop:
            print("Early stopping")
            break

writer.close()


计算mIoU / Рассчитать mIoU / Calculate mIoU

In [13]:
def calculate_iou(predicted, target, num_classes):
    iou_list = []
    for cls in range(num_classes):
        pred_inds = predicted == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0:
            # 避免除以0 / Избегание деления на 0 / Avoiding division by zero
            iou_list.append(float('nan'))  # 该类别未出现在预测和目标中 / Этот класс не появляется в прогнозе и цели / This class does not appear in the prediction and target
        else:
            iou_list.append(intersection / union)
    # 忽略nan值计算平均IoU / Игнорирование значений nan при вычислении среднего IoU / Ignoring nan values when calculating mean IoU
    iou_list = [x for x in iou_list if not np.isnan(x)]
    mean_iou = sum(iou_list) / len(iou_list) if iou_list else float('nan')
    return mean_iou

# 模型验证和计算Mean IoU / Проверка модели и вычисление среднего IoU / Model validation and calculating Mean IoU
def validate_and_calculate_iou(model, loader, device, num_classes):
    model.eval()
    total_iou = 0.0
    correct_pixels = 0
    total_pixels = 0
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_iou += calculate_iou(predicted, labels, num_classes)
            # 计算准确率 / Вычисление точности / Calculating accuracy
            correct_pixels += (predicted == labels).sum().item()  # 累加正确预测的像素数 / Накопление количества правильно предсказанных пикселей / Accumulating the number of correctly predicted pixels
            total_pixels += labels.nelement()  # 累加总像素数 / Накопление общего количества пикселей / Accumulating the total number of pixels
            all_predictions.append(predicted.cpu().numpy())
            all_targets.append(labels.cpu().numpy())

    # 扁平化预测和目标张量 / Плоскость тензоров предсказаний и целей / Flattening the prediction and target tensors
    all_predictions_flattened = np.concatenate(all_predictions).reshape(-1)
    all_targets_flattened = np.concatenate(all_targets).reshape(-1)

    mean_iou = total_iou / len(loader)
    overall_accuracy = correct_pixels / total_pixels  # 计算总体准确率 / Вычисление общей точности / Calculating overall accuracy
    precision = precision_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算精确率 / Вычисление точности / Calculating precision
    recall = recall_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算召回率 / Вычисление полноты / Calculating recall
    f1 = f1_score(all_targets_flattened, all_predictions_flattened, average='macro', zero_division=0)  # 计算Fs1 / Вычисление F1 / Calculating F1
    print(f"Mean IoU on validation set: {mean_iou}, Overall Accuracy: {overall_accuracy:.4f}", f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

验证 / Проверить модель / Validate model

In [None]:
num_classes = 7

# 调用验证函数 / Вызов функции валидации / Calling the validation function
validate_and_calculate_iou(model, val_loader, device, num_classes)

# 计算模型参数数量 / Вычисление количества параметров модели / Calculating the number of model parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total trainable parameters: {total_params}")