# TF2.0 UNet语义分割模型

In [18]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
import glob

## 1. 读取数据集

In [3]:
# 读取训练数据的图片
train_imgs = glob.glob('../cityscapes/leftImg8bit/train/*/*.png')
# 读取训练数据的标签
train_labels = glob.glob('../cityscapes/gtFine/train/*/*_gtFine_labelIds.png')

NameError: name 'glob' is not defined

In [2]:
# 对训练集进行乱序
index = np.random.permutation(len(train_imgs))
train_imgs = np.array(train_imgs)[index]
train_labels = np.array(train_labels)[index]

NameError: name 'train_imgs' is not defined

In [11]:
# 读取交叉验证数据的图片
val_imgs = glob.glob('../cityscapes/leftImg8bit/val/*/*.png')
# 读取交叉验证数据的标签
val_labels = glob.glob('../cityscapes/gtFine/val/*/*_gtFine_labelIds.png')

In [14]:
# 创建dataset
dataset_train = tf.data.Dataset.from_tensor_slices((train_imgs, train_labels))
dataset_val = tf.data.Dataset.from_tensor_slices((val_imgs, val_labels))

In [3]:
# 定义读取图片的函数
def read_png(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=3)
    return img

# 定义读取标签的函数
def read_png_label(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=1)
    return img

In [None]:
img_1 = read_png(train_imgs[0])
label_1 = read_png_label(train_labels[0])

In [None]:
img_1.shape

In [None]:
label_1.shape

## 2. 数据增强
1. 随机翻转：img = tf.image.flip_left_right()
2. 随机裁切：由于原始图像和分割标签图像是匹配的，所以需要将两者按照通道方向进行合并，然后再随机裁切，又由于图像比较大，如果直接塞到模型，可能显存不够用，因此可以将图像先resize到较小的尺寸，再在较小尺寸上进行随机裁切，这样获取到的图像视野能够比较大。

In [4]:
def crop_img(img, mask):
    # 先将原始图像和分割标签进行合并（沿着图像的通道方向）
    concat_img = tf.concat([img, mask], axis=-1)
    concat_img = tf.image.resize(concat_img, (280, 280), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    crop_img = tf.image.random_crop(concat_img, [256, 256, 4])
    # 需要注意的是如果mask=crop_img[:, :, 3]，则返回的是一个二维的图像，channel维度被切片切掉了
    # 而mask = crop_img[:, :, 3:]则代表channel那个维度的1会被保留下来
    return crop_img[:, :, :3], crop_img[:, :, 3:]

In [None]:
img_1, label_1 = crop_img(img_1, label_1)
plt.subplot(1, 2, 1)
plt.imshow(img_1.numpy())
plt.subplot(1, 2, 2)
plt.imshow(np.squeeze(label_1.numpy()))

In [5]:
# 对输入的原始图像做归一化
def normal(img, mask):
    img = tf.cast(img, tf.float32) / 127.5 - 1
    mask = tf.cast(mask, tf.int32)
    return img, mask

In [6]:
# 加载训练数据
def load_image_train(img_path, mask_path):
    img = read_png(img_path)
    mask = read_png_label(mask_path)
    
    img, mask = crop_img(img, mask)
    
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)
    
    img, mask = normal(img, mask)
    return img, mask

In [7]:
# 加载验证集数据
def load_image_val(img_path, mask_path):
    img = read_png(img_path)
    mask = read_png_label(mask_path)
    
    img = tf.image.resize(img, (256, 256))
    mask = tf.image.resize(mask, (256, 256))
    
    img, mask = normal(img, mask)
    return img, mask

In [8]:
# 定义模型的一些常量
BATCH_SIZE = 32
BUFFER_SIZE = 300
train_step_per_epoch = len(train_imgs) // BATCH_SIZE
val_step_per_epoch = len(val_imgs) // BATCH_SIZE

auto = tf.data.experimental.AUTOTUNE

# 获取dataset
dataset_train = dataset_train.map(load_image_train, num_parallel_calls=auto)
dataset_val = dataset_val.map(load_image_val, num_parallel_calls=auto)

NameError: name 'train_imgs' is not defined

In [None]:
for i, m in dataset_train.take(1):
    plt.subplot(1, 2, 1)
    plt.imshow((img_1.numpy()+1)/2)
    plt.subplot(1, 2, 2)
    plt.imshow(np.squeeze(label_1.numpy()))

In [None]:
dataset_train = dataset_train.cache().repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(auto)
dataset_val = dataset_val.cache().batch(BATCH_SIZE)

## 3. UNet模型搭建

In [None]:
# 查看标签有多少类
np.unique(label_1.numpy())

In [21]:
def creat_model():
    # UNet模型第一部分
    inputs = tf.keras.layers.Input(shape=(256, 256, 3))
    x_0 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(inputs)
    x_0 = tf.keras.layers.BatchNormalization()(x_0)
    x_0 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(x_0)
    x_0 = tf.keras.layers.BatchNormalization()(x_0) # [256, 256, 64]
    
    # 下采样
    x_1 = tf.keras.layers.MaxPooling2D()(x_0) # [128, 128, 64]
    # 第二部分
    x_1 = tf.keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu')(x_1)
    x_1 = tf.keras.layers.BatchNormalization()(x_1) 
    x_1 = tf.keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu')(x_1)
    x_1 = tf.keras.layers.BatchNormalization()(x_1) # [128, 128, 128]
    
    # 下采样
    x_2 = tf.keras.layers.MaxPooling2D()(x_1) # [64, 64, 128]
    # 第三部分
    x_2 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same', activation='relu')(x_2)
    x_2 = tf.keras.layers.BatchNormalization()(x_2) 
    x_2 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same', activation='relu')(x_2)
    x_2 = tf.keras.layers.BatchNormalization()(x_2) # [64, 64, 256]
    
    # 下采样
    x_3 = tf.keras.layers.MaxPooling2D()(x_2) # [32, 32, 256]
    # 第四部分
    x_3 = tf.keras.layers.Conv2D(filters=512, kernel_size=3, padding='same', activation='relu')(x_3)
    x_3 = tf.keras.layers.BatchNormalization()(x_3) 
    x_3 = tf.keras.layers.Conv2D(filters=512, kernel_size=3, padding='same', activation='relu')(x_3)
    x_3 = tf.keras.layers.BatchNormalization()(x_3) # [32, 32, 512]
    
    # 下采样
    x_4 = tf.keras.layers.MaxPooling2D()(x_3) # [16, 16, 512]
    # 第五部分
    x_4 = tf.keras.layers.Conv2D(filters=1024, kernel_size=3, padding='same', activation='relu')(x_4)
    x_4 = tf.keras.layers.BatchNormalization()(x_4) 
    x_4 = tf.keras.layers.Conv2D(filters=1024, kernel_size=3, padding='same', activation='relu')(x_4)
    x_4 = tf.keras.layers.BatchNormalization()(x_4) # [16, 16, 1024]
    
    # ===================================================================
    # 上采样(参数strides记得要设置，否则图像不会上采样扩大， 由于strides=2， 图像反卷积变为原来的一倍)
    x_5 = tf.keras.layers.Conv2DTranspose(filters=512, kernel_size=3, strides=2, 
                                          padding='same', activation='relu')(x_4) 
    x_5 = tf.keras.layers.BatchNormalization()(x_5) # [32, 32, 512]
    # 下采样的部分与现在的部分进行合并(concat，增加channel的厚度，与FCN中tf.add()不同)
    x_6 = tf.concat([x_3, x_5], axis=-1) # [32, 32, 1024]
    x_6 = tf.keras.layers.Conv2D(filters=512, kernel_size=3, padding='same', activation='relu')(x_6)
    x_6 = tf.keras.layers.BatchNormalization()(x_6) 
    x_6 = tf.keras.layers.Conv2D(filters=512, kernel_size=3, padding='same', activation='relu')(x_6)
    x_6 = tf.keras.layers.BatchNormalization()(x_6) # [32, 32, 512]
    
    # 上采样
    x_7 = tf.keras.layers.Conv2DTranspose(filters=256, kernel_size=3, strides=2, 
                                          padding='same', activation='relu')(x_6) 
    x_7 = tf.keras.layers.BatchNormalization()(x_7) # [64, 54, 256]
    # 下采样的部分与现在的部分进行合并
    x_8 = tf.concat([x_2, x_7], axis=-1) # [64, 64, 512]
    x_8 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same', activation='relu')(x_8)
    x_8 = tf.keras.layers.BatchNormalization()(x_8) 
    x_8 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same', activation='relu')(x_8)
    x_8 = tf.keras.layers.BatchNormalization()(x_8) # [64, 64, 256]
    
    # 上采样
    x_9 = tf.keras.layers.Conv2DTranspose(filters=128, kernel_size=3, strides=2, 
                                          padding='same', activation='relu')(x_8) 
    x_9 = tf.keras.layers.BatchNormalization()(x_9) # [128, 128, 128]
    # 下采样的部分与现在的部分进行合并
    x_10 = tf.concat([x_1, x_9], axis=-1) # [128, 128, 256]
    x_10 = tf.keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu')(x_10)
    x_10 = tf.keras.layers.BatchNormalization()(x_10) 
    x_10 = tf.keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu')(x_10)
    x_10 = tf.keras.layers.BatchNormalization()(x_10) # [128, 128, 128]
    
    # 上采样
    x_11 = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=3, strides=2, 
                                          padding='same', activation='relu')(x_10) 
    x_11 = tf.keras.layers.BatchNormalization()(x_11) # [256, 256, 64]
    # 下采样的部分与现在的部分进行合并
    x_12 = tf.concat([x_0, x_11], axis=-1) # [256, 256, 128]
    x_12 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')(x_12)
    x_12 = tf.keras.layers.BatchNormalization()(x_12) 
    x_12 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')(x_12)
    x_12 = tf.keras.layers.BatchNormalization()(x_12) # [256, 256, 64]
    
    # 最后输出层，34是分类类别数
    outputs = tf.keras.layers.Conv2D(filters=34, kernel_size=1, 
                                     padding='same', activation='softmax')(x_12) # [256, 256, 34]
    return tf.keras.Model(inputs=inputs, outputs=outputs)

In [22]:
model = creat_model()
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 256, 256, 64) 1792        input_6[0][0]                    
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 256, 256, 64) 256         conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 256, 256, 64) 36928       batch_normalization_10[0][0]     
______________________________________________________________________________________________

In [24]:
tf.keras.utils.plot_model(model)

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
