In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm.notebook import tqdm

导入数据

In [None]:
DATA_DIR = os.path.join("/kaggle", "input", "semantic-segmentation", "semantic-segmentation","data")
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'labels_class_dict.csv'))
# Get class names
class_names = class_dict['class_names'].tolist()
# Get class RGB values
class_rgb_values = class_dict[['r', 'g', 'b']].values.tolist()
print('All dataset classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

In [None]:
image_names = os.listdir(os.path.join(DATA_DIR, 'images'))
image_paths = [os.path.join(DATA_DIR, 'images', image_name) for image_name in image_names]
mask_paths = [os.path.join(DATA_DIR, 'masks', image_name.replace('jpg', 'png')) for image_name in image_names]

#把颜色格式转化成rgb
images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) for image_path in image_paths]
masks = [cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_BGR2RGB) for mask_path in mask_paths]
print(len(images),len(masks))
# mask = cv2.cvtColor(cv2.imread(mask_paths[0]), cv2.COLOR_BGR2RGB)

images = [cv2.resize(image,(320,320)) for image in images]
masks = [cv2.resize(mask,(320,320)) for mask in masks]

#打印样本图片
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(images[1])
axes[1].imshow(masks[1])

图片预处理

one-hot :https://blog.csdn.net/baidu_36511315/article/details/105528546

In [None]:
def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    # Arguments
        label: The 2D array segmentation image label
        label_values

    # Returns
        A 2D array with the same width and height as the input, but
        with a depth size of num_classes
    """
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map

# Perform reverse one-hot-encoding on labels / preds


def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    # Arguments
        image: The one-hot format image 

    # Returns
        A 2D array with the same width and height as the input, but
        with a depth size of 1, where each pixel value is the classified 
        class key.
    """
    x = np.argmax(image, axis=-1)
    return x

In [None]:
masks = [one_hot_encode(mask, class_rgb_values).astype('float') for mask in masks]
print('Image shape: ', images[0].shape)
print('Mask shape: ', masks[0].shape)

# masks = [reverse_one_hot(mask) for mask in masks]
# print(masks[0].shape)

分割数据集（前100为训练数据，后615为测试数据）

In [None]:
train_image = images[:50]
test_image = images[50:]
print(len(train_image),len(test_image))
train_mask = masks[:50]
test_mask = masks[50:]
print(len(train_mask),len(test_mask))

# #打印样本图片
# fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# axes[0].imshow(test_image[1])
# axes[1].imshow(test_mask[1])

# for i in range(100):
#     print(train_image[i].shape,train_mask[i].shape)
# test = train_image[i].reshape(240,-1)
# print(test.shape)

定义dataset

In [None]:
class imageDataset(Dataset):
    def __init__(self,images,masks):
        self.images = images
        self.masks = masks
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        image = np.array(self.images[index])
        mask = np.array(self.masks[index])
        transform = transforms.ToTensor()
        image = transform(image)
        mask = transform(mask)
        return image,mask

In [None]:
dataset = imageDataset(train_image,train_mask)
image,mask = dataset[0]
# plt.imshow(image)
# print(image)
print(image.shape,mask.shape)

UNET模型
选择原因：cnn二维模型有UNet、ResNet（网络结构过于深，超过1000层）和FCN
UNET和FCN区别：https://blog.csdn.net/weixin_41108334/article/details/87917748

In [None]:
class UNet(nn.Module):
    
    def __init__(self, num_classes):
        super(UNet, self).__init__()
        self.num_classes = num_classes
        self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
        self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
        self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
        self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
        self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.middle = self.conv_block(in_channels=512, out_channels=1024)
        self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
        self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
        self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
        self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1)
        
    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(num_features=out_channels),
                                    nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(num_features=out_channels))
        return block
    
    def forward(self, X):
        contracting_11_out = self.contracting_11(X) # [-1, 64, 256, 256]
        contracting_12_out = self.contracting_12(contracting_11_out) # [-1, 64, 128, 128]
        contracting_21_out = self.contracting_21(contracting_12_out) # [-1, 128, 128, 128]
        contracting_22_out = self.contracting_22(contracting_21_out) # [-1, 128, 64, 64]
        contracting_31_out = self.contracting_31(contracting_22_out) # [-1, 256, 64, 64]
        contracting_32_out = self.contracting_32(contracting_31_out) # [-1, 256, 32, 32]
        contracting_41_out = self.contracting_41(contracting_32_out) # [-1, 512, 32, 32]
        contracting_42_out = self.contracting_42(contracting_41_out) # [-1, 512, 16, 16]
        middle_out = self.middle(contracting_42_out) # [-1, 1024, 16, 16]
        expansive_11_out = self.expansive_11(middle_out) # [-1, 512, 32, 32]
        expansive_12_out = self.expansive_12(torch.cat((expansive_11_out, contracting_41_out), dim=1)) # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
        expansive_21_out = self.expansive_21(expansive_12_out) # [-1, 256, 64, 64]
        expansive_22_out = self.expansive_22(torch.cat((expansive_21_out, contracting_31_out), dim=1)) # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
        expansive_31_out = self.expansive_31(expansive_22_out) # [-1, 128, 128, 128]
        expansive_32_out = self.expansive_32(torch.cat((expansive_31_out, contracting_21_out), dim=1)) # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
        expansive_41_out = self.expansive_41(expansive_32_out) # [-1, 64, 256, 256]
        expansive_42_out = self.expansive_42(torch.cat((expansive_41_out, contracting_11_out), dim=1)) # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
        output_out = self.output(expansive_42_out) # [-1, num_classes, 256, 256]
        return output_out


模型训练

In [None]:
batch_size = 16
epochs = 10
lr = 0.01
num_classes = 9 #label的数量

In [None]:
dataset = imageDataset(train_image, train_mask)
data_loader = DataLoader(dataset, batch_size=batch_size)

In [None]:
model = UNet(num_classes=num_classes)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
step_losses = []
epoch_losses = []
for epoch in tqdm(range(epochs)):
    epoch_loss = 0
    for X, Y in tqdm(data_loader, total=len(data_loader), leave=False):
        optimizer.zero_grad()
#         print(X.shape)
        Y_pred = model(X)
        loss = criterion(Y_pred, Y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        step_losses.append(loss.item())
    epoch_losses.append(epoch_loss/len(data_loader))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].plot(step_losses)
axes[1].plot(epoch_losses)

In [None]:
model_name = "U-Net.pth"
torch.save(model.state_dict(), model_name)

检查模型

In [None]:
model_path = "/kaggle/working/U-Net.pth"
model_ = UNet(num_classes=num_classes)
model_.load_state_dict(torch.load(model_path))

In [None]:
test_batch_size = 8
# test_i = test_image[:1]
# test_m = test_mask[:1]
dataset = imageDataset(test_image, test_mask)
data_loader = DataLoader(dataset, batch_size=test_batch_size)

In [None]:
X, Y = next(iter(data_loader))
Y_pred = model_(X)
print(Y_pred.shape)
Y_pred = torch.argmax(Y_pred, dim=1)
print(Y_pred.shape)

In [None]:
fig, axes = plt.subplots(test_batch_size, 3, figsize=(3*5, test_batch_size*5))

for i in range(test_batch_size):
    
    landscape = X[i].permute(1, 2, 0).cpu().detach().numpy()
    label_class = reverse_one_hot(Y[i].permute(1, 2, 0)).cpu().detach().numpy()
    label_class_predicted = Y_pred[i].cpu().detach().numpy()
    
    axes[i, 0].imshow(landscape)
    axes[i, 0].set_title("Landscape")
    axes[i, 1].imshow(label_class)
    axes[i, 1].set_title("Label Class")
    axes[i, 2].imshow(label_class_predicted)
    axes[i, 2].set_title("Label Class - Predicted")

计算miou
https://zhuanlan.zhihu.com/p/406706860