In [None]:
from mydataset import MyDataset
import os
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
from unet import UNet
import time

In [None]:
class MyDataset(Dataset):

    def __init__(self, image_dir, mask_dir, train_dir):
        self.Size = (256, 256)
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_fns = os.listdir(image_dir)
        self.mask_fns = os.listdir(mask_dir)
        self.class_dict = pd.read_csv(os.path.join(train_dir, 'labels_class_dict.csv'))
        # Get class names
        self.class_names = self.class_dict['class_names'].tolist()
        # Get class RGB values
        self.class_rgb_values = self.class_dict[['r', 'g', 'b']].values.tolist()
        self.image_preprocessed = []
        self.mask_preprocessed = []

        for index in range(0, len(self.image_fns)):
            image_file_name = self.image_fns[index]
            image_path = os.path.join(self.image_dir, image_file_name)
            mask_file_name = self.mask_fns[index]
            mask_path = os.path.join(self.mask_dir, mask_file_name)
            image = Image.open(image_path).convert('RGB')
            image = np.array(image)
            image = self.transform(image)
            mask = Image.open(mask_path).convert('RGB')
            mask = np.array(mask)
            mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
            mask = reverse_one_hot(mask)
            mask = torch.Tensor(mask).long()
            mask = mask.unsqueeze(0)  # 升一维
            mask_transform = transforms.Resize(size=self.Size)
            mask = mask_transform(mask)
            mask = mask.squeeze(0)  # 还原维度
            self.image_preprocessed.append(image)
            self.mask_preprocessed.append(mask)

    def __len__(self):
        return len(self.image_fns)

    def transform(self, image):
        transform_ops = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            transforms.Resize(size=self.Size)
        ])
        return transform_ops(image)

    def __getitem__(self, index):
        return self.image_preprocessed[index], self.mask_preprocessed[index]


In [None]:
class UNet(nn.Module):

    def __init__(self, num_classes):
        super(UNet, self).__init__()
        self.num_classes = num_classes
        self.contracting_11 = self.conv_block(in_channels=3, out_channels=64)
        self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
        self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
        self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
        self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.middle = self.conv_block(in_channels=512, out_channels=1024)
        self.expansive_11 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1,
                                               output_padding=1)
        self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
        self.expansive_21 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1,
                                               output_padding=1)
        self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
        self.expansive_31 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1,
                                               output_padding=1)
        self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
        self.expansive_41 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1,
                                               output_padding=1)
        self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=out_channels),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=out_channels))
        return block

    def forward(self, X):
        contracting_11_out = self.contracting_11(X)  # [-1, 64, 256, 256]
        contracting_12_out = self.contracting_12(contracting_11_out)  # [-1, 64, 128, 128]
        contracting_21_out = self.contracting_21(contracting_12_out)  # [-1, 128, 128, 128]
        contracting_22_out = self.contracting_22(contracting_21_out)  # [-1, 128, 64, 64]
        contracting_31_out = self.contracting_31(contracting_22_out)  # [-1, 256, 64, 64]
        contracting_32_out = self.contracting_32(contracting_31_out)  # [-1, 256, 32, 32]
        contracting_41_out = self.contracting_41(contracting_32_out)  # [-1, 512, 32, 32]
        contracting_42_out = self.contracting_42(contracting_41_out)  # [-1, 512, 16, 16]
        middle_out = self.middle(contracting_42_out)  # [-1, 1024, 16, 16]
        expansive_11_out = self.expansive_11(middle_out)  # [-1, 512, 32, 32]
        expansive_12_out = self.expansive_12(
            torch.cat((expansive_11_out, contracting_41_out), dim=1))  # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
        expansive_21_out = self.expansive_21(expansive_12_out)  # [-1, 256, 64, 64]
        expansive_22_out = self.expansive_22(
            torch.cat((expansive_21_out, contracting_31_out), dim=1))  # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
        expansive_31_out = self.expansive_31(expansive_22_out)  # [-1, 128, 128, 128]
        expansive_32_out = self.expansive_32(
            torch.cat((expansive_31_out, contracting_21_out), dim=1))  # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
        expansive_41_out = self.expansive_41(expansive_32_out)  # [-1, 64, 256, 256]
        expansive_42_out = self.expansive_42(
            torch.cat((expansive_41_out, contracting_11_out), dim=1))  # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
        output_out = self.output(expansive_42_out)  # [-1, num_classes, 256, 256]
        return output_out


In [None]:
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour) # np.equal实现把label image每个像素的RGB值与某个class的RGB值进行比对，变成RGB bool值
        class_map = np.all(equality, axis=-1) # np.all 把RGB bool值，变成一个bool值，即实现某个class 的label mask。使用for循环，生成所有class的label mask
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1) # np.stack实现所有class的label mask的堆叠。最终depth size 为num_classes的数量
    return semantic_map

def reverse_one_hot(image):
    x = np.argmax(image, axis=-1) # axis表示最后一个维度，即channel
    return x