In [None]:
#Testing of trained AM-SegNet:

# Construct model structure
# Load trained model weights
# Load images for model testing
# Peform semantic segmentation 

In [None]:
#import libraries for model development
from keras import models, layers
from keras import backend as K
import numpy as np
import glob
import cv2
import os

In [None]:
#Definition of lightweight block (lw_conv_block)

def lw_conv_block(inputs, filter_size, filter_num, dropout, batch_norm=True):

    conv = layers.Conv2D(filter_num, (1, 1), padding='same')(inputs)
    conv = layers.Activation('relu')(conv)

    #left
    squeeze_conv = layers.Conv2D(filter_num,(1, 1), padding='same')(conv)

    # Batch normalization operation
    if batch_norm is True:
        squeeze_conv = layers.BatchNormalization(axis=3)(squeeze_conv)

    squeeze_conv = layers.Activation("relu")(squeeze_conv)

    #middle
    channel_conv = layers.DepthwiseConv2D(kernel_size=(filter_size, filter_size), strides=(1, 1), padding='same', depth_multiplier=2)(conv)
    point_conv = layers.Conv2D(2*filter_num, (1, 1), padding='same')(channel_conv)
    
    # Batch normalization operation
    if batch_norm is True:
        separa_conv = layers.BatchNormalization(axis=3)(point_conv)
    else:
        separa_conv = point_conv
        
    separa_conv = layers.Activation("relu")(separa_conv)

    #right
    expand_conv = layers.Conv2D(filter_num,(filter_size, filter_size), padding='same')(conv)
    
    # Batch normalization operation
    if batch_norm is True:
        expand_conv = layers.BatchNormalization(axis=3)(expand_conv)

    expand_conv = layers.Activation("relu")(expand_conv)

    #concatenate
    lw_conv = layers.concatenate([squeeze_conv, separa_conv, expand_conv], axis=3)
    
    # Dropout operation
    if dropout > 0:
        lw_conv = layers.Dropout(dropout)(lw_conv)

    return lw_conv

In [None]:
#Definition of convolution block

def stand_conv_block(inputs, filter_size, filter_num, dropout, batch_norm=True):

    conv = layers.Conv2D(filter_num, (filter_size, filter_size), padding='same')(inputs)
    
    # Batch normalization operation
    if batch_norm is True:

        conv = layers.BatchNormalization(axis=3)(conv)

    conv_output = layers.Activation('relu')(conv)

    # Dropout operation
    if dropout > 0:
        
        conv_output = layers.Dropout(dropout)(conv_output)

    return conv_output

In [None]:
#Definition of attention block

def attention_block(x, gating, size):

# Convert gating single using (1,1) convolutions
    phi_g = layers.Conv2D(size, (1, 1), padding='same')(gating) 

# Convert x single to the same shape as the gating signal
    theta_x = layers.Conv2D(size, (1, 1), padding='same')(x) 

# Adding phi_g, theta_x together, activated by relu
    concat_xg = layers.add([phi_g, theta_x])
    act_xg = layers.Activation('relu')(concat_xg)

# Conduct ψ operation on act_xg, then activated by sigmoid
    psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)
    sigmoid_xg = layers.Activation('sigmoid')(psi)

# Conduct multiply operation on [upsample_psi, x]
    result = layers.multiply([sigmoid_xg, x])
    result_bn = layers.BatchNormalization()(result)
    
    return result_bn

In [None]:
#Definition of repeat_elem for element repeating

def repeat_elem(tensor, rep):
     
     return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
                          arguments={'repnum': rep})(tensor)

In [None]:
# Definition of AM-SegNet with lightweight block and attention mechanism

def AM_SegNet(input_shape, num_classes, dropout, batch_norm):

    # parameters of network congfiguration

    filter_num = 12 # number of filters
    filter_size = 3 # size of filters
    up_samp_size = 2 # size of upsampling filters

    inputs = layers.Input(input_shape)

    # Downsampling

    # Downsampling step 1
    conv_1 = lw_conv_block(inputs, filter_size, 1*filter_num, dropout, batch_norm)
    pool_1 = layers.MaxPooling2D(pool_size=(2,2))(conv_1)

    # Downsampling step 2
    conv_2 = lw_conv_block(pool_1, filter_size, 2*filter_num, dropout, batch_norm)
    pool_2 = layers.MaxPooling2D(pool_size=(2,2))(conv_2)

    # Downsampling step 3
    conv_3 = lw_conv_block(pool_2, filter_size, 4*filter_num, dropout, batch_norm)
    pool_3 = layers.MaxPooling2D(pool_size=(2,2))(conv_3)

    # Downsampling step 5
    conv_4 = lw_conv_block(pool_3, filter_size, 8*filter_num, dropout, batch_norm)
    pool_4 = layers.MaxPooling2D(pool_size=(2,2))(conv_4)

    # Standard convolution only
    conv_5_1 = stand_conv_block(pool_4, filter_size, 64*filter_num, dropout, batch_norm)
    conv_5_2 = stand_conv_block(conv_5_1, filter_size, 64*filter_num, dropout, batch_norm)
    conv_5_3 = stand_conv_block(conv_5_2, filter_size, 64*filter_num, dropout, batch_norm)

    #Calculate Attention
    conv_att = attention_block(conv_5_1, conv_5_3, 64*filter_num)
    conv_5 = layers.add([conv_5_3, conv_att])

    # Upsampling

    # Upsampling step 1

    up_1 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(conv_5)
    up_1 = layers.concatenate([up_1, conv_4], axis=3)
    up_conv_1 = stand_conv_block(up_1, filter_size, 16*filter_num, dropout, batch_norm)
   
    # Upsampling step 2
    up_conv_1 = layers.Conv2D(16*filter_num, (filter_size, filter_size), padding='same')(up_conv_1)
    up_2 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(up_conv_1)
    up_2 = layers.concatenate([up_2, conv_3], axis=3)
    up_conv_2 = stand_conv_block(up_2, filter_size, 8*filter_num, dropout, batch_norm)

    # Upsampling step 3
    up_conv_2 = layers.Conv2D(8*filter_num, (filter_size, filter_size), padding='same')(up_conv_2)
    up_3 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(up_conv_2)
    up_3 = layers.concatenate([up_3, conv_2], axis=3)
    up_conv_3 = stand_conv_block(up_3, filter_size, 4*filter_num, dropout, batch_norm)

    # Upsampling step 4
    up_conv_3 = layers.Conv2D(4*filter_num, (filter_size, filter_size), padding='same')(up_conv_3)
    up_4 = layers.UpSampling2D(size=(up_samp_size, up_samp_size), data_format="channels_last")(up_conv_3)
    up_4 = layers.concatenate([up_4, conv_1], axis=3)
    up_conv_4 = stand_conv_block(up_4, filter_size, 2*filter_num, dropout, batch_norm)

    # 1*1 convolutional layers
    conv_final = layers.Conv2D(num_classes, kernel_size=(1,1))(up_conv_4)
    conv_final = layers.BatchNormalization(axis=3)(conv_final)
    conv_final = layers.Activation('softmax')(conv_final)  #Change to softmax for multichannel
    
    # Model 
    model = models.Model(inputs, conv_final, name="AM-SegNet")
    
    # print model summary for details
    print(model.summary())

    return model

In [None]:
#Size of images and input shape
input_size_x= 256 
input_size_y= 512
input_size = (input_size_x,input_size_y,1)

In [None]:
#Number of pixel labels: Keyhole, pore, substract, background and powder
class_num=5

#Setting dropout rate
dropout=0.0

In [None]:
#create the AM-SegNet model
model = AM_SegNet(input_size, class_num, dropout, batch_norm=True)

In [None]:
#Load X-ray images for model testing
test_images = []

for directory_path in glob.glob("test images/"):
    for img_path in glob.glob(os.path.join(directory_path, "*.jpg")):
        img = cv2.imread(img_path, 0)       
        test_images.append(img)
       
#Convert image list to np.array       
test_images = np.array(test_images)

In [None]:
#Load pre-trained model weights
model.load_weights('Model weights of AM-SegNet.hdf5')

In [None]:
#Perform semantic segmentation analysis
y_pred=model.predict(test_images)
y_pred_argmax=np.argmax(y_pred, axis=3)

In [None]:
#Plot segmention results using customised colormap
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from matplotlib.colors import ListedColormap

# Customised color map
colors = ['#541352FF', '#d62728', '#2f9aa0FF', '#10a53dFF', '#ffcf20FF']
custom_cmap = ListedColormap(colors)

#Plot the segmention results, e.g. the first one in the image list
figure(figsize=(8, 4))
plt.imshow(y_pred_argmax[0], cmap= custom_cmap, alpha=0.60)
plt.axis('off')