In [None]:
# Water-Net: flood extraction network combining Transformer and CNN from SAR
# Author: Teng Zhao, Xiaoping Du and Xiangtao Fan.
# Version: 1.0
# Date:27/11/2022
#导入相关库
import os
import imageio as io
import cv2
from skimage import img_as_ubyte
import tensorflow as tf
import numpy as np
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Input, Conv2D, MaxPooling2D, \
    UpSampling2D, Concatenate, Dense, multiply, Permute, Add, Lambda,BatchNormalization, Activation, Dropout, Conv2DTranspose
from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow import keras
from datetime import datetime
#超参数设置：
dropout_rate    = 0.1   #Swin Transformer 舍弃率
w = 0.7  #损失函数权重系数
EPOCHS =100  #循环次数

In [None]:
#定义相关模块

def cbam_block(cbam_feature, ratio=8):
    cbam_feature = channel_attention(cbam_feature, ratio)
    cbam_feature = spatial_attention(cbam_feature)
    return cbam_feature


def channel_attention(input_feature, ratio=2):
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    channel = input_feature.shape[channel_axis]

    shared_layer_one = Dense(channel // ratio,
                             activation='relu',
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')
    shared_layer_two = Dense(channel,
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')

    avg_pool = GlobalAveragePooling2D()(input_feature)
    avg_pool = Reshape((1, 1, channel))(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel)
    avg_pool = shared_layer_one(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel // ratio)
    avg_pool = shared_layer_two(avg_pool)
    assert avg_pool.shape[1:] == (1, 1, channel)

    max_pool = GlobalMaxPooling2D()(input_feature)
    max_pool = Reshape((1, 1, channel))(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel)
    max_pool = shared_layer_one(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel // ratio)
    max_pool = shared_layer_two(max_pool)
    assert max_pool.shape[1:] == (1, 1, channel)

    cbam_feature = Add()([avg_pool, max_pool])
    cbam_feature = Activation('sigmoid')(cbam_feature)

    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)

    return multiply([input_feature, cbam_feature])


def spatial_attention(input_feature):
    kernel_size = 7

    if K.image_data_format() == "channels_first":
        channel = input_feature.shape[1]
        cbam_feature = Permute((2, 3, 1))(input_feature)
    else:
        channel = input_feature.shape[-1]
        cbam_feature = input_feature

    avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
    assert avg_pool.shape[-1] == 1
    max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
    assert max_pool.shape[-1] == 1
    concat = Concatenate(axis=3)([avg_pool, max_pool])
    assert concat.shape[-1] == 2
    cbam_feature = Conv2D(filters=1,
                          kernel_size=kernel_size,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          kernel_initializer='he_normal',
                          use_bias=False)(concat)
    assert cbam_feature.shape[-1] == 1

    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)

    return multiply([input_feature, cbam_feature])



def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
    )
    x = tf.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = tf.reshape(x, shape=(-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        windows,
        shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
    )
    x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
    x = tf.reshape(x, shape=(-1, height, width, channels))
    return x


class DropPath(layers.Layer):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, inputs, training=None):
        if self.drop_prob == 0.0 or not training:
            return inputs
        else:
            batch_size = tf.shape(inputs)[0]
            keep_prob = 1 - self.drop_prob
            path_mask_shape = (batch_size,) + (1,) * (len(tf.shape(inputs)) - 1)
            path_mask = tf.floor(
                backend.random_bernoulli(path_mask_shape, p=keep_prob)
            )
            outputs = (
                tf.math.divide(tf.cast(inputs, dtype=tf.float32), keep_prob) * path_mask
            )
            return outputs

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "drop_prob": self.drop_prob,
            }
        )
        return config
class WindowAttention(layers.Layer):
    def __init__(
        self,
        dim,
        window_size,
        num_heads,
        qkv_bias=True,
        dropout_rate=0.0,
        return_attention_scores=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.return_attention_scores = return_attention_scores
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.dropout = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)

    def build(self, input_shape):
        self.relative_position_bias_table = self.add_weight(
            shape=(
                (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
                self.num_heads,
            ),
            initializer="zeros",
            trainable=True,
            name="relative_position_bias_table",
        )

        self.relative_position_index = self.get_relative_position_index(
            self.window_size[0], self.window_size[1]
        )
        super().build(input_shape)

    def get_relative_position_index(self, window_height, window_width):
        x_x, y_y = tf.meshgrid(range(window_height), range(window_width))
        coords = tf.stack([y_y, x_x], axis=0)
        coords_flatten = tf.reshape(coords, [2, -1])

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = tf.transpose(relative_coords, perm=[1, 2, 0])

        x_x = (relative_coords[:, :, 0] + window_height - 1) * (2 * window_width - 1)
        y_y = relative_coords[:, :, 1] + window_width - 1
        relative_coords = tf.stack([x_x, y_y], axis=-1)

        return tf.reduce_sum(relative_coords, axis=-1)

    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
        x_qkv = tf.transpose(x_qkv, perm=(2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = tf.transpose(k, perm=(0, 1, 3, 2))
        attn = q @ k

        relative_position_bias = tf.gather(
            self.relative_position_bias_table,
            self.relative_position_index,
            axis=0,
        )
        relative_position_bias = tf.transpose(relative_position_bias, [2, 0, 1])
        attn = attn + tf.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.get_shape()[0]
            mask_float = tf.cast(
                tf.expand_dims(tf.expand_dims(mask, axis=1), axis=0), tf.float32
            )
            attn = (
                tf.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
                + mask_float
            )
            attn = tf.reshape(attn, shape=(-1, self.num_heads, size, size))
            attn = tf.nn.softmax(attn, axis=-1)
        else:
            attn = tf.nn.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = tf.transpose(x_qkv, perm=(0, 2, 1, 3))
        x_qkv = tf.reshape(x_qkv, shape=(-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)

        if self.return_attention_scores:
            return x_qkv, attn
        else:
            return x_qkv
    
    
class SwinTransformer(layers.Layer):
    def __init__(
        self, 
        dim,
        num_patch,
        num_heads,
        window_size=7,
        shift_size=0,
        num_mlp=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super(SwinTransformer, self).__init__(**kwargs)

        self.dim = dim 
        self.num_patch = num_patch  
        self.num_heads = num_heads 
        self.window_size = window_size  
        self.shift_size = shift_size  
        self.num_mlp = num_mlp  

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = (
            DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity
        )
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(num_mlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(dropout_rate),
                layers.Dense(dim),
                layers.Dropout(dropout_rate),
            ]
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = tf.convert_to_tensor(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = tf.reshape(
                mask_windows, shape=[-1, self.window_size * self.window_size]
            )
            attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False)

    def call(self, x):
        
        _, num_patches_before, channels = x.shape
        height,width = self.num_patch
        x_skip = x
        x = self.norm1(x)
        x = tf.reshape(x, shape=(-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = tf.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(
            x_windows, shape=(-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = tf.reshape(
            attn_windows, shape=(-1, self.window_size, self.window_size, channels)
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = tf.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = tf.reshape(x, shape=(-1, height * width, channels))
        x = self.drop_path(x)
        x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32)
        return x

#针对（512,512）输入尺寸下影像，设置相应的swin Transformer模块参数
swin_sequences_128 = keras.Sequential(name="swin_blocks_128")
swin_sequences_128.add(
                SwinTransformer(
                    dim=32,
                    num_patch=(128, 128),
                    num_heads=4,
                    window_size=8,
                    shift_size=0,
                    num_mlp=128,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
swin_sequences_128.add(
                SwinTransformer(
                    dim=32,
                    num_patch=(128, 128),
                    num_heads=4,
                    window_size=8,
                    shift_size=4,
                    num_mlp=128,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
swin_sequences_64 = keras.Sequential(name="swin_blocks_64")
swin_sequences_64.add(
                SwinTransformer(
                    dim=64,
                    num_patch=(64, 64),
                    num_heads=8,
                    window_size=8,
                    shift_size=0,
                    num_mlp=256,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
swin_sequences_64.add(
                SwinTransformer(
                    dim=64,
                    num_patch=(64, 64),
                    num_heads=8,
                    window_size=8,
                    shift_size=4,
                    num_mlp=256,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
swin_sequences_32 = keras.Sequential(name="swin_blocks_32")
swin_sequences_32.add(
                SwinTransformer(
                    dim=96,
                    num_patch=(32, 32),
                    num_heads=12,
                    window_size=8,
                    shift_size=0,
                    num_mlp=384,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
swin_sequences_32.add(
                SwinTransformer(
                    dim=96,
                    num_patch=(32, 32),
                    num_heads=12,
                    window_size=8,
                    shift_size=4,
                    num_mlp=384,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
swin_sequences_16 = keras.Sequential(name="swin_blocks_16")
swin_sequences_16.add(
                SwinTransformer(
                    dim=128,
                    num_patch=(16, 16),
                    num_heads=16,
                    window_size=8,
                    shift_size=0,
                    num_mlp=512,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
swin_sequences_16.add(
                SwinTransformer(
                    dim=128,
                    num_patch=(16, 16),
                    num_heads=16,
                    window_size=8,
                    shift_size=4,
                    num_mlp=512,
                    qkv_bias=False,
                    dropout_rate=dropout_rate,
                )
            )
transformer_block = [swin_sequences_128,swin_sequences_64,swin_sequences_32,swin_sequences_16]
#定义损失函数
def dice_coef(y_true, y_pred, smooth=1):
    y_pred = tf.where(y_pred >= 0.5, 1., 0.)
    intersection = K.sum(y_true * y_pred, axis=-1)  ##y_true与y_pred都是矩阵！（Unet）
    union = K.sum(y_true, axis=-1) + K.sum(y_pred, axis=-1)
    return 1 - K.mean((2. * intersection + smooth) / (union + smooth))

def dice_p_bce(in_gt, in_pred):
    return w * tf.keras.losses.binary_crossentropy(in_gt, in_pred) + (1 - w) * dice_coef(in_gt, in_pred)
#定义卷积块
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same", kernel_initializer="he_normal")(input)
    x = BatchNormalization()(x)
    x = tfa.activations.mish(x)
#     x= Dropout(0.1)(x)
    x = Conv2D(num_filters, 3, padding="same", kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)
    x = tfa.activations.mish(x)
#     x= Dropout(0.1)(x)
    return x

# Defining the Transpose Convolution Block
def decoder_block(input, skip_features, num_filters):
#     channel_axis = 1 if K.image_data_format() == "channels_first" else -1
#     channel = input.shape[channel_axis]
    x = tf.keras.layers.UpSampling2D(2, interpolation="bilinear")(input)
#     x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = Conv2D(num_filters, 3, padding="same", kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)
    x = tfa.activations.mish(x)
    return x
def fusion_module(x, n_feature, stride_size=(1, 1), shortcut=False, **kwargs):
    out = tf.keras.layers.Conv2D(n_feature, 3, strides=stride_size, padding="same", use_bias=False,
                                 kernel_initializer="he_normal", bias_initializer="zeros")(x)
    out = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(out)
    out = tfa.activations.mish(out)

    out = tf.keras.layers.Conv2D(n_feature, 3, padding="same", use_bias=False, kernel_initializer="he_normal",
                                 bias_initializer="zeros")(out)
    out = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(out)
    out = tf.keras.layers.Add()([out, x])
    out = tfa.activations.mish(out)
    
    return out
#对应多尺度交互模块中的两个环节代码：
def MSA_Module_1(x, n_branch=4, shortcut=False, **kwargs):
    
    #不同尺度交互
    if not isinstance(x, list):
        x = [x]
    n_feature = [tf.keras.backend.int_shape(_x)[-1] for _x in x]
    out = list(x)
    outs = []
    for index, _n_feature in enumerate(n_feature):
        _out = []
        for seq, o in enumerate(out):
            if seq < index:
                for k in range(index - seq):
                    o = tf.keras.layers.Conv2D(_n_feature, 3, strides=2, padding="same", use_bias=False,
                                               kernel_initializer="he_normal", bias_initializer="zeros")(o)
                    o = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(o)
                    if k != (index - seq - 1):
                        o = tfa.activations.mish(o)
            elif seq == index:
                pass
            else:  # index < seq
                o = tf.keras.layers.Conv2D(_n_feature, 1, use_bias=False, kernel_initializer="he_normal",
                                           bias_initializer="zeros")(o)
                o = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(o)
                upsample_size = [2 ** (seq - index)] * 2
                o = tf.keras.layers.UpSampling2D(upsample_size, interpolation="bilinear")(o)
            _out.append(o)
        _out = tf.keras.layers.Add()(_out)
        _out = tfa.activations.mish(_out)
        outs.append(_out)
    #全局维度建模
    for i in range(len(outs)):
            transform_out = outs[i]
            patch_dim = transform_out.shape[-1]
            patch_num = transform_out.shape[1]
            print(patch_dim,patch_num)
            transform_out = tf.reshape(transform_out, (-1, patch_num * patch_num, patch_dim))
            transform_out = transformer_block[i](transform_out)
            transform_out = tf.reshape(transform_out, (-1, patch_num,patch_num, patch_dim))
            outs[i] = transform_out
    return outs

def MSA_Module_2(x, n_branch=4, shortcut=False, **kwargs):
    if not isinstance(x, list):
        x = [x]
    n_feature = [tf.keras.backend.int_shape(_x)[-1] for _x in x]
    out = list(x)
    outs = []
    for index, _n_feature in enumerate(n_feature):
        _out = []
        for seq, o in enumerate(out):
            if seq < index:
                for k in range(index - seq):
                    o = tf.keras.layers.Conv2D(_n_feature, 3, strides=2, padding="same", use_bias=False,
                                               kernel_initializer="he_normal", bias_initializer="zeros")(o)
                    o = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(o)
                    if k != (index - seq - 1):
                        o = tfa.activations.mish(o)
            elif seq == index:
                pass
            else:  # index < seq
                o = tf.keras.layers.Conv2D(_n_feature, 1, use_bias=False, kernel_initializer="he_normal",
                                           bias_initializer="zeros")(o)
                o = tf.keras.layers.BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-5)(o)
                upsample_size = [2 ** (seq - index)] * 2
                o = tf.keras.layers.UpSampling2D(upsample_size, interpolation="bilinear")(o)
            _out.append(o)
        _out = tf.keras.layers.Add()(_out)
        _out = tfa.activations.mish(_out)
        outs.append(_out)
    for _ in range(n_branch):
        for index, _n_feature in enumerate(n_feature):
            outs[index] = fusion_module(outs[index], _n_feature, shortcut=shortcut, **kwargs)
    return outs
def MSAM(x,n_channel=[32, 64, 96, 128],**kwargs):
    x = list(x)
    shape = tf.keras.backend.int_shape(x[0])[-3:-1]
    #1*1降维
    for index in range(len(x)):
        o = tf.keras.layers.Conv2D(n_channel[index], 1, padding="same", use_bias=False,
                                   kernel_initializer="he_normal", bias_initializer="zeros")(x[index])
        o = BatchNormalization()(o)
        o = tfa.activations.mish(o)
        x[index] = o
    #两个多尺度交互
    out = MSA_Module_1(x, n_branch=4, shortcut=False, **kwargs)
    out = MSA_Module_2(out, n_branch=4, shortcut=False, **kwargs)
    #上采样至统一尺寸输出
    for index in range(1, len(out)):
            upsample_size = np.divide(tf.keras.backend.int_shape(out[0])[-3:-1],
                                      tf.keras.backend.int_shape(out[index])[-3:-1]).astype(np.int32)
            out[index] = tf.keras.layers.UpSampling2D(upsample_size, interpolation="bilinear")(out[index])
        out = tf.keras.layers.Concatenate(axis=-1)(out)
        out = Conv2D(128, 1, padding="same", kernel_initializer="he_normal")(out)
        out = BatchNormalization()(out)
        out = tfa.activations.mish(out)
        out = cbam_block(out)
        
    return out
# Building the Water-Net
def Water_Net(input_shape):
    """ Input """
    inputs = Input(shape=input_shape, name='input_image')

    """EfficientNetB0 Model """
    effNetB4 = tf.keras.applications.EfficientNetB0(input_tensor=inputs, include_top=False)

    """ Encoder """
    s1 = effNetB4.get_layer("input_image").output  ## (512 x 512)
    s1 = conv_block(s1,32)
    s2 = effNetB4.get_layer("block1a_activation").output  ## (256 x 256)
    s3 = effNetB4.get_layer("block2a_activation").output  ## (128 x 128)
    s4 = effNetB4.get_layer("block3a_activation").output  ## (64 x 64)
    s5 = effNetB4.get_layer("block4a_activation").output  ## (32 x 32)
    b1 = effNetB4.get_layer("block7a_activation").output
    
    """ Decoder """
    input_layers = [s3,s4,s5,b1]
    out = MSAM(input_layers)
    d4 = decoder_block(out, s2, 64)  ## (256 x 256)
    d5 = decoder_block(d4, s1, 32)  ## (512 x 512)
    
    """ Output """
    outputs = Conv2D(1, 1, padding="same", activation="sigmoid",name ="outputs")(d5)
    model = Model(inputs, outputs, name="EfficientNetB4_U-Net")
    return model

#绘制损失函数等指标曲线
def plot_metrics(history):
    metrics = ['loss', 'auc', 'precision', 'recall']
    plt.figure()
    for n, metric in enumerate(metrics):
        name = metric
        plt.subplot(2, 2, n + 1)
        plt.plot(history.epoch, history.history[metric], color="green", label='Train')
        plt.plot(history.epoch, history.history['val_' + metric],
                 color="red", linestyle="--", label='Val')
        plt.xlabel('Epoch')
        plt.ylabel(name)
        if metric == 'loss':
            plt.ylim([0, plt.ylim()[1]])
        elif metric == 'auc':
            plt.ylim([0.8, 1])
        else:
            plt.ylim([0, 1])

    plt.legend()
    plt.show()
#定义meanIOU指标函数
def get_iou_vector(A, B):
    intersection = 0
    union =0
    B = np.where(B >= 0.5, 1., 0.)
    intersection =np.logical_and(A, B).sum()
    union =np.logical_or(A, B).sum()
    iou = intersection / union
    return iou
def my_iou_metric(label, pred):
    return tf.numpy_function(get_iou_vector, [label, pred], tf.float64)


In [None]:
#定义输出路径
mkdir -p ./data

In [None]:
train_image_paths = r"../input/avert-data/process_data/image/"
train_mask_paths = r"../input/avert-data/process_data/label/"
test_image_paths =  r"../input/avert-data/process_data/test_image/"
test_mask_paths = r"../input/avert-data/process_data/test_label/"
val_image_paths =  r"../input/avert-data/process_data/val_image/"
val_mask_paths = r"../input/avert-data/process_data/val_label/"
train_filenames = os.listdir(train_image_paths)
test_filenames = os.listdir(test_image_paths)
val_filenames = os.listdir(val_image_paths)
num_image = len(train_filenames)
test_num = len(test_filenames)
print(num_image,test_num)

def load_img_and_mask(image_path,mask_path,filename):

    image = tf.io.read_file(image_path +filename)         #read file into buffer
    image = tf.image.decode_png(image, channels=3)             #decode jpeg into tensor
    mask_filename = (mask_path +filename)                      #create full filename for masks
    mask = tf.io.read_file(mask_filename)                       #read masks
    mask = tf.image.decode_png(mask, channels=1)              #decode mask into tensor
#     mask = tf.stack([mask,mask],axis=0)
    return (image, mask)                                        #return image and mask as a tuple

def scale_values(image, mask, mask_split_threshold=128):
    image = tf.cast(image,tf.float32)
#     image = tf.math.divide(image, 255.0)
    mask = tf.cast(mask, tf.float32)
    mask = tf.where(mask>128, 1., 0.)
    return (image, mask)


print("加载及处理数据中——————————")
dataset = tf.data.Dataset.from_tensor_slices(train_filenames)
test_dataset =  tf.data.Dataset.from_tensor_slices(test_filenames)
val_dataset =  tf.data.Dataset.from_tensor_slices(val_filenames)
dataset = dataset.map(lambda x: load_img_and_mask(train_image_paths,
                                                  train_mask_paths,
                                                   filename=x))
test_dataset = test_dataset.map(lambda x: load_img_and_mask(test_image_paths,
                                                  test_mask_paths,
                                                   filename=x))
val_dataset = val_dataset.map(lambda x: load_img_and_mask(val_image_paths,
                                                  val_mask_paths,
                                                   filename=x))
dataset = dataset.map(scale_values)
test_dataset = test_dataset.map(scale_values)
val_dataset = val_dataset.map(scale_values)
train_dataset = dataset.batch(8)
val_dataset = val_dataset.batch(8)
test_dataset = test_dataset.batch(8)

print(f'Original Set: {dataset}')
print(f'Training Set: {train_dataset}')
print(f'Validation Set: {val_dataset}')
print(f'Testing Set: {test_dataset}')
METRICS = [
    my_iou_metric,
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall'),
    tf.keras.metrics.AUC(name='auc'),
    tf.keras.metrics.AUC(name='prc', curve='PR'),  # precision-recall curve

]
checkpoint = ModelCheckpoint(filepath='./model_best_cbam_tenth_2.h5',
                             monitor='val_loss',
                             verbose=1,
                             save_best_only=True,
                             save_weights_only=True,
                             mode='min')
early_stopping = EarlyStopping(monitor='val_loss',mode='min', patience=10)
tqdm_callback = tfa.callbacks.TQDMProgressBar()
callbacks = [checkpoint,
             early_stopping,
            tqdm_callback]
model = Water_Net((512, 512, 3))
lr_decayed_fn = (
  tf.keras.optimizers.schedules.CosineDecayRestarts(
     3e-4,10))
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_decayed_fn)
model.compile(optimizer=optimizer, loss=dice_p_bce, metrics = METRICS)
history = model.fit(train_dataset ,validation_data=val_dataset,epochs=EPOCHS, callbacks=callbacks,verbose = 2)
plot_metrics(history)
# 验证阶段
model.load_weights('./model_best_cbam_tenth_2.h5')
meanIoU = tf.keras.metrics.MeanIoU(num_classes=2)  # define meanIoU
meanRecall = tf.keras.metrics.Recall()
meanPrecision = tf.keras.metrics.Precision()
meanIoU.reset_states()  ##清除之前的计算结果，相当于复位重新开始计算
meanRecall.reset_states()  ##清除之前的计算结果，相当于复位重新开始计算
meanPrecision.reset_states()  ##清除之前的计算结果，相当于复位重新开始计算
k = 0
intersection = 0
union = 0
q = 0
now1 = datetime.now()#统计程序开始的时间
for ele in test_dataset:  
    image, y_true = ele   
    prediction = model.predict(image)# make model prediction based on image
    prediction = np.where(prediction >= 0.5, 1., 0.)
    
    for i in range(len(prediction[:,0,0])):
        prediction_1 = np.where(prediction > 0.5,255, 0)
        y_true_1 = tf.where(y_true >=0.5,255, 0)
        test = cv2.merge([image[i][:,:,0].numpy(),image[i][:,:,0].numpy(),image[i][:,:,1].numpy()])
        cv2.imwrite("./data/%d_image.png" % q,test) 
        cv2.imwrite("./data/%d_true.png" % q,y_true_1[i].numpy())
        cv2.imwrite("./data/%d_prediction.png" % q,prediction_1[i])
        q = q+1
    intersection += np.logical_and(y_true, prediction).sum()
    union += np.logical_or(y_true, prediction).sum()
    meanIoU.update_state(y_true, prediction)  # update the state of the meanIoU metric
    meanRecall.update_state(y_true, prediction)
    meanPrecision.update_state(y_true, prediction)
    print("第{}个影像结束：".format(k))
    k = k + 1
now2 = datetime.now()#统计程序结束的时间
print(now2-now1)#得到程序运行的总时间
print("meanIOU:", intersection / union)
IoU_result = meanIoU.result().numpy()  # select the Mean IoU score
Recall_result = meanRecall.result().numpy()  # select the Mean IoU score
Precision_result = meanPrecision.result().numpy()  # select the Mean IoU score
print(f'Mean IoU: {IoU_result}')  # print result
print(f'Mean Recall: {Recall_result}')  # print result
print(f'Mean Precision: {Precision_result}')  # print result