In [1]:
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, Dense, ReLU, GlobalAveragePooling2D, Input, Flatten, Conv1D
from tensorflow import random, Variable, add, multiply, transpose, Tensor, reshape
from tensorflow.keras.datasets import mnist
from tensorflow.keras.backend import sigmoid, softmax
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD

In [2]:
def channel_shuffle(input_tensor: Tensor, group_num: int = 2):
    """
    参考：https://blog.csdn.net/baidu_23388287/article/details/94456951
         https://blog.csdn.net/qq_36758914/article/details/106967780

    :param input_tensor:
    :param group_num:
    :return:
    """
    batch_size, h, w, channel = input_tensor.get_shape()
    input_reshaped = reshape(input_tensor, [-1, h, w, group_num, channel // group_num])
    input_transpose = transpose(input_reshaped, [0, 1, 2, 4, 3])
    return reshape(input_transpose, [-1, h, w, channel])

def sECAnet(input_shape: tuple, classes: int):
    x = Input(shape=input_shape)
    y = Conv2D(filters=64, kernel_size=(3, 3), padding="same")(x)
    y = Conv2D(filters=128, kernel_size=(3, 3), padding="same")(y)
    y_ = Conv2D(filters=256, kernel_size=(3, 3), padding="same")(y)
    # 注意力1
    a1 = GlobalAveragePooling2D(keepdims=True)(y_)
    # 打乱
    a1 = channel_shuffle(a1)
    a1 = Conv1D(filters=a1.shape[-1], kernel_size=3, padding="same")(a1)
    a1 = sigmoid(a1)
    # 注意力2
    a2 = GlobalAveragePooling2D(keepdims=True)(y_)
    a2 = Conv1D(filters=a2.shape[-1], kernel_size=3, padding="same")(a2)
    a2 = sigmoid(a2)
    # select
    weight = Variable(random.uniform([1], 0, 1), trainable=True)
    attention_v = add(weight * a1, (1. - weight) * a2)
    # 相乘
    y = multiply(attention_v, y_)

    y = Flatten()(y)
    y = Dense(512)(y)
    y = ReLU()(y)
    y = Dense(classes)(y)
    y = softmax(y)
    return Model(x, y)

In [3]:
snet = sECAnet((28, 28, 1), 10)
snet.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 28, 28, 64)   640         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 128)  73856       conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 28, 28, 256)  295168      conv2d_1[0][0]                   
______________________________________________________________________________________________

In [4]:

def data_load():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train / 255.
    x_test = x_test / 255.
    return (x_train, y_train), (x_test, y_test)

In [5]:
sgd = SGD(learning_rate=0.001)
(x_train, y_train), (x_test, y_test) = data_load()
snet.compile(optimizer=sgd, loss=SparseCategoricalCrossentropy(), metrics=['acc'])
snet.fit(x=x_train, y=y_train, validation_data=(x_test, y_test), batch_size=64, epochs=20, verbose=1)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x27d7fbde400>