# **ConvNeXt**
此份程式碼會介紹如何使用 tf.keras 的方式建構 ConvNeXt 的模型架構。

<img src="https://i.imgur.com/aIZ2IgS.png" width=600/>

- [source paper](https://arxiv.org/abs/2201.03545)

## 匯入套件

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Tensorflow 相關套件
import tensorflow as tf
from tensorflow.keras import datasets, layers, Model, Sequential, losses

## 載入資料集

In [None]:
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()

# Expand dimensions
x_train = tf.expand_dims(x_train, axis=3, name=None)
x_test = tf.expand_dims(x_test, axis=3, name=None)
print(f'x_train shape: {x_train.shape}')
print(f'x_test shape: {x_test.shape}')
print('----------')

# Grayscale to RGB
x_train = tf.repeat(x_train, 3, axis=3)
x_test = tf.repeat(x_test, 3, axis=3)
print(f'x_train shape: {x_train.shape}')
print(f'x_test shape: {x_test.shape}')
print('----------')

# Split dataset into training and validation data
x_val = x_train[int(x_train.shape[0]*0.8):, :, :, :]
y_val = y_train[int(y_train.shape[0]*0.8):]
x_train = x_train[:int(x_train.shape[0]*0.8), :, :, :]
y_train = y_train[:int(y_train.shape[0]*0.8)]
print(f'x_train shape: {x_train.shape}, x_val shape: {x_val.shape}')
print(f'y_train shape: {y_train.shape}, y_val shape: {y_val.shape}')

## ConvNext Arhietecture

![](https://i.imgur.com/P9FvSbO.png)

In [None]:
labels_num = 10

In [None]:
def ConvNeXtBlock(x, filter_num, block_num):
    for i in range(block_num):
        # depthwise conv
        depthwise = layers.DepthwiseConv2D((7, 7),
                                           padding='same')(x)
        depthwise = layers.LayerNormalization(epsilon=1e-6)(depthwise)

        # pointwise conv
        pointwise = layers.Conv2D(4 * filter_num, (1, 1),
                                  strides=(1, 1),
                                  padding='same',
                                  activation='gelu')(depthwise)
        pointwise = layers.Conv2D(filter_num, (1, 1),
                                  strides=(1, 1),
                                  padding='same')(pointwise)

        # skip connection
        outputs = layers.Add()([x, pointwise])
        x = outputs
    return outputs

In [None]:
def Downsample(x, filter_num):
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    output = layers.Conv2D(filter_num, (2, 2),
                           strides=(2, 2),
                           padding='same')(x)
    return output

In [None]:
filter_list = [96, 192, 384, 768]
block_list = [3, 3, 9, 3]

tf.keras.backend.clear_session()
inputs = layers.Input(shape=x_train.shape[1:])
x = layers.Resizing(224, 224,
                    interpolation="bilinear",
                    input_shape=x_train.shape[1:])(inputs)

x = layers.Conv2D(filter_list[0], (4, 4),
                  strides=(4, 4),
                  padding='same')(x)
x = layers.LayerNormalization(epsilon=1e-6)(x)
x = ConvNeXtBlock(x, filter_list[0], block_list[0])

for filter_num, block_num in zip(filter_list[1:], block_list[1:]):
    x = Downsample(x, filter_num)
    x = ConvNeXtBlock(x, filter_num, block_num)

x = layers.GlobalAveragePooling2D()(x)
x = layers.LayerNormalization(epsilon=1e-6)(x)
outputs = layers.Dense(labels_num, activation='softmax')(x)

In [None]:
ConvNeXt_model = Model(inputs=inputs, outputs=outputs)

In [None]:
ConvNeXt_model.summary()

In [None]:
batch_size = 4
inputs = np.ones((batch_size, x_train.shape[1], x_train.shape[2], 3),
                 dtype=np.float32)
ConvNeXt_model(inputs)