https://blog.csdn.net/qq_36758914/article/details/107511908

In [1]:
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, Conv2D, Concatenate, BatchNormalization, DepthwiseConv2D
from tensorflow.keras.layers import Lambda, Reshape, Layer, Activation, add, Multiply
from math import ceil
import numpy as np

In [None]:
class SEModule(Layer):

  def __init__(self, filters, ratio):
    super(SEModule, self).__init__()
    self.pooling = GlobalAveragePooling2D()
    self.conv1 = Conv2D(int(filters / ratio), (1, 1), strides=(1, 1), padding='same',
                        use_bias=False, activation=None)
    self.conv2 = Conv2D(int(filters), (1, 1), strides=(1, 1), padding='same',
                        use_bias=False, activation=None)
    self.relu = Activation('relu')
    self.hard_sigmoid = Activation('hard_sigmoid')

  def call(self, inputs):
    x = self.pooling(inputs)
    x = Reshape((1, 1, int(x.shape[1])))(x)
    x = self.relu(self.conv1(x))
    excitation = self.hard_sigmoid(self.conv2(x))
    x = inputs * excitation
    return x

In [None]:
class GhostModule(Layer):
  def __init__(self, out, ratio, convkernel, dwkernel):
    super(GhostModule, self).__init__()
    self.ratio = ratio
    self.out = out
    self.conv_out_channel = ceil(self.out * 1.0 / ratio)
    self.conv = Conv2D(int(self.conv_out_channel), (convkernel, convkernel), use_bias=False,
                        strides=(1, 1), padding='same', activation=None)
    self.depthconv = DepthwiseConv2D(dwkernel, 1, padding='same', use_bias=False,
                                      depth_multiplier=ratio-1, activation=None)
    self.concat = Concatenate()

  def call(self, inputs):
    x = self.conv(inputs)
    if self.ratio == 1:
      return x
    dw = self.depthconv(x)
    dw = dw[:, :, :, :int(self.out - self.conv_out_channel)]
    output = self.concat([x, dw])
    return output

In [None]:
class GBNeck(Layer):
  def __init__(self, dwkernel, strides, exp, out, ratio, use_se):
    super(GBNeck, self).__init__()
    self.strides = strides
    self.use_se = use_se
    self.conv = Conv2D(out, (1, 1), strides=(1, 1), padding='same',
                        activation=None, use_bias=False)
    self.relu = Activation('relu')
    self.depthconv1 = DepthwiseConv2D(dwkernel, strides, padding='same', depth_multiplier=ratio-1,
                                      activation=None, use_bias=False)
    self.depthconv2 = DepthwiseConv2D(dwkernel, strides, padding='same', depth_multiplier=ratio-1,
                                      activation=None, use_bias=False)
    for i in range(5):
      setattr(self, f"batchnorm{i+1}", BatchNormalization())
    self.ghost1 = GhostModule(exp, ratio, 1, 3)
    self.ghost2 = GhostModule(out, ratio, 1, 3)
    self.se = SEModule(exp, ratio)

  def call(self, inputs):
    x = self.batchnorm1(self.depthconv1(inputs))
    x = self.batchnorm2(self.conv(x))

    y = self.relu(self.batchnorm3(self.ghost1(inputs)))
    if self.strides > 1:
        y = self.relu(self.batchnorm4(self.depthconv2(y)))
    if self.use_se:
        y = self.se(y)
    y = self.batchnorm5(self.ghost2(y))
    return add([x, y])

In [None]:
class GhostNet(tf.keras.Model):
  
  def __init__(self, classes):
    super(GhostNet, self).__init__()
    self.classes = classes
    self.conv1 = Conv2D(16, (3, 3), strides=(2, 2), padding='same',
                        activation=None, use_bias=False)
    self.conv2 = Conv2D(960, (1, 1), strides=(1, 1), padding='same',
                        activation=None, use_bias=False)
    self.conv3 = Conv2D(1280, (1, 1), strides=(1, 1), padding='same',
                        activation=None, use_bias=False)
    self.conv4 = Conv2D(self.classes, (1, 1), strides=(1, 1), padding='same',
                        activation=None, use_bias=False)
    for i in range(3):
      setattr(self, f"batchnorm{i+1}", BatchNormalization())  
    self.relu = Activation('relu')
    self.softmax = Activation('softmax')
    self.pooling = GlobalAveragePooling2D()

    self.dwkernels = [3, 3, 3, 5, 5, 3, 3, 3, 3, 3, 3, 5, 5, 5, 5, 5]
    self.strides = [1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1]
    self.exps = [16, 48, 72, 72, 120, 240, 200, 184, 184, 480, 672, 672, 960, 960, 960, 960]
    self.outs = [16, 24, 24, 40, 40, 80, 80, 80, 80, 112, 112, 160, 160, 160, 160, 160]
    self.ratios = [2] * 16
    self.use_ses = [False, False, False, True, True, False, False, False,
                    False, True, True, True, False, True, False, True]
    for i, args in enumerate(zip(self.dwkernels, self.strides, self.exps, self.outs, self.ratios, self.use_ses)):
      setattr(self, f"gbneck{i}", GBNeck(*args))

  def call(self, inputs):
    x = self.relu(self.batchnorm1(self.conv1(inputs)))
    # Iterate through Ghost Bottlenecks
    for i in range(16):
      x = getattr(self, f"gbneck{i}")(x)
    x = self.relu(self.batchnorm2(self.conv2(x)))
    x = self.pooling(x)
    x = Reshape((1, 1, int(x.shape[1])))(x)
    x = self.relu(self.batchnorm3(self.conv3(x)))
    x = self.conv4(x)
    x = tf.squeeze(x, 1)
    x = tf.squeeze(x, 1)
    output = self.softmax(x)
    return output

In [None]:
inputs = np.zeros((1, 224, 224, 3), np.float32)
model = GhostNet(10)

In [None]:
model.build(input_shape = (1,100,100,3))

In [None]:
model.gbneck0

<__main__.GBNeck at 0x7f56a5716290>

In [None]:
model.summary()

Model: "ghost_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  432       
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  153600    
_________________________________________________________________
conv2d_2 (Conv2D)            multiple                  1228800   
_________________________________________________________________
conv2d_3 (Conv2D)            multiple                  12800     
_________________________________________________________________
batch_normalization (BatchNo multiple                  64        
_________________________________________________________________
batch_normalization_1 (Batch multiple                  3840      
_________________________________________________________________
batch_normalization_2 (Batch multiple                  51