# **SENet**
此份程式碼會介紹如何使用 tf.keras 的方式建構 SENet 的模型架構。

<img src="https://i.imgur.com/3xGwreb.png" width=1000/>

- [source paper](https://arxiv.org/abs/1709.01507)

## 匯入套件

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Tensorflow 相關套件
import tensorflow as tf
from tensorflow.keras import datasets, layers, Model, Sequential, losses

## 載入資料集

In [None]:
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

# Expand dimensions
x_train = tf.expand_dims(x_train, axis=3, name=None)
x_test = tf.expand_dims(x_test, axis=3, name=None)
print(f'x_train shape: {x_train.shape}')
print(f'x_test shape: {x_test.shape}')
print('----------')

# Grayscale to RGB
x_train = tf.repeat(x_train, 3, axis=3)
x_test = tf.repeat(x_test, 3, axis=3)
print(f'x_train shape: {x_train.shape}')
print(f'x_test shape: {x_test.shape}')
print('----------')

# Split dataset into training and validation data
x_val = x_train[int(x_train.shape[0]*0.8):, :, :, :]
y_val = y_train[int(y_train.shape[0]*0.8):]
x_train = x_train[:int(x_train.shape[0]*0.8), :, :, :]
y_train = y_train[:int(y_train.shape[0]*0.8)]
print(f'x_train shape: {x_train.shape}, x_val shape: {x_val.shape}')
print(f'y_train shape: {y_train.shape}, y_val shape: {y_val.shape}')

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
x_train shape:(60000, 28, 28, 1)
x_test shape:(10000, 28, 28, 1)
----------
x_train shape:(60000, 28, 28, 3)
x_test shape:(10000, 28, 28, 3)
----------
x_train shape:(48000, 28, 28, 3), x_val shape:(12000, 28, 28, 3)
y_train shape:(48000,), y_val shape:(12000,)


## SENet Arhietecture

<img src="https://i.imgur.com/mvqWU9g.png" width=1000/>

- [source paper](https://arxiv.org/abs/1709.01507)

In [None]:
labels_num = 10

In [None]:
def add_conv(x, filters_num, kernel_size):
    x = layers.Conv2D(filters_num,
                      (kernel_size, kernel_size),
                      strides=1, padding = 'same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D((2, 2), strides=1,
                            padding = 'same')(x)
    return x

filters_num = 64

inputs = layers.Input(shape=x_train.shape[1:])
x = layers.Resizing(224, 224,
                    interpolation="bilinear",
                    input_shape=x_train.shape[1:])(inputs)
conv_1 = add_conv(x, filters_num, 7)
conv_2 = add_conv(conv_1, filters_num, 3)

squeeze = layers.GlobalAveragePooling2D()(conv_2)
squeeze = layers.Reshape((1, 1, filters_num))(squeeze)
excitation = layers.Dense(filters_num, activation='relu')(squeeze)
excitation = layers.Dense(filters_num, activation='sigmoid')(excitation)

scale = conv_2 * excitation
scale = layers.GlobalAveragePooling2D()(scale)
outputs = layers.Dense(labels_num)(scale)

In [None]:
SENet_model = Model(inputs=inputs, outputs=outputs)

In [None]:
SENet_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28, 3)]  0           []                               
                                                                                                  
 resizing (Resizing)            (None, 224, 224, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (None, 224, 224, 64  9472        ['resizing[0][0]']               
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 224, 224, 64  256        ['conv2d[0][0]']             

In [None]:
batch_size = 4
inputs = np.ones((batch_size, x_train.shape[1], x_train.shape[2], 3),
                 dtype=np.float32)
SENet_model(inputs).shape

TensorShape([4, 10])

In [None]:
SENet_model(inputs)

<tf.Tensor: shape=(4, 10), dtype=float32, numpy=
array([[-0.02457043, -0.0495788 ,  0.24664153, -0.00211908, -0.11318807,
        -0.07496573, -0.07785393, -0.07045375, -0.11746019, -0.03302393],
       [-0.02457043, -0.0495788 ,  0.24664153, -0.00211908, -0.11318807,
        -0.07496573, -0.07785393, -0.07045375, -0.11746019, -0.03302393],
       [-0.02457043, -0.0495788 ,  0.24664153, -0.00211908, -0.11318807,
        -0.07496573, -0.07785393, -0.07045375, -0.11746019, -0.03302393],
       [-0.02457043, -0.0495788 ,  0.24664153, -0.00211908, -0.11318807,
        -0.07496573, -0.07785393, -0.07045375, -0.11746019, -0.03302393]],
      dtype=float32)>