In [29]:
import matplotlib.pyplot as plt

from tensorflow import expand_dims
from sklearn.metrics import classification_report
from tensorflow.keras.layers import Layer, Dense, GlobalAveragePooling2D, ReLU, Activation, Conv2D, BatchNormalization, MaxPooling2D, Flatten, multiply
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [30]:
class SE_Module(Layer):
    def __init__(self, channels: int, reduction: int = 16):
        super(SE_Module, self).__init__()
        self._avg_pool = GlobalAveragePooling2D()
        self._d1 = Dense(1 * 1 * (channels / reduction))
        self._r = ReLU()
        self._d2 = Dense(1 * 1 * channels)
        self._act = Activation("sigmoid")

    def call(self, inputs, *args, **kwargs):
        output = self._avg_pool(inputs)
        output = self._d1(output)
        output = self._r(output)
        output = self._d2(output)
        output = self._act(output)
        output = expand_dims(input=output, axis=1)
        output = expand_dims(input=output, axis=1)
        return multiply([inputs, output])

In [31]:
class SENet(Model):
    def get_config(self):
        super(SENet, self).get_config()

    def __init__(self, classes: int):
        super(SENet, self).__init__()
        self._con1 = Conv2D(filters=32, kernel_size=(3, 3), strides=2, padding="same")
        self._a1 = ReLU()
        self._b1 = BatchNormalization()
        self._m1 = MaxPooling2D(pool_size=(3, 3), strides=1, padding="same")

        self._con2 = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding="same")
        self._a2 = ReLU()
        self._b2 = BatchNormalization()
        self._m2 = MaxPooling2D(pool_size=(3, 3), strides=1, padding="same")
        # self._se2 = SE_Module(channels=64)

        # self._con3 = Conv2D(filters=128, kernel_size=(3, 3), strides=2, padding="same")
        # self._a3 = ReLU()
        # self._b3 = BatchNormalization()
        # self._m3 = MaxPooling2D(pool_size=(3, 3), strides=1, padding="same")
        self._se2 = SE_Module(channels=64)
        self._avgpool = GlobalAveragePooling2D()
        self._fc = Dense(classes)
        self._a4 = Activation("sigmoid")

    def call(self, inputs, training=None, mask=None):
        output = self._con1(inputs)
        output = self._a1(output)
        output = self._b1(output)
        output = self._m1(output)
        # output = self._se1(output)

        output = self._con2(output)
        output = self._a2(output)
        output = self._b2(output)
        output = self._m2(output)
        output = self._se2(output)

        output = self._avgpool(output)
        output = self._fc(output)
        output = self._a4(output)
        return output

In [32]:
senet = SENet(classes=2)
senet.build(input_shape=(None, 64, 64, 3))
senet.summary()

Model: "se_net_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_10 (Conv2D)           multiple                  896       
_________________________________________________________________
re_lu_12 (ReLU)              multiple                  0         
_________________________________________________________________
batch_normalization_10 (Batc multiple                  128       
_________________________________________________________________
max_pooling2d_10 (MaxPooling multiple                  0         
_________________________________________________________________
conv2d_11 (Conv2D)           multiple                  18496     
_________________________________________________________________
re_lu_13 (ReLU)              multiple                  0         
_________________________________________________________________
batch_normalization_11 (Batc multiple                  256

In [33]:
idg = ImageDataGenerator(rescale=1./255, validation_split=0.2)
train_img_gen = idg.flow_from_directory("data/training_set", target_size=(64, 64),
                                        class_mode='categorical', subset='training')
valid_img_gen = idg.flow_from_directory("data/training_set", target_size=(64, 64),
                                        class_mode='categorical', subset='validation')
test_idg = ImageDataGenerator(rescale=1./255)
test_img_gen = test_idg.flow_from_directory("data/test_set", target_size=(64, 64),
                                       class_mode='categorical')

assert train_img_gen.class_indices == test_img_gen.class_indices

Found 6404 images belonging to 2 classes.
Found 1601 images belonging to 2 classes.
Found 2023 images belonging to 2 classes.


In [35]:
sgd = SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
senet.compile(optimizer=sgd, loss="binary_crossentropy", metrics=["acc"])
history = senet.fit(train_img_gen, validation_data=valid_img_gen, epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [38]:
senet.evaluate(test_img_gen)
senet.predict(next(test_img_gen)[0])



array([[5.71659170e-02, 9.45141792e-01],
       [9.56549719e-02, 8.99565041e-01],
       [9.98662353e-01, 1.38015754e-03],
       [6.58368299e-05, 9.99934077e-01],
       [7.88451254e-01, 2.25163341e-01],
       [9.42080002e-03, 9.90814626e-01],
       [9.76712525e-01, 2.54333168e-02],
       [9.50281918e-01, 5.51555119e-02],
       [8.21961323e-04, 9.99243140e-01],
       [1.65730596e-01, 8.53235424e-01],
       [9.07347083e-01, 1.09939903e-01],
       [6.75304700e-03, 9.94063914e-01],
       [9.66065098e-04, 9.99089956e-01],
       [3.74713033e-01, 6.39181912e-01],
       [8.54200006e-01, 1.53269857e-01],
       [1.45217905e-06, 9.99998689e-01],
       [9.70629394e-01, 3.20013352e-02],
       [2.94716447e-03, 9.96960700e-01],
       [2.40428358e-01, 7.92013228e-01],
       [2.47849464e-01, 7.43531287e-01],
       [6.99926138e-01, 3.14027190e-01],
       [1.85606295e-05, 9.99980211e-01],
       [6.40053770e-08, 9.99999881e-01],
       [4.12293058e-03, 9.95705426e-01],
       [2.610899