In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import *
from keras_self_attention import SeqSelfAttention
from tensorflow.keras import regularizers
from tensorflow.keras.utils import plot_model
import pickle
import numpy as np
from tensorflow.keras.layers import Lambda
from tensorflow.keras import backend as K

In [2]:
X_train= pickle.load(open("X_train","rb"))
y_train= pickle.load(open("y_train","rb"))
X_test= pickle.load(open("X_test","rb"))
y_test= pickle.load(open("y_test","rb"))

In [3]:
X_train=X_train/255.0
X_test=X_test/255.0

In [4]:
X_test.shape

(9057, 128, 128, 1)

In [5]:
def channel_attention(input_feature, ratio=8):
    channel_dim = int(input_feature.shape[-1])
    shared_layer_one = Dense(channel_dim//ratio, activation='relu', kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros')
    shared_layer_two = Dense(channel_dim, kernel_initializer='he_normal', use_bias=True, bias_initializer='zeros')
    avg_pool = GlobalAveragePooling2D()(input_feature)
    avg_pool = Reshape((1,1,channel_dim))(avg_pool)
    avg_pool = shared_layer_one(avg_pool)
    avg_pool = shared_layer_two(avg_pool)
    max_pool = GlobalMaxPooling2D()(input_feature)
    max_pool = Reshape((1,1,channel_dim))(max_pool)
    max_pool = shared_layer_one(max_pool)
    max_pool = shared_layer_two(max_pool)
    cbam_feature = Add()([avg_pool,max_pool])
    cbam_feature = Activation('sigmoid')(cbam_feature)
    return Multiply()([input_feature, cbam_feature])

In [6]:
def spatial_attention(input_feature):
    kernel_size = 7
    if K.image_data_format() == "channels_first":
        channel = input_feature.shape[1]
        cbam_feature = Permute((2,3,1))(input_feature)
    else:
        channel = input_feature.shape[-1]
        cbam_feature = input_feature
    avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
    max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
    concat = Concatenate(axis=3)([avg_pool, max_pool])
    cbam_feature = Conv2D(filters=1,
                          kernel_size=kernel_size,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          kernel_initializer='he_normal',
                          use_bias=False)(concat)
    if K.image_data_format() == "channels_first":
        cbam_feature = Permute((3, 1, 2))(cbam_feature)
    return Multiply()([input_feature, cbam_feature])

In [7]:
def create_model():
    inputl=Input(shape=(128,128,1))

    x=Conv2D(96,(3,3),padding="same")(inputl)
    x=Dropout(0.3)(x)

    x=Conv2D(72,(3,3),padding="same")(x)
    x=Dropout(0.3)(x)

    x1=Conv2D(72,(3,3),padding="same")(x)
    x1=Activation("relu")(x1)
    x1=MaxPooling2D(pool_size=(2,2))(x1)
    x2=Conv2D(32,(5,5),padding="same")(x)
    x2=Activation("relu")(x2)
    x2=MaxPooling2D(pool_size=(2,2))(x2)
    x3=Conv2D(32,(7,7),padding="same")(x)
    x3=Activation("relu")(x3)
    x3=MaxPooling2D(pool_size=(2,2))(x3)

    x=concatenate([x1,x2,x3],axis=3)
    x=Dropout(0.3)(x)
    x=spatial_attention(x)

    x=Flatten()(x)

    x=Dense(84, activation="relu",kernel_regularizer=regularizers.l2(l=0.001),bias_regularizer=regularizers.l2(l=0.001))(x)
    x=Dropout(0.3)(x)
    x=Dense(128, activation="relu",kernel_regularizer=regularizers.l2(l=0.001),bias_regularizer=regularizers.l2(l=0.001))(x)
    x=Dropout(0.3)(x)
    x=Dense(128, activation="relu",kernel_regularizer=regularizers.l2(l=0.001),bias_regularizer=regularizers.l2(l=0.001))(x)
    x=Dropout(0.1)(x)
    outputl=Dense(1,activation="sigmoid")(x)
    model=Model(inputs=inputl,outputs=outputl)
    model.compile(loss="binary_crossentropy",
              optimizer='adam',
              metrics=["accuracy"])
    return model


In [8]:
model=create_model()

In [9]:
model.fit(X_train,y_train,epochs=20,batch_size=32,validation_data=(X_test,y_test))

Epoch 1/20
Epoch 2/20
Epoch 3/20

KeyboardInterrupt: 

In [None]:
model.save("attention_20e.h5")