In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt
import albumentations as albu
import torch
import segmentation_models_pytorch as smp
from timm.layers.activations import Swish
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import segmentation_models_pytorch.utils
import ssl
ssl._create_default_https_context = ssl._create_unverified_context


# ---------------------------------------------------------------
### 加载数据

class Dataset(BaseDataset):
    """CamVid数据集。进行图像读取，图像增强增强和图像预处理.

    Args:
        images_dir (str): 图像文件夹所在路径
        masks_dir (str): 图像分割的标签图像所在路径
        class_values (list): 用于图像分割的所有类别数
        augmentation (albumentations.Compose): 数据传输管道
        preprocessing (albumentations.Compose): 数据预处理
    """
    # CamVid数据集中用于图像分割的所有标签类别
    CLASSES = ['sky', 'building', 'pole', 'road', 'pavement',
               'tree', 'signsymbol', 'fence', 'car',
               'pedestrian', 'bicyclist', 'unlabelled']

    def __init__(
            self,
            images_dir,
            masks_dir,
            classes=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)

        # 获取图像当前高度和宽度
        h, w = image.shape[:2]
        # 计算需要填充的高度和宽度，使其能被32整除
        new_h = (h // 32 + 1) * 32 if h % 32!= 0 else h
        new_w = (w // 32 + 1) * 32 if w % 32!= 0 else w

        # 对图像和掩码进行填充，保持尺寸一致
        if new_h!= h or new_w!= w:
            image = cv2.copyMakeBorder(image, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=0)
            mask = cv2.copyMakeBorder(mask, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=0)

        # 从标签中提取特定的类别 (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')

        # 图像增强应用
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # 图像预处理应用
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.ids)


# ---------------------------------------------------------------
### 图像增强

def get_training_augmentation():
    train_transform = [
        albu.HorizontalFlip(p=0.5),
        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),
        # albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0, value=0),
        albu.PadIfNeeded(min_height=320, min_width=320, border_mode=0),
        # albu.RandomCrop(height=320, width=320, always_apply=True),
        albu.RandomCrop(height=320, width=320),
        albu.GaussNoise(p=0.2),  # 已经修改过
        albu.Perspective(p=0.5),  # 修改这一行
        # albu.IAAPerspective(p=0.5),
        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightnessContrast(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.Sharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),
        albu.OneOf(
            [
                albu.RandomBrightnessContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """调整图像使得图片的分辨率长宽能被32整除"""
    def _pad_to_divisible(image, **kwargs):
        h, w = image.shape[:2]
        new_h = (h // 32 + 1) * 32 if h % 32!= 0 else h
        new_w = (w // 32 + 1) * 32 if w % 32!= 0 else w
        return cv2.copyMakeBorder(image, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=0)

    test_transform = [
        albu.Lambda(image=_pad_to_divisible),
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """进行图像预处理操作

    Args:
        preprocessing_fn (callbale): 数据规范化的函数
            (针对每种预训练的神经网络)
    Return:
        transform: albumentations.Compose
    """
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)


def main():
    # 数据集所在的目录
    DATA_DIR = './CamVid/'

    # 如果目录下不存在CamVid数据集，则克隆下载
    if not os.path.exists(DATA_DIR):
        print('Loading data...')
        os.system('git clone https://github.com/alexgkendall/SegNet-Tutorial./data')
        print('Done!')

    # 训练集
    x_train_dir = os.path.join(DATA_DIR, 'train')
    y_train_dir = os.path.join(DATA_DIR, 'trainannot')

    # 验证集
    x_valid_dir = os.path.join(DATA_DIR, 'val')
    y_valid_dir = os.path.join(DATA_DIR, 'valannot')

    ENCODER = 'se_resnext50_32x4d'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = ['car']
    ACTIVATION = 'sigmoid'  # could be None for logits or 'softmax2d' for multiclass segmentation

    device = torch.device("cuda")

    # 用预训练编码器建立分割模型
    # 使用FPN模型
    # model = smp.FPN(
    #     encoder_name=ENCODER,
    #     encoder_weights=ENCODER_WEIGHTS,
    #     classes=len(CLASSES),
    #     activation=ACTIVATION,
    # )
    # 使用unet++模型
    model = smp.UnetPlusPlus(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

    # 加载训练数据集
    train_dataset = Dataset(
        x_train_dir,
        y_train_dir,
        augmentation=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CLASSES,
    )

    # 加载验证数据集
    valid_dataset = Dataset(
        x_valid_dir,
        y_valid_dir,
        augmentation=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CLASSES,
    )

    # 需根据显卡的性能进行设置，batch_size为每次迭代中一次训练的图片数，num_workers为训练时的工作进程数，如果显卡不太行或者显存空间不够，将batch_size调低并将num_workers调为0
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
    valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

    loss = smp.utils.losses.DiceLoss()
    metrics = [
        smp.utils.metrics.IoU(threshold=0.5),
    ]

    optimizer = torch.optim.Adam([
        dict(params=model.parameters(), lr=0.0001),
    ])

    # 创建一个简单的循环，用于迭代数据样本
    train_epoch = smp.utils.train.TrainEpoch(
        model,
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        device=device,
        verbose=True,
    )

    valid_epoch = smp.utils.train.ValidEpoch(
        model,
        loss=loss,
        metrics=metrics,
        device=device,
        verbose=True,
    )

    # 进行40轮次迭代的模型训练
    max_score = 0
    for i in range(0, 40):
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        # 每次迭代保存下训练最好的模型
        if max_score < valid_logs['iou_score']:
            max_score = valid_logs['iou_score']
            torch.save(model, './best_model.pth')
            print('Model saved!')

        if i == 25:
            optimizer.param_groups[0]['lr'] = 1e-5
            print('Decrease decoder learning rate to 1e-5!')


if __name__ == '__main__':
    main()

  from .autonotebook import tqdm as notebook_tqdm
  original_init(self, **validated_kwargs)



Epoch: 0
train: 100%|██████████| 69/69 [00:28<00:00,  2.38it/s, dice_loss - 1.0, iou_score - 2.968e-11]
valid: 100%|██████████| 14/14 [25:13<00:00, 108.09s/it, dice_loss - 1.0, iou_score - 4.596e-11] 
Model saved!

Epoch: 1
train: 100%|██████████| 69/69 [00:52<00:00,  1.32it/s, dice_loss - 1.0, iou_score - 0.01449]  
valid: 100%|██████████| 14/14 [09:20<00:00, 40.02s/it, dice_loss - 1.0, iou_score - 9.107e-08]
Model saved!

Epoch: 2
train: 100%|██████████| 69/69 [01:25<00:00,  1.24s/it, dice_loss - 0.9999, iou_score - 0.9275]  
valid: 100%|██████████| 14/14 [09:27<00:00, 40.52s/it, dice_loss - 1.0, iou_score - 1.0]
Model saved!

Epoch: 3
train: 100%|██████████| 69/69 [01:11<00:00,  1.03s/it, dice_loss - 0.9996, iou_score - 1.0]
valid: 100%|██████████| 14/14 [07:46<00:00, 33.31s/it, dice_loss - 0.9999, iou_score - 1.0]

Epoch: 4
train: 100%|██████████| 69/69 [01:03<00:00,  1.09it/s, dice_loss - 0.998, iou_score - 1.0] 
valid: 100%|██████████| 14/14 [06:05<00:00, 26.09s/it, dice_loss - 