# 1. 加载依赖库

In [1]:
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Reshape, Permute, Dense, Activation, Flatten, Conv2D
from keras.layers import MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D, GlobalMaxPool2D, BatchNormalization
from keras.layers import Convolution2D, UpSampling2D, AtrousConvolution2D, ZeroPadding2D, Lambda, Conv2DTranspose
from keras.layers import multiply, add, concatenate, Concatenate
from keras.layers import LocallyConnected2D
from keras.layers import add
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras.optimizers import SGD
from keras.callbacks import ModelCheckpoint, EarlyStopping
import matplotlib.pyplot as plt
import numpy as np
from keras.utils import np_utils
import cv2
import random
from keras.layers.advanced_activations import PReLU, LeakyReLU
import tifffile
from keras.backend import tf as ktf
from glob import glob
from random import choice

Using TensorFlow backend.


In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 2. 全局变量

In [3]:
NB_CLASS = 5
IMG_WIDTH = 512
IMG_HEIGHT = 512

# 数据路径
img_type = '.tif'
TRAIN_TOP_PATH = 'E:/Semantic-Segmentation/four_dataset/VAIHINGEN/train_crop_top/'

VAL_TOP_PATH = 'E:/Semantic-Segmentation/four_dataset/VAIHINGEN/test_crop_top/'

# 3. 模型定义

In [4]:
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Conv1D, Conv2D, BatchNormalization, Activation
from keras.layers import UpSampling2D, add, concatenate, Dropout
from keras_superpixel_pooling_new import *
from keras_superpixel_unpooling_new import *


def conv3x3(x, out_filters, strides=(1, 1)):
    x = Conv2D(out_filters, 3, padding='same', strides=strides, use_bias=False, kernel_initializer='he_normal')(x)
    return x


def basic_Block(input, out_filters, strides=(1, 1), with_conv_shortcut=False):
    x = conv3x3(input, out_filters, strides)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = conv3x3(x, out_filters)
    x = BatchNormalization(axis=3)(x)

    if with_conv_shortcut:
        residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input)
        residual = BatchNormalization(axis=3)(residual)
        x = add([x, residual])
    else:
        x = add([x, input])

    x = Activation('relu')(x)
    return x


def bottleneck_Block(input, out_filters, strides=(1, 1), with_conv_shortcut=False):
    expansion = 4
    de_filters = int(out_filters / expansion)

    x = Conv2D(de_filters, 1, use_bias=False, kernel_initializer='he_normal')(input)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = Conv2D(de_filters, 3, strides=strides, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    x = Conv2D(out_filters, 1, use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization(axis=3)(x)

    if with_conv_shortcut:
        residual = Conv2D(out_filters, 1, strides=strides, use_bias=False, kernel_initializer='he_normal')(input)
        residual = BatchNormalization(axis=3)(residual)
        x = add([x, residual])
    else:
        x = add([x, input])

    x = Activation('relu')(x)
    return x


def stem_net(input):
    x = Conv2D(64, 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(input)
    x = BatchNormalization(axis=3)(x)
    x = Activation('relu')(x)

    # x = Conv2D(64, 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    # x = BatchNormalization(axis=3)(x)
    # x = Activation('relu')(x)

    x = bottleneck_Block(x, 256, with_conv_shortcut=True)
    x = bottleneck_Block(x, 256, with_conv_shortcut=False)
    x = bottleneck_Block(x, 256, with_conv_shortcut=False)
    x = bottleneck_Block(x, 256, with_conv_shortcut=False)

    return x


def transition_layer1(x, out_filters_list=[32, 64]):
    x0 = Conv2D(out_filters_list[0], 3, padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    x0 = BatchNormalization(axis=3)(x0)
    x0 = Activation('relu')(x0)

    x1 = Conv2D(out_filters_list[1], 3, strides=(2, 2),
                padding='same', use_bias=False, kernel_initializer='he_normal')(x)
    x1 = BatchNormalization(axis=3)(x1)
    x1 = Activation('relu')(x1)

    return [x0, x1]


def make_branch1_0(x, out_filters=32):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch1_1(x, out_filters=64):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def fuse_layer1(x):
    x0_0 = x[0]
    x0_1 = Conv2D(32, 1, use_bias=False, kernel_initializer='he_normal')(x[1])
    x0_1 = BatchNormalization(axis=3)(x0_1)
    x0_1 = UpSampling2D(size=(2, 2))(x0_1)
    x0 = add([x0_0, x0_1])

    x1_0 = Conv2D(64, 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(x[0])
    x1_0 = BatchNormalization(axis=3)(x1_0)
    x1_1 = x[1]
    x1 = add([x1_0, x1_1])
    return [x0, x1]


def transition_layer2(x, out_filters_list=[32, 64, 128]):
    x0 = Conv2D(out_filters_list[0], 3, padding='same', use_bias=False, kernel_initializer='he_normal')(x[0])
    x0 = BatchNormalization(axis=3)(x0)
    x0 = Activation('relu')(x0)

    x1 = Conv2D(out_filters_list[1], 3, padding='same', use_bias=False, kernel_initializer='he_normal')(x[1])
    x1 = BatchNormalization(axis=3)(x1)
    x1 = Activation('relu')(x1)

    x2 = Conv2D(out_filters_list[2], 3, strides=(2, 2),
                padding='same', use_bias=False, kernel_initializer='he_normal')(x[1])
    x2 = BatchNormalization(axis=3)(x2)
    x2 = Activation('relu')(x2)

    return [x0, x1, x2]


def make_branch2_0(x, out_filters=32):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch2_1(x, out_filters=64):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch2_2(x, out_filters=128):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def fuse_layer2(x):
    x0_0 = x[0]
    x0_1 = Conv2D(32, 1, use_bias=False, kernel_initializer='he_normal')(x[1])
    x0_1 = BatchNormalization(axis=3)(x0_1)
    x0_1 = UpSampling2D(size=(2, 2))(x0_1)
    x0_2 = Conv2D(32, 1, use_bias=False, kernel_initializer='he_normal')(x[2])
    x0_2 = BatchNormalization(axis=3)(x0_2)
    x0_2 = UpSampling2D(size=(4, 4))(x0_2)
    x0 = add([x0_0, x0_1, x0_2])

    x1_0 = Conv2D(64, 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(x[0])
    x1_0 = BatchNormalization(axis=3)(x1_0)
    x1_1 = x[1]
    x1_2 = Conv2D(64, 1, use_bias=False, kernel_initializer='he_normal')(x[2])
    x1_2 = BatchNormalization(axis=3)(x1_2)
    x1_2 = UpSampling2D(size=(2, 2))(x1_2)
    x1 = add([x1_0, x1_1, x1_2])

    x2_0 = Conv2D(32, 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(x[0])
    x2_0 = BatchNormalization(axis=3)(x2_0)
    x2_0 = Activation('relu')(x2_0)
    x2_0 = Conv2D(128, 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(x2_0)
    x2_0 = BatchNormalization(axis=3)(x2_0)
    x2_1 = Conv2D(128, 3, strides=(2, 2), padding='same', use_bias=False, kernel_initializer='he_normal')(x[1])
    x2_1 = BatchNormalization(axis=3)(x2_1)
    x2_2 = x[2]
    x2 = add([x2_0, x2_1, x2_2])
    return [x0, x1, x2]


def transition_layer3(x, out_filters_list=[32, 64, 128, 256]):
    x0 = Conv2D(out_filters_list[0], 3, padding='same', use_bias=False, kernel_initializer='he_normal')(x[0])
    x0 = BatchNormalization(axis=3)(x0)
    x0 = Activation('relu')(x0)

    x1 = Conv2D(out_filters_list[1], 3, padding='same', use_bias=False, kernel_initializer='he_normal')(x[1])
    x1 = BatchNormalization(axis=3)(x1)
    x1 = Activation('relu')(x1)

    x2 = Conv2D(out_filters_list[2], 3, padding='same', use_bias=False, kernel_initializer='he_normal')(x[2])
    x2 = BatchNormalization(axis=3)(x2)
    x2 = Activation('relu')(x2)

    x3 = Conv2D(out_filters_list[3], 3, strides=(2, 2),
                padding='same', use_bias=False, kernel_initializer='he_normal')(x[2])
    x3 = BatchNormalization(axis=3)(x3)
    x3 = Activation('relu')(x3)

    return [x0, x1, x2, x3]


def make_branch3_0(x, out_filters=32):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch3_1(x, out_filters=64):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch3_2(x, out_filters=128):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def make_branch3_3(x, out_filters=256):
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    x = basic_Block(x, out_filters, with_conv_shortcut=False)
    return x


def fuse_layer3(x):
    x0_0 = x[0]
    x0_1 = Conv2D(32, 1, use_bias=False, kernel_initializer='he_normal')(x[1])
    x0_1 = BatchNormalization(axis=3)(x0_1)
    x0_1 = UpSampling2D(size=(2, 2))(x0_1)
    x0_2 = Conv2D(32, 1, use_bias=False, kernel_initializer='he_normal')(x[2])
    x0_2 = BatchNormalization(axis=3)(x0_2)
    x0_2 = UpSampling2D(size=(4, 4))(x0_2)
    x0_3 = Conv2D(32, 1, use_bias=False, kernel_initializer='he_normal')(x[3])
    x0_3 = BatchNormalization(axis=3)(x0_3)
    x0_3 = UpSampling2D(size=(8, 8))(x0_3)
    x0 = concatenate([x0_0, x0_1, x0_2, x0_3], axis=-1)
    return x0


def final_layer(x, classes=1):
    x = Activation('relu')(x)
    x = Conv2D(classes, 1, use_bias=False, kernel_initializer='he_normal')(x)
    x = BatchNormalization(axis=3)(x)
    x = Activation('softmax', name='Classification')(x)
    return x


def seg_hrnet(batch=1, height=256, width=256, channel=3, classes=6):
    inputs = Input(batch_shape=(batch,) + (height, width, channel))
    slic_feature_map = Input(batch_shape=(batch,) + (height, width))

    x0 = stem_net(inputs)

    x1 = transition_layer1(x0)
    x1_0 = make_branch1_0(x1[0])
    x1_1 = make_branch1_1(x1[1])
    x1 = fuse_layer1([x1_0, x1_1])

    x2 = transition_layer2(x1)
    x2_0 = make_branch2_0(x2[0])
    x2_1 = make_branch2_1(x2[1])
    x2_2 = make_branch2_2(x2[2])
    x2 = fuse_layer2([x2_0, x2_1, x2_2])

    x3 = transition_layer3(x2)
    x3_0 = make_branch3_0(x3[0])
    x3_1 = make_branch3_1(x3[1])
    x3_2 = make_branch3_2(x3[2])
    x3_3 = make_branch3_3(x3[3])
    x3 = fuse_layer3([x3_0, x3_1, x3_2, x3_3])
    x3 = UpSampling2D(size=(2, 2))(x3)

    # ORIGINAL
    x4_1 = Conv2D(32, 1, use_bias=False, kernel_initializer='he_normal')(x3)
    x4_1 = BatchNormalization(axis=-1)(x4_1)

    # SUPERPIXEL
    x4_2 = SuperpixelPooling(num_superpixels=100)([x3, slic_feature_map])
    x4_2 = Conv1D(32, 1, use_bias=False, kernel_initializer='he_normal')(x4_2)
    x4_2 = BatchNormalization(axis=-1)(x4_2)
    x4_2 = SuperpixelUnpooling(num_superpixels=100)([x4_2, slic_feature_map])
    print(x4_2)

    x4 = add([x4_1, x4_2])

    out = final_layer(x4_1, classes=classes)

    model = Model(inputs=[inputs, slic_feature_map], outputs=[out])

    return model

# 4. 数据加载generator

## 4.1 读取一个batch的图片

In [5]:
from skimage.segmentation import slic
# 读取图片函数
def read_img(top_paths):
    top_imgs = []
    label_imgs = []
    slic_segs = []
    for top_path in top_paths:
        label_path = top_path.replace('top', 'label')
        
        top_img = tifffile.imread(top_path)
        label_img = tifffile.imread(label_path)
        slic_seg = slic(top_img, n_segments=100, compactness=30, max_iter=10, convert2lab=False, enforce_connectivity=False)
        
        top_img = top_img / 255.0
        
        label_img = np.expand_dims(label_img, axis=2)
        label_img = np_utils.to_categorical(label_img, num_classes=6)
        label_img = label_img[:, :, 0:5]

        top_imgs.append(top_img)
        label_imgs.append(label_img)
        slic_segs.append(slic_seg)

    return np.array(top_imgs), np.array(label_imgs), np.array(slic_segs)

## 4.2 获取批次函数，其实就是一个generator

In [6]:
def batch_generator(top_path, batch_size):
    while 1:
        for i in range(0, len(top_path), batch_size):
            top, label, slic_seg = read_img(top_path[i:i + batch_size])

            yield ([{'input_1': ]top, 'input_2': slic_seg}, {'Classification': label})

## 4.3 读取数据路径

In [7]:
def get_data_paths(train_crop_top_dir, test_crop_top_dir):
    train_crop_top_paths = glob(os.path.join(train_crop_top_dir, '*.tif'))
    test_crop_top_paths = glob(os.path.join(test_crop_top_dir, '*.tif'))

    # 随机打乱训练数据
    index = [m for m in range(len(train_crop_top_paths))]
    random.shuffle(index)
    train_crop_top_paths = np.array(train_crop_top_paths)[index]

    print(index)
    return train_crop_top_paths, test_crop_top_paths

## 5. 定义评价指标

In [8]:
def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        """Recall metric.

        Only computes a batch-wise average of recall.

        Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        """Precision metric.

        Only computes a batch-wise average of precision.

        Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

## 6. 主函数

In [9]:
if __name__ == '__main__':
    # Create a Keras Model
    model = seg_hrnet(1, IMG_WIDTH, IMG_HEIGHT, 3, NB_CLASS)
    model.summary()

    # 优化器函数
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=[f1, 'acc'])

    # 获取路径列表
    train_crop_top_paths, test_crop_top_paths = get_data_paths(TRAIN_TOP_PATH, VAL_TOP_PATH)

    model_checkpoint = ModelCheckpoint('hrnet_superpixel.hdf5', monitor='val_f1', mode='max', verbose=1, save_best_only=True)
    early_stop = EarlyStopping(monitor='val_f1', mode='max', patience=20)
    check_point_list = [model_checkpoint, early_stop]

    result = model.fit_generator(
        generator=batch_generator(train_crop_top_paths, 1),
        steps_per_epoch=10196,
        epochs=500,
        verbose=1,
        validation_data=batch_generator(test_crop_top_paths, 1),
        validation_steps=398,
        callbacks=check_point_list,
        class_weight=[0.7880542, 0.84201391, 1.05645948, 0.94926901, 18.17903545])

Tensor("superpixel_unpooling_1/Reshape_5:0", shape=(1, 512, 512, 32), dtype=float32)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (1, 512, 512, 3)     0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (1, 256, 256, 64)    1728        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (1, 256, 256, 64)    256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (1, 256, 256, 64)    0           batch_normalization_1[0][0]      
________________________

[6343, 3960, 2540, 9557, 168, 9068, 6525, 9199, 7703, 5285, 8141, 2817, 8712, 7562, 10050, 3429, 2151, 4364, 8229, 6207, 655, 149, 5927, 4398, 5830, 5544, 10193, 3916, 6145, 4964, 9351, 3331, 3422, 9508, 8151, 9477, 1948, 9978, 2274, 3189, 7413, 8356, 1042, 6667, 6832, 8400, 8958, 6130, 6448, 2857, 7993, 2720, 9918, 1774, 260, 8549, 2975, 685, 352, 2266, 2516, 8432, 7701, 7311, 546, 1531, 5938, 3078, 7609, 9323, 5181, 9963, 5894, 5701, 1962, 5764, 4859, 3449, 9246, 396, 2330, 10043, 3538, 4643, 3864, 4020, 7935, 8647, 3461, 3360, 612, 4435, 3762, 3393, 4253, 3448, 6205, 1960, 4824, 5670, 10166, 2362, 4086, 3937, 8989, 8986, 495, 7003, 2106, 8679, 5598, 2287, 1652, 2382, 2402, 2276, 9939, 1465, 1137, 403, 7187, 348, 7425, 1443, 6071, 6506, 446, 9003, 3138, 3798, 2749, 1702, 3571, 2775, 2634, 7040, 1155, 1074, 9093, 1744, 4344, 3745, 8279, 4733, 8193, 2198, 5997, 2750, 1809, 9319, 4065, 2656, 1999, 4869, 6700, 3016, 8462, 7079, 2531, 8457, 6164, 1823, 876, 9046, 9151, 4145, 6822, 6753, 1

Epoch 1/500

Epoch 00001: val_f1 improved from -inf to 0.87195, saving model to hrnet_superpixel.hdf5
Epoch 2/500

Epoch 00002: val_f1 improved from 0.87195 to 0.90019, saving model to hrnet_superpixel.hdf5
Epoch 3/500

KeyboardInterrupt: 

In [None]:
plt.figure()
plt.plot(result.epoch, result.history['f1'], label="acc")
plt.plot(result.epoch, result.history['val_f1'], label="val_acc")
plt.scatter(result.epoch, result.history['f1'], marker='*')
plt.scatter(result.epoch, result.history['val_f1'])
plt.legend(loc='under right')
plt.show()

plt.figure()
plt.plot(result.epoch, result.history['loss'], label="loss")
plt.plot(result.epoch, result.history['val_loss'], label="val_loss")
plt.scatter(result.epoch, result.history['loss'], marker='*')
plt.scatter(result.epoch, result.history['val_loss'], marker='*')
plt.legend(loc='upper right')
plt.show()

with open('unet_resnet_101.txt', 'w') as f:
    f.write(str(result.history))