In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, Conv2D, LayerNormalization
from tensorflow.keras.activations import softmax
from tensorflow.keras.layers import Layer


from tensorflow.keras import Model
from tensorflow.keras import Sequential
import tensorflow.keras.layers as nn

from tensorflow import einsum
from einops import rearrange, repeat
from einops.layers.tensorflow import Rearrange

from tensorflow.image import extract_patches
from tensorflow.keras.layers import Conv2D, Layer, Dense, Embedding

import numpy as np
from tensorflow.keras.layers import Dense, Dropout, Conv2D, LayerNormalization
from tensorflow.keras.activations import softmax
from math import ceil
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import glob
import time
import pickle
import cv2
from tensorflow.python.client import device_lib
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer
from tensorflow.keras import Sequential
import tensorflow.keras.layers as nn
from tensorflow import einsum
from einops import rearrange, repeat
from einops.layers.tensorflow import Rearrange
import numpy as np
from tensorflow.keras.layers import Dense, Dropout, Conv2D, LayerNormalization
from tensorflow.keras.activations import softmax
from PIL import Image

In [None]:
def drop_path_(inputs, drop_prob, is_training):
    
    # Bypass in non-training mode
    if (not is_training) or (drop_prob == 0.):
        return inputs

    # Compute keep_prob
    keep_prob = 1.0 - drop_prob

    # Compute drop_connect tensor
    input_shape = tf.shape(inputs)
    batch_num = input_shape[0]; rank = len(input_shape)
    
    shape = (batch_num,) + (1,) * (rank - 1)
    random_tensor = keep_prob + tf.random.uniform(shape, dtype=inputs.dtype)
    path_mask = tf.floor(random_tensor)
    output = tf.math.divide(inputs, keep_prob) * path_mask
    return output

class drop_path(Layer):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def call(self, x, training=None):
        return drop_path_(x, self.drop_prob, training)

In [None]:
class RMSNorm(tf.keras.layers.Layer):
    def __init__(self, eps = 1e-5):
        super(RMSNorm, self).__init__()
        self.eps = eps
    def build(self, input_shape):
        self.weight = self.add_weight(shape = (input_shape[-1],), dtype = tf.float32, trainable = True, initializer = tf.keras.initializers.Constant(1.), name = 'weight')
    def compute_output_shape(self, input_shape):
        return input_shape
    def call(self, inputs):
        stddev = tf.math.maximum(tf.math.sqrt(tf.math.reduce_mean(inputs ** 2, axis = -1, keepdims = True)), self.eps)
        results = inputs / stddev
        results = results * self.weight
        return results
    def get_config(self):
        config = super(RMSNorm, self).get_config()
        config['eps'] = self.eps
        return config
    @classmethod
    def from_config(cls, config):
        return cls(**config)
    
def selective_scan(u, delta, A, B, C, D):
    dA = tf.einsum('bld,dn->bldn', delta, A) # first step of A_bar = exp(ΔA), i.e., ΔA
    dB_u = tf.einsum('bld,bld,bln->bldn', delta, u, B)
    dA_cumsum = tf.pad(
        dA[:, 1:], [[0, 0], [1, 1], [0, 0], [0, 0]])[:, 1:, :, :]
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip along axis 1
    # Cumulative sum along all the input tokens, parallel prefix sum, calculates dA for all the input tokens parallely
    dA_cumsum = tf.math.cumsum(dA_cumsum, axis=1)  
    dA_cumsum = tf.exp(dA_cumsum)  # second step of A_bar = exp(ΔA), i.e., exp(ΔA)
    dA_cumsum = tf.reverse(dA_cumsum, axis=[1])  # Flip back along axis 1
    x = dB_u * dA_cumsum
    x = tf.math.cumsum(x, axis=1)/(dA_cumsum + 1e-12) # 1e-12 to avoid division by 0
    y = tf.einsum('bldn,bln->bld', x, C)
    return y + u * D 

class SSM(tf.keras.layers.Layer):
    def __init__(self, d_model, expand = 2, d_state = 16, bias = False):
        super(SSM, self).__init__()
        self.d_model = d_model
        self.expand = expand
        self.d_state = d_state
        self.bias = bias
        self.dt_rank = ceil(self.d_model / 16)
    def build(self, input_shape):
        self.x_proj_weight = self.add_weight(shape = (self.d_model * self.expand, self.dt_rank + 2 * self.d_state), dtype = tf.float32, trainable = True, name = 'x_proj_weight')
        if self.bias:
            self.x_proj_bias = self.add_weight(shape = (self.dt_rank + 2 * self.d_state), dtype = tf.float32, trainable = True, name = 'x_proj_bias')
        self.dt_proj_weight = self.add_weight(shape = (self.dt_rank, self.expand * self.d_model), dtype = tf.float32, trainable = True, name = 'dt_proj_wei9ght')
        self.dt_proj_bias = self.add_weight(shape = (self.expand * self.d_model,), dtype = tf.float32, trainable = True, name = 'dt_proj_bias')
        self.A_log = self.add_weight(shape = (self.expand * self.d_model, self.d_state), dtype = tf.float32, trainable = True, name = 'A_log')
        self.A_log.assign(tf.math.log(tf.tile(tf.expand_dims(tf.range(1, self.d_state + 1, dtype = tf.float32), axis = 0), (self.expand * self.d_model, 1))))
        self.D = self.add_weight(shape = (self.expand * self.d_model,), dtype = tf.float32, trainable = True, initializer = tf.keras.initializers.Constant(1.), name = 'D')
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], self.d_model * self.expand)
    def call(self, x):
        # x.shape = (batch, seq_len, d_model * expand)
        x_dbl = tf.linalg.matmul(x, self.x_proj_weight) # x_dbl.shape = (batch, seq_len, dt_rank + 2 * d_state)
        if self.bias:
            x_dbl = x_dbl + self.x_proj_bias
        delta, B, C = tf.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], axis = -1)
        # delta.shape = (batch, seq_len, dt_rank)
        # B.shape = (batch, seq_len, d_state)
        # C.shape = (batch, seq_len, d_state)
        delta = tf.math.softplus(tf.linalg.matmul(delta, self.dt_proj_weight) + self.dt_proj_bias) # delta.shape = (batch, seq_len, expand * d_model)
        # selective scan
        # state(t+1) = A state(t) + B x(t) # B is input gate
        # y(t)   = C state(t) + D x(t) # C is output gate
        A = -tf.exp(self.A_log) # A.shape = (expand * d_model, d_state)
        y = selective_scan(x, delta,A,B,C,self.D)
        return y
    def get_config(self):
        config = super(SSM, self).get_config()
        config['d_model'] = self.d_model
        config['expand'] = self.expand
        config['d_state'] = self.d_state
        config['bias'] = self.bias
        return config
    @classmethod
    def from_config(cls, config):
        return cls(**config)

class MambaBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, expand = 2, bias = False, d_conv = 4, conv_bias = True, d_state = 16):
        super(MambaBlock, self).__init__()
        self.d_model = d_model
        self.expand = expand
        self.d_state = d_state
        self.bias = bias
        self.dt_rank = ceil(self.d_model / 16)
        self.d_conv = d_conv
        self.conv_bias = conv_bias
        self.fliter = d_model*expand
    def call(self,x):
        x_and_res = tf.keras.layers.Dense(2 * self.expand * self.d_model, use_bias = self.bias)(x) # results.shape = (batch, seq_len, 2 * expand * d_model)
        x, res = tf.keras.layers.Lambda(lambda x: tf.split(x, 2, axis = -1))(x_and_res) # x.shape = (batch, seq_len, expand * d_model)
        # spatial & channel mixing
        x = tf.keras.layers.Conv1D(self.fliter, kernel_size = self.d_conv, padding = 'same', use_bias = self.conv_bias, activation = tf.keras.activations.swish)(x) # x.shape = (batch, seq_len, expand * d_model)
        # selective state space model
        y = SSM(self.d_model, self.expand, self.d_state, self.bias)(x) # y.shape = (batch, seq_len, d_model * expand)
        # NOTE: borrowing idea of Swish gated linear unit (SwiGLU)
        # this layer gates ssm results with swish layer as well. it can be called as swish gated selective state space model (SwiSSM)
        y = tf.keras.layers.Lambda(lambda x: x[0] * tf.nn.silu(x[1]))([y, res])
        outputs = tf.keras.layers.Dense(self.d_model, use_bias = self.bias)(y)
        return outputs

def ResidualBlock(d_model, expand = 2, bias = False, d_conv = 4, conv_bias = True, d_state = 16):
    inputs = tf.keras.Input((None, d_model)) # inputs.shape = (batch, seq_len, d_model)
    Path3 = tf.keras.layers.DepthwiseConv1D(kernel_size=5,strides=1,padding="same")(inputs)
    Path3 =  tf.nn.silu(Path3)
    results = RMSNorm()(inputs)
    path2 = tf.keras.layers.DepthwiseConv1D(kernel_size=3,strides=1,padding="same")(results)
    Path2 =  tf.nn.silu(path2)
    results = MambaBlock(d_model, expand, bias, d_conv, conv_bias, d_state)(results)
    results = tf.keras.layers.Add()([results, inputs, Path3, Path2])
    results = tf.keras.layers.Conv1D(d_model, 1, 1, padding="same")(results)
    return tf.keras.Model(inputs = inputs, outputs = results)

In [None]:
class Mlp(tf.keras.layers.Layer):
    def __init__(self, filter_num, drop=0., name=''):
        
        super().__init__()
        
        # MLP layers
        self.fc1 = Dense(filter_num[0], name='{}_mlp_0'.format(name))
        self.fc2 = Dense(filter_num[1], name='{}_mlp_1'.format(name))
        
        # Dropout layer
        self.drop = Dropout(drop)
        
        # GELU activation
        self.activation = tf.keras.activations.gelu
        
    def call(self, x):
        
        # MLP --> GELU --> Drop --> MLP --> Drop
        x = self.fc1(x)
        self.activation(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        
        return x

class patch_extract(Layer):
    '''
    Extract patches from the input feature map.
    
    patches = patch_extract(patch_size)(feature_map)
    
    ----------
    Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, 
    T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020. 
    An image is worth 16x16 words: Transformers for image recognition at scale. 
    arXiv preprint arXiv:2010.11929.
    
    Input
    ----------
        feature_map: a four-dimensional tensor of (num_sample, width, height, channel)
        patch_size: size of split patches (width=height)
        
    Output
    ----------
        patches: a two-dimensional tensor of (num_sample*num_patch, patch_size*patch_size)
                 where `num_patch = (width // patch_size) * (height // patch_size)`
                 
    For further information see: https://www.tensorflow.org/api_docs/python/tf/image/extract_patches
        
    '''
    
    def __init__(self, patch_size):
        super(patch_extract, self).__init__()
        self.patch_size_x = patch_size[0]
        self.patch_size_y = patch_size[1]
        
    def call(self, images):
        
        batch_size = tf.shape(images)[0]
        
        patches = extract_patches(images=images,
                                  sizes=(1, self.patch_size_x, self.patch_size_y, 1),
                                  strides=(1, self.patch_size_x, self.patch_size_y, 1),
                                  rates=(1, 1, 1, 1), padding='VALID',)
        # patches.shape = (num_sample, patch_num, patch_num, patch_size*channel)
        
        patch_dim = patches.shape[-1]
        patch_numx = patches.shape[1]
        patch_numy = patches.shape[2]
        patches = tf.reshape(patches, (batch_size, patch_numx*patch_numy, patch_dim))
        # patches.shape = (num_sample, patch_num*patch_num, patch_size*channel)
        
        return patches
    
class patch_embedding(Layer):
    '''
    
    Embed patches to tokens.
    
    patches_embed = patch_embedding(num_patch, embed_dim)(pathes)
    
    ----------
    Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, 
    T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S. and Uszkoreit, J., 2020. 
    An image is worth 16x16 words: Transformers for image recognition at scale. 
    arXiv preprint arXiv:2010.11929.
    
    Input
    ----------
        num_patch: number of patches to be embedded.
        embed_dim: number of embedded dimensions. 
        
    Output
    ----------
        embed: Embedded patches.
    
    For further information see: https://keras.io/api/layers/core_layers/embedding/
    
    '''
    
    def __init__(self, num_patch, embed_dim):
        super(patch_embedding, self).__init__()
        self.num_patch = num_patch
        self.proj = Dense(embed_dim)
        self.pos_embed = Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = tf.range(start=0, limit=self.num_patch, delta=1)
        embed = self.proj(patch) + self.pos_embed(pos)
        return embed

class patch_merging(tf.keras.layers.Layer):
    '''
    Downsample embedded patches; it halfs the number of patches
    and double the embedded dimensions (c.f. pooling layers).
    
    Input
    ----------
        num_patch: number of patches to be embedded.
        embed_dim: number of embedded dimensions. 
        
    Output
    ----------
        x: downsampled patches.
    
    '''
    def __init__(self, num_patch, embed_dim, name=''):
        super().__init__()
        
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        
        # A linear transform that doubles the channels 
        self.linear_trans = Dense(2*embed_dim, use_bias=False, name='{}_linear_trans'.format(name))

    def call(self, x):
        
        H, W = self.num_patch
        B, L, C = x.get_shape().as_list()
        
        assert (L == H * W), 'input feature has wrong size'
        assert (H % 2 == 0 and W % 2 == 0), '{}-by-{} patches received, they are not even.'.format(H, W)
        
        # Convert the patch sequence to aligned patches
        x = tf.reshape(x, shape=(-1, H, W, C))
        
        # Downsample
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = tf.concat((x0, x1, x2, x3), axis=-1)
        
        # Convert to the patch squence
        x = tf.reshape(x, shape=(-1, (H//2)*(W//2), 4*C))
       
        # Linear transform
        x = self.linear_trans(x)

        return x

class patch_expanding(tf.keras.layers.Layer):

    def __init__(self, num_patch, embed_dim, upsample_rate, return_vector=True, name=''):
        super().__init__()
        
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.upsample_rate = upsample_rate
        self.return_vector = return_vector
        
        # Linear transformations that doubles the channels 
        self.linear_trans1 = Conv2D(upsample_rate*embed_dim, kernel_size=1, use_bias=False, name='{}_linear_trans1'.format(name))
        # 
        self.linear_trans2 = Conv2D(upsample_rate*embed_dim, kernel_size=1, use_bias=False, name='{}_linear_trans1'.format(name))
        self.prefix = name
        
    def call(self, x):
        
        H, W = self.num_patch
        B, L, C = x.get_shape().as_list()
        
        assert (L == H * W), 'input feature has wrong size'

        x = tf.reshape(x, (-1, H, W, C))
        
        x = self.linear_trans1(x)
        
        # rearange depth to number of patches
        x = tf.nn.depth_to_space(x, self.upsample_rate, data_format='NHWC', name='{}_d_to_space'.format(self.prefix))
        
        if self.return_vector:
            # Convert aligned patches to a patch sequence
            x = tf.reshape(x, (-1, L*self.upsample_rate*self.upsample_rate, C//2))

        return x

def residual_block(input_x, input_filters, is_downpooling=False, is_uppooling=False):
    
    x = tf.keras.layers.BatchNormalization()(input_x)
    x = tf.keras.layers.Activation('relu')(x)
 
    x = tf.keras.layers.Conv2D(filters=input_filters, kernel_size=3, strides=1, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
 
    x = tf.keras.layers.Conv2D(filters=input_filters, kernel_size=3, strides=1, padding='same')(x)
    x = tf.keras.layers.concatenate([x, input_x],axis=-1)
    x = tf.keras.layers.Conv2D(filters=input_filters, kernel_size=3, strides=1, padding='same')(x)
    if is_downpooling:
        x = tf.keras.layers.MaxPool2D(pool_size=(2, 2),strides=(2, 2), padding='valid')(x)
    if is_uppooling:
        x = tf.keras.layers.UpSampling2D((2, 2))(x)
    
    return x

def Mamba_unet(input_tensor, filter_num_begin, depth, stack_num_down, stack_num_up, 
                      patch_size, num_heads, window_size, num_mlp, n_labels, name='Mamba_unet'):
    '''
    The base of Mamba_unet.
    
    The general structure:
    
    1. Input image --> a sequence of patches --> tokenize these patches
    2. Downsampling: swin-transformer --> patch merging (pooling)
    3. Upsampling: concatenate --> swin-transfprmer --> patch expanding (unpooling)
    4. Model head
    
    '''
    # Compute number be patches to be embeded
    input_size = input_tensor.shape.as_list()[1:]
    num_patch_x = input_size[0]//patch_size[0]
    num_patch_y = input_size[1]//patch_size[1]
    
    # Number of Embedded dimensions
    embed_dim = filter_num_begin
    
    depth_ = depth
    
    X_skip = []

    X = input_tensor
    
    Img_X = X #图像输入
    
    X_imageskip=[]
    
    # Patch extraction
    X = patch_extract(patch_size)(X)

    # Embed patches to tokens
    X = patch_embedding(num_patch_x*num_patch_y, embed_dim)(X)
    
    # The first Mamba stack
    X = ResidualBlock(embed_dim, expand = 2, bias = False, d_conv = 4, conv_bias = True, d_state = 16)(X)                               
                                      
    X_skip.append(X)
    
    Img_Xd = residual_block(Img_X,32,is_downpooling=False, is_uppooling=False)
    
    X_imageskip.append(Img_Xd)
    
    # Downsampling blocks
    for i in range(depth_-1):
        
        # Patch merging
        X = patch_merging((num_patch_x, num_patch_y), embed_dim=embed_dim, name='down{}'.format(i))(X)
        
        # update token shape info
        embed_dim = embed_dim*2
        num_patch_x = num_patch_x//2
        num_patch_y = num_patch_y//2
                                      
        Img_X = residual_block(Img_X,64*(2*i+1), is_downpooling=True, is_uppooling=False)
        X_imageskip.append(Img_X)
        Img_X_patch = patch_extract(patch_size)(Img_X)
        Img_X_patch = patch_embedding(num_patch_x*num_patch_y, embed_dim)(Img_X_patch)
        

        #Mamba block
        X1 = ResidualBlock(embed_dim, expand = 2, bias = False, d_conv = 2, conv_bias = True, d_state = 16)(X)
        X = tf.concat([X1,Img_X_patch],axis=-1)
        X = Dense(embed_dim, use_bias=False)(X)
        
        # Store tensors for concat
        X_skip.append(X)
        
    # reverse indexing encoded tensors and hyperparams
    X_skip = X_skip[::-1]
    num_heads = num_heads[::-1]
    window_size = window_size[::-1]
    X_imageskip = X_imageskip[::-1]
    
    # upsampling begins at the deepest available tensor
    X = X_skip[0]
    
    # other tensors are preserved for concatenation
    X_decode = X_skip[1:]
    
    depth_decode = len(X_decode)
    
    X_pyramid = []
    
    for i in range(depth_decode):
        
        # Patch expanding
        X = patch_expanding(num_patch=(num_patch_x, num_patch_y), 
                                               embed_dim=embed_dim, 
                                               upsample_rate=2, 
                                               return_vector=True)(X)
        # update token shape info
        embed_dim = embed_dim//2
        num_patch_x = num_patch_x*2
        num_patch_y = num_patch_y*2
        
        Img_X = residual_block(tf.concat([Img_X,X_imageskip[i]],axis=-1),embed_dim,is_downpooling=False, is_uppooling=True) #图像上采样
        #print(Img_X.shape)
        
        # Concatenation and linear projection
        X = tf.keras.layers.concatenate([X, X_decode[i]], axis=-1, name='{}_concat_{}'.format(name, i))
        X = Dense(embed_dim, use_bias=False, name='{}_concat_linear_proj_{}'.format(name, i))(X)

        X = ResidualBlock(embed_dim, expand = 2, bias = False, d_conv = 2, conv_bias = True, d_state = 16)(X)
        #linear projection
        X = Dense(embed_dim, use_bias=False)(X)
        #Token转image
        X_pyramid_image =  patch_expanding(num_patch=(num_patch_x, num_patch_y), 
                                           embed_dim=embed_dim, 
                                           upsample_rate=patch_size[0], 
                                           return_vector=False)(X)
        #图像融合
        Img_X = residual_block(tf.concat([Img_X,X_pyramid_image],axis=-1),64*(2*(depth_decode-i)),is_downpooling=False, is_uppooling=False)
        X_pyramid.append(Img_X)

    for i in range(len(X_pyramid)):
        X_pyramid[i]  = tf.keras.layers.Conv2D(64, kernel_size=3, strides=1, padding='same',activation="relu")(X_pyramid[i])
        X_pyramid[i]  = tf.keras.layers.Conv2DTranspose(32, kernel_size=1, strides=2**(2-i), padding='same',activation="relu")(X_pyramid[i])
        #print(X_pyramid[i].shape)
        #X_pyramid[i] = Conv2D(n_labels, kernel_size=1, use_bias=False, activation='softmax')(X_pyramid[i])

    Final_output = tf.keras.layers.concatenate([X_pyramid[0],X_pyramid[1],X_pyramid[2]],axis=-1)
    Final_output = Conv2D(16, kernel_size=2, padding='same',activation='relu')(Final_output)
    Final_output = Conv2D(n_labels, kernel_size=1,padding='same', activation='softmax')(Final_output)
    
    return Final_output

In [None]:
filter_num_begin = 64      # number of channels in the first downsampling block; it is also the number of embedded dimensions
depth = 4                  # the depth of Mamba; depth=4 means three down/upsampling levels and a bottom level 
stack_num_down = 2         # number of Mambaper downsampling level
stack_num_up = 2           # number of Mamba per upsampling level
patch_size = (4, 4)        # Extract 4-by-4 patches from the input image. Height and width of the patch must be equal.
num_heads = [4, 8, 8, 8]   # number of attention heads per down/upsampling level
window_size = [4, 2, 2, 2] # the size of attention window per down/upsampling level
num_mlp = 128              # number of MLP nodes within the Transformer

n_labels = 2
IN = tf.keras.layers.Input((512,512,3))
# # Output section
OUT = Mamba_unet(IN, filter_num_begin, depth, stack_num_down, stack_num_up, 
                      patch_size, num_heads, window_size, num_mlp,n_labels, 
                      name='Mamba_unet')
# Model configuration
model = Model(inputs=[IN,], outputs=OUT)
img = tf.random.normal(shape=[1,512,512,3])
preds = model(img)
model.summary()

In [None]:
###########加载数据#############
image_dir =r'D:\AI in NTU\Tunnel segment\concreteCrackSegmentationDataset\rgb'
mask_dir =r'D:\AI in NTU\Tunnel segment\concreteCrackSegmentationDataset\BW'
image_test_dir =r'D:\AI in NTU\Tunnel segment\concreteCrackSegmentationDataset\Test_rgb'
mask_test_dir =r'D:\AI in NTU\Tunnel segment\concreteCrackSegmentationDataset\Test_BW'
def get_image_and_mask_paths(image_dir:str,mask_dir:str):
    '''
    获取所有图片与对应标签的路径 
    '''
    # 导入数据集
    all_file_image=os.listdir(image_dir)
    all_image=[]
    all_mask=[]
    for i in range(len(os.listdir(image_dir))):
        if (os.path.splitext(all_file_image[i])[1] == ".jpg") or (os.path.splitext(all_file_image[i])[1] == ".JPG"):
            all_image.append(image_dir + "\\" + all_file_image[i])
            all_mask.append(mask_dir + "\\" +  all_file_image[i])
    all_image=np.array(all_image)[:,np.newaxis]
    all_mask =np.array(all_mask)[:,np.newaxis]
    all_path = np.concatenate((all_image,all_mask),axis=-1)
    return all_path

paths = get_image_and_mask_paths(image_dir,mask_dir)
validpaths = get_image_and_mask_paths(image_test_dir,mask_test_dir)
BATCH_SIZE = 2

#创建图片路径及其数字标签的dataset
db_train= tf.data.Dataset.from_tensor_slices(paths)
db_train = db_train.shuffle(buffer_size=4,seed=2024)
db_train = db_train.batch(BATCH_SIZE)

db_test= tf.data.Dataset.from_tensor_slices(validpaths)
db_test = db_test.shuffle(buffer_size=4,seed=2024)
db_test = db_test.batch(BATCH_SIZE)

def load_image(path,ran):
    image = tf.io.read_file(path)
    image = tf.cast(tf.image.decode_jpeg(image,channels=3),dtype=tf.float32)
    if ran>0.5:#随机左右翻转图像，增强数据
        image=tf.image.flip_left_right(image)
    image = tf.image.resize(image,[512,512])
    image /= 255.0
    image = image*2-1#将整张图片的值规范在[-1,1]之间
    return image

def load_image_mask(path,ran):
    image = tf.io.read_file(path)
    image = tf.cast(tf.image.decode_jpeg(image,channels=1),dtype=tf.float32)
    if ran>0.5:#随机左右翻转图像，增强数据
        image=tf.image.flip_left_right(image)
    image /= 255.0
    image = tf.image.resize(image,[512,512])
    image = tf.cast(tf.round(image),dtype=tf.uint8)
    image = tf.one_hot(image,depth=2)
    return image

#####训练##########
optimizer = tf.keras.optimizers.Adam(1e-4)
History=[]
valid_num=0
for epoch in range(0,100,1):
    count=0
    count_valid=0
    Average_loss=0
    Average_valid_MeanIoU=0
    for batch_size in db_train:
        count+=1
        train_image = []
        mask_image = []
        for i in range(len(batch_size)):
            ran = tf.random.uniform(())
            train_image.append(load_image(batch_size[i][0],ran))
            mask = load_image_mask(batch_size[i][1],ran)
            mask_image.append(mask)
        mask_image =np.squeeze((mask_image))
        train_image=np.array(train_image)

        #计算损失
        with tf.GradientTape() as tape:
            predicted_image = model(train_image) 
            loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(mask_image,predicted_image))

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients,model.trainable_variables))
        Average_loss = Average_loss + loss
        
    Average_loss = Average_loss/count
    
    for batch_size in db_test:
        count_valid+=1
        valid_image = []
        valid_mask_image = []
        for i in range(len(batch_size)):
            ran = tf.random.uniform(())
            valid_image.append(load_image(batch_size[i][0],ran))
            validmask = load_image_mask(batch_size[i][1],ran)
            valid_mask_image.append(validmask)
        valid_mask_image =np.squeeze((valid_mask_image))
        valid_image = np.array(valid_image)
        predicted_image = model(valid_image) 
        m = tf.keras.metrics.MeanIoU(num_classes=2)
        m.update_state(tf.argmax(valid_mask_image,axis=-1),tf.argmax(predicted_image,axis=-1))
        Average_valid_MeanIoU+=tf.reduce_mean(m.result().numpy())
    Average_valid_MeanIoU = Average_valid_MeanIoU/count_valid
    History.append([epoch, Average_loss,Average_valid_MeanIoU])
    tf.print("=>Epoch%4d  Averageloss:%4.6f Average_valid_MeanIoU:%4.6f" %(epoch, Average_loss,Average_valid_MeanIoU))   
    if Average_valid_MeanIoU>valid_num:
        valid_num = Average_valid_MeanIoU
        model.save_weights("./MambaUnet_CCSD_crop_pyramid(4×4).h5")
            
    plt.imshow(valid_image[0])
    plt.show()
    plt.imshow(np.argmax(predicted_image,axis=-1)[0])
    plt.show()

In [None]:
###########################测试###############
#automatic_gpu_usage()#分配GPU
###########sample##################
BATCH_SIZE = 2
##########加载数据#############
image_test_dir =r'D:\AI in NTU\Tunnel segment\concreteCrackSegmentationDataset\Test_rgb'
mask_test_dir =r'D:\AI in NTU\Tunnel segment\concreteCrackSegmentationDataset\Test_BW'
def get_image_and_mask_paths(image_dir:str,mask_dir:str):
    '''
    获取所有图片与对应标签的路径 
    '''
    # 导入数据集
    all_file_image=os.listdir(image_dir)
    all_image=[]
    all_mask=[]
    for i in range(len(os.listdir(image_dir))):
        if (os.path.splitext(all_file_image[i])[1] == ".jpg") or (os.path.splitext(all_file_image[i])[1] == ".JPG"):
            all_image.append(image_dir + "\\" + all_file_image[i])
            all_mask.append(mask_dir + "\\" +  all_file_image[i])
    all_image=np.array(all_image)[:,np.newaxis]
    all_mask =np.array(all_mask)[:,np.newaxis]
    all_path = np.concatenate((all_image,all_mask),axis=-1)
    return all_path

paths = get_image_and_mask_paths(image_test_dir,mask_test_dir)

#创建图片路径及其数字标签的dataset
db_train= tf.data.Dataset.from_tensor_slices(paths)
db_train = db_train.shuffle(buffer_size=4,seed=2023)
db_train = db_train.batch(BATCH_SIZE)

def load_image(path):
    image = tf.io.read_file(path)
    image = tf.cast(tf.image.decode_jpeg(image,channels=3),dtype=tf.float32)
    image = tf.image.resize(image,[512,512])
    image /= 255.0
    image = image*2-1#将整张图片的值规范在[-1,1]之间
    return image

def load_image_mask(path):
    image = tf.io.read_file(path)
    image = tf.cast(tf.image.decode_jpeg(image,channels=1),dtype=tf.float32)
    image /= 255.0
    image = tf.image.resize(image,[512,512])
    image = tf.cast(tf.round(image),dtype=tf.uint8)
    image = tf.one_hot(image,depth=2)
    return image

#Network 
model.load_weights("./MambaUnet_CCSD_crop_pyramid(4×4).h5")

count=0
IoU_list=[]
Precision=[]
Recall=[]
Accuracy=[]
MAE=[]
F1=[]
for batch_size in db_train:
    count+=1
    train_image = []
    mask_image = []
    for i in range(len(batch_size)):
        train_image.append(load_image(batch_size[i][0]))
        mask_image.append(load_image_mask(batch_size[i][1]))
    mask_image =np.squeeze((mask_image))
    train_image=np.array(train_image)
    predicted_image = model(train_image)
    segement_image=np.argmax(predicted_image,axis=-1)
    mask_image_r = np.argmax(mask_image,axis=-1)
    print(train_image.shape)
    for i in range(BATCH_SIZE):
        plt.imshow(train_image[i])
        plt.show()
        plt.imshow(segement_image[i])
        plt.show()
        m = tf.keras.metrics.MeanIoU(num_classes=2)
        m.update_state(mask_image_r[i],segement_image[i])
        a = tf.keras.metrics.Accuracy()
        a.update_state(mask_image_r[i],segement_image[i])
        p = tf.keras.metrics.Precision()
        p.update_state(mask_image_r[i],segement_image[i])
        r = tf.keras.metrics.Recall()
        r.update_state(mask_image_r[i],segement_image[i])
        mae = tf.reduce_sum(tf.math.abs(mask_image_r[i]-segement_image[i]))/512/512
        #print(mae)
        IoU_list.append(m.result().numpy())
        Accuracy.append(a.result().numpy())
        Precision.append(p.result().numpy())
        Recall.append(r.result().numpy())
        MAE.append(mae)
        F1.append(2 * (p.result().numpy() * r.result().numpy()) / (p.result().numpy() + r.result().numpy()))
        
print(np.average(IoU_list))
print(np.average(Accuracy))
print(np.average(Precision))
print(np.average(Recall))
print(np.average(F1))
print(np.average(MAE))
plt.bar(range(len(IoU_list)), IoU_list)
plt.show()