# 课时117 TF2.0-UNet语义分割模型

In [1]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
import glob

## 1. 读取数据集

In [7]:
# 读取训练数据的图片
train_imgs = glob.glob('../cityscapes/leftImg8bit/train/*/*.png')
# 读取训练数据的标签
train_labels = glob.glob('../cityscapes/gtFine/train/*/*_gtFine_labelIds.png')

In [9]:
# 对训练集进行乱序
index = np.random.permutation(len(train_imgs))
train_imgs = np.array(train_imgs)[index]
train_labels = np.array(train_labels)[index]

In [11]:
# 读取交叉验证数据的图片
val_imgs = glob.glob('../cityscapes/leftImg8bit/val/*/*.png')
# 读取交叉验证数据的标签
val_labels = glob.glob('../cityscapes/gtFine/val/*/*_gtFine_labelIds.png')

In [14]:
# 创建dataset
dataset_train = tf.data.Dataset.from_tensor_slices((train_imgs, train_labels))
dataset_val = tf.data.Dataset.from_tensor_slices((val_imgs, val_labels))

In [3]:
# 定义读取图片的函数
def read_png(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=3)
    return img

# 定义读取标签的函数
def read_png_label(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=1)
    return img

In [None]:
img_1 = read_png(train_imgs[0])
label_1 = read_png_label(train_labels[0])

In [None]:
img_1.shape

In [None]:
label_1.shape

## 2. 数据增强
1. 随机翻转：img = tf.image.flip_left_right()
2. 随机裁切：由于原始图像和分割标签图像是匹配的，所以需要将两者按照通道方向进行合并，然后再随机裁切，又由于图像比较大，如果直接塞到模型，可能显存不够用，因此可以将图像先resize到较小的尺寸，再在较小尺寸上进行随机裁切，这样获取到的图像视野能够比较大。

In [4]:
def crop_img(img, mask):
    # 先将原始图像和分割标签进行合并（沿着图像的通道方向）
    concat_img = tf.concat([img, mask], axis=-1)
    concat_img = tf.image.resize(concat_img, (280, 280), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    crop_img = tf.image.random_crop(concat_img, [256, 256, 4])
    # 需要注意的是如果mask=crop_img[:, :, 3]，则返回的是一个二维的图像，channel维度被切片切掉了
    # 而mask = crop_img[:, :, 3:]则代表channel那个维度的1会被保留下来
    return crop_img[:, :, :3], crop_img[:, :, 3:]

In [None]:
img_1, label_1 = crop_img(img_1, label_1)
plt.subplot(1, 2, 1)
plt.imshow(img_1.numpy())
plt.subplot(1, 2, 2)
plt.imshow(np.squeeze(label_1.numpy()))

In [5]:
# 对输入的原始图像做归一化
def normal(img, mask):
    img = tf.cast(img, tf.float32) / 127.5 - 1
    mask = tf.cast(mask, tf.int32)
    return img, mask

In [6]:
# 加载训练数据
def load_image_train(img_path, mask_path):
    img = read_png(img_path)
    mask = read_png_label(mask_path)
    
    img, mask = crop_img(img, mask)
    
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)
    
    img, mask = normal(img, mask)
    return img, mask

In [7]:
# 加载验证集数据
def load_image_val(img_path, mask_path):
    img = read_png(img_path)
    mask = read_png_label(mask_path)
    
    img = tf.image.resize(img, (256, 256))
    mask = tf.image.resize(mask, (256, 256))
    
    img, mask = normal(img, mask)
    return img, mask

In [10]:
# 定义模型的一些常量
BATCH_SIZE = 32
BUFFER_SIZE = 300
train_step_per_epoch = len(train_imgs) // BATCH_SIZE
val_step_per_epoch = len(val_imgs) // BATCH_SIZE

auto = tf.data.experimental.AUTOTUNE

# 获取dataset
dataset_train = dataset_train.map(load_image_train, num_parallel_calls=auto)
dataset_val = dataset_val.map(load_image_val, num_parallel_calls=auto)

In [None]:
for i, m in dataset_train.take(1):
    plt.subplot(1, 2, 1)
    plt.imshow((img_1.numpy()+1)/2)
    plt.subplot(1, 2, 2)
    plt.imshow(np.squeeze(label_1.numpy()))

In [None]:
dataset_train = dataset_train.cache().repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(auto)
dataset_val = dataset_val.cache().batch(BATCH_SIZE)

In [None]:
# 查看标签有多少类
np.unique(label_1.numpy())

## 3. UNet模型搭建(使用自定义层)

In [11]:
# 使用自定义层定义相应的下采样模块
class Downsample(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(Downsample, self).__init__()
        self.conv_layer_1 = tf.keras.layers.Conv2D(filters=filters,
                                                   kernel_size=3,
                                                   padding="same")
        self.conv_layer_2 = tf.keras.layers.Conv2D(filters=filters,
                                                   kernel_size=3,
                                                   padding="same")
        self.pool_layer = tf.keras.layers.MaxPooling2D()
    
    def call(self, x, is_pool=True):
        if is_pool:
            x = self.pool_layer(x)
        x = self.conv_layer_1(x)
        x = tf.nn.relu(x)
        x = self.conv_layer_2(x)
        x = tf.nn.relu(x)
        return x

In [12]:
# 使用自定义层定义相应的上采样模块
class Upsample(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(Upsample, self).__init__()
        self.conv_layer_1 = tf.keras.layers.Conv2D(filters=filters,
                                                   kernel_size=3,
                                                   padding="same")
        self.conv_layer_2 = tf.keras.layers.Conv2D(filters=filters,
                                                   kernel_size=3,
                                                   padding="same")
        self.deconv_layer = tf.keras.layers.Conv2DTranspose(filters=filters//2,
                                                            kernel_size=3,
                                                            strides=2,
                                                            padding="same")
    
    def call(self, x):
        x = self.conv_layer_1(x)
        x = tf.nn.relu(x)
        x = self.conv_layer_2(x)
        x = tf.nn.relu(x)
        x = self.deconv_layer(x)
        x = tf.nn.relu(x)
        return x

In [13]:
class UNet_model(tf.keras.Model):
    def __init__(self):
        super(UNet_model, self).__init__()
        self.downsample_1 = Downsample(filters=64)
        self.downsample_2 = Downsample(filters=128)
        self.downsample_3 = Downsample(filters=256)
        self.downsample_4 = Downsample(filters=512)
        self.downsample_5 = Downsample(filters=1024)
        self.upsample_0 = tf.keras.layers.Conv2DTranspose(filters=512,
                                                          kernel_size=2,
                                                          strides=2,
                                                          padding="same")
        self.upsample_1 = Upsample(filters=512)
        self.upsample_2 = Upsample(filters=256)
        self.upsample_3 = Upsample(filters=128)
        self.second_output_layer = Downsample(filters=64)
        self.last_output_layer = tf.keras.layers.Conv2D(filters=34,
                                                        kernel_size=1,
                                                        padding="same")
    def call(self, x):
        x1 = self.downsample_1(x, is_pool=False)
        x2 = self.downsample_2(x1)
        x3 = self.downsample_3(x2)
        x4 = self.downsample_4(x3)
        x5 = self.downsample_5(x4)
        x5 = self.upsample_0(x5)

        x5 = tf.concat([x4, x5])
        x5 = self.upsample_1(x5)

        x5 = tf.concat([x3, x5])
        x5 = self.upsample_2(x5)

        x5 = tf.concat([x2, x5])
        x5 = self.upsample_3(x5)

        x5 = tf.concat([x1, x5])
        x5 = self.second_output_layer(x5, is_pool=False)
        x5 = self.last_output_layer(x5)
        return x5

## 4. 定义优化器，损失函数等

In [14]:
model = UNet_model()
opt = tf.keras.optimizers.Adam(0.0001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [17]:
# 定义IOU指标
class MeanIOU(tf.keras.metrics.MeanIoU):
    def __call__(self, y_true, y_pred):
        y_pred = tf.argmax(y_pred)
        return super().__call__(y_true, y_pred)

In [19]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')
train_iou = MeanIOU(34, name='train_iou')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')
test_iou = MeanIOU(34, name='test_iou')

In [21]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as t:
        pred = model(images)
        loss_step = loss_fn(labels, pred)
    grads = t.gradient(loss_step, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))

    train_loss(loss_step)
    train_acc(labels, pred)
    train_iou(labels, pred)

In [None]:
@tf.function
def test_step(images, labels):
    pred = model(images)
    loss_step = loss_fn(labels, pred)

    test_loss(loss_step)
    test_acc(labels, pred)
    test_iou(labels, pred)

In [None]:
EPOCHS = 60

for epoch in range(EPOCHS):
    train_loss.reset_states()
    train_acc.reset_states()
    train_iou.reset_states()

    test_loss.reset_states()
    test_acc.reset_states()
    test_iou.reset_states()

    for images, labels in dataset_train:
        train_step(images, labels)
    
    for test_images, test_labels in dataset_val:
        test_step(test_images, test_labels)
    
    template = 'Epoch: {:.3f}, Loss {:.3f}, Accuracy {:.3f}, \
                IOU {:.3f}, Test Loss {:.3f}, Test Accuracy {:.3f}\
                , Test IOU {:.3f}'
    print(template.format(epoch+1,
                          train_loss.result(),
                          train_acc.result()*100,
                          train_iou.result(),
                          test_loss.result(),
                          test_acc.result()*100,
                          test_iou.result()))

## 5. 查看模型训练效果

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training_loss')
plt.plot(epochs, val_loss, 'bo', label='Validation_loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss Value')
plt.legend()
plt.show()

In [None]:
# 查看模型训练完毕之后的效果
num = 3
for image, mask in dataset_val.take(1):
    pred_mask = model.predict(image)
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    
    plt.figure(figsize=(10, 10))
    for i in range(num):
        plt.subplot(num, 3, i*num+1)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(image[i]))
        plt.subplot(num, 3, i*num+2)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(mask[i]))
        plt.subplot(num, 3, i*num+3)
        plt.imshow(tf.keras.preprocessing.image.array_to_img(pred_mask[i]))