In [1]:
from tensorflow.keras import *

class CAM(layers.Layer):
    def __init__(self, ratio=8):
        super().__init__()
        self.ratio = ratio
        self.gap = layers.GlobalAveragePooling2D()
        self.gmp = layers.GlobalMaxPooling2D()

    def build(self, input_shape):
        self.conv1 = layers.Conv2D(input_shape[-1]//self.ratio,
                                   kernel_size=1,
                                   strides=1, padding='same',
                                   activation='relu')
        self.conv2 = layers.Conv2D(input_shape[-1],
                                   kernel_size=1,
                                   strides=1, padding='same',
                                   activation='relu')

    def call(self, inputs):
        gap = self.gap(inputs)
        gmp = self.gmp(inputs)
        gap = layers.Reshape((1,1,gap.shape[1]))(gap)
        gmp = layers.Reshape((1,1,gmp.shape[1]))(gmp)
        gap_out = self.conv2(self.conv1(gap))
        gmp_out = self.conv2(self.conv1(gmp))

        return tf.math.sigmoid(gap_out+gmp_out)
     


class SAM(layers.Layer):
    def __init__(self, kernel_size=3):
        super().__init__()
        self.conv1 = layers.Conv2D(64,
                                            kernel_size=kernel_size,
                                            use_bias=False,
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same',
                                            activation=tf.nn.relu)
        self.conv2 = layers.Conv2D(32, kernel_size=kernel_size,
                                            use_bias=False,
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same',
                                            activation=tf.nn.relu)
        self.conv3 = layers.Conv2D(16, kernel_size=kernel_size,
                                            use_bias=False,
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same',
                                            activation=tf.nn.relu)
        self.conv4 = layers.Conv2D(1,
                                            kernel_size=(1, 1),
                                            use_bias=False,
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same',
                                            activation=tf.math.sigmoid)


    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs,  axis=3)
        x = tf.stack([avg_out, max_out], axis=3)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return self.conv4(x)

In [2]:
def build_model():
    inp = layers.Input(shape=(256,256,3))
    base_model = efn.EfficientNetB0(include_top=False, weights=None, input_shape=(256,256,3))
    base_model.load_weights('/content/tf-efficientnet-noisy-student-weights/efficientnet-b0_noisy-student_notop.h5')
    
    # Freezing Layers
    for layer in base_model.layers[:len(base_model.layers)//10]:
        layer.trainable = False
    # Get Learnable Resizer  
    

    # Output
    base_x = base_model(inp)
    can_x = CAM()(base_x)*base_x
    spnx = SAM()(can_x)*can_x
    spny = SAM()(can_x)
    
    gapx = layers.GlobalAveragePooling2D()(spnx)
    wvgx = layers.GlobalAveragePooling2D()(spny)
    x = layers.Average()([gapx, wvgx])
    
    x = layers.Dense(len(TARGET), activation='softmax', dtype='float32')(x)
    
    # Compile
    model = Model(inputs=inp, outputs=x, name='efficientnet')
    loss = losses.CategoricalCrossentropy(label_smoothing=0.15)
    opt = optimizers.Adam(learning_rate=1e-3)
    
    model.compile(loss=loss, optimizer=opt, metrics=['accuracy',f1])
    
    return model