# ConvNeXt

<img src="https://i.imgur.com/aIZ2IgS.png" width=600/>

- [source paper](https://arxiv.org/abs/2201.03545)

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import datasets, layers, Model, Sequential, losses

In [2]:
(x_train, y_train), (x_test ,y_test) = datasets.mnist.load_data()

x_train = tf.expand_dims(x_train, axis=3, name=None)
x_test = tf.expand_dims(x_test, axis=3, name=None)
print(f'x_train shape:{x_train.shape}')
print(f'x_test shape:{x_test.shape}')
print('----------')

x_train = tf.repeat(x_train, 3, axis=3)
x_test = tf.repeat(x_test, 3, axis=3)
print(f'x_train shape:{x_train.shape}')
print(f'x_test shape:{x_test.shape}')
print('----------')

x_val = x_train[int(x_train.shape[0]*0.8):,:,:,:]
y_val = y_train[int(y_train.shape[0]*0.8):]
x_train = x_train[:int(x_train.shape[0]*0.8),:,:,:]
y_train = y_train[:int(y_train.shape[0]*0.8)]
print(f'x_train shape:{x_train.shape}, x_val shape:{x_val.shape}')
print(f'y_train shape:{y_train.shape}, y_val shape:{y_val.shape}')

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
x_train shape:(60000, 28, 28, 1)
x_test shape:(10000, 28, 28, 1)
----------
x_train shape:(60000, 28, 28, 3)
x_test shape:(10000, 28, 28, 3)
----------
x_train shape:(48000, 28, 28, 3), x_val shape:(12000, 28, 28, 3)
y_train shape:(48000,), y_val shape:(12000,)


In [3]:
labels_num = 10
batch_size = 4

* ## API function

In [4]:
filters_num = 64

inputs = layers.Input(shape=x_train.shape[1:])
x = layers.experimental.preprocessing.Resizing(224, 224, interpolation="bilinear", input_shape=x_train.shape[1:])(inputs)
conv = layers.Conv2D(filters_num, (4,4), strides=(4,4), padding = 'same')(inputs)
conv = layers.LayerNormalization(epsilon = 1e-6)(conv)

# depthwise conv
depthwise = layers.Conv2D(filters_num, (7,7), strides=(1,1), groups = filters_num, padding = 'same')(conv)
depthwise = layers.LayerNormalization(epsilon = 1e-6)(depthwise)

# pointwise conv
pointwise = layers.Conv2D(4 * filters_num, (1,1), strides=(1,1), padding = 'same', activation = 'gelu')(depthwise)
pointwise = layers.Conv2D(filters_num, (1,1), strides=(1,1), padding = 'same')(pointwise)

outputs = layers.Add()([conv, pointwise])
outputs = layers.GlobalAveragePooling2D()(outputs)
outputs = layers.Dense(labels_num)(outputs)

In [5]:
ConvNeXt_model = Model(inputs=inputs, outputs=outputs)

In [6]:
ConvNeXt_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28, 3)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 7, 7, 64)     3136        ['input_1[0][0]']                
                                                                                                  
 layer_normalization (LayerNorm  (None, 7, 7, 64)    128         ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 conv2d_1 (Conv2D)              (None, 7, 7, 64)     3200        ['layer_normalization[0][0]']

In [7]:
inputs = np.ones((batch_size, x_train.shape[1], x_train.shape[2], 3), dtype=np.float32)
ConvNeXt_model(inputs).shape

TensorShape([4, 10])

In [8]:
ConvNeXt_model(inputs)

<tf.Tensor: shape=(4, 10), dtype=float32, numpy=
array([[-0.10945934, -1.0268849 , -0.52640945,  0.7144207 , -0.7330663 ,
         0.00588338,  1.3659196 , -0.13540567, -0.9588703 , -1.2702948 ],
       [-0.10945934, -1.0268849 , -0.52640945,  0.7144207 , -0.7330663 ,
         0.00588338,  1.3659196 , -0.13540567, -0.9588703 , -1.2702948 ],
       [-0.10945934, -1.0268849 , -0.52640945,  0.7144207 , -0.7330663 ,
         0.00588338,  1.3659196 , -0.13540567, -0.9588703 , -1.2702948 ],
       [-0.10945924, -1.0268849 , -0.5264097 ,  0.71442086, -0.73306626,
         0.00588356,  1.3659194 , -0.13540573, -0.9588702 , -1.2702947 ]],
      dtype=float32)>

* ## OOP method

In [9]:
class ConvNeXtBlock(layers.Layer):  
  def __init__(self, filters_num):
    super().__init__()

    # depthwise conv
    self.dwconv = layers.Conv2D(filters_num, (7,7), strides=(1,1), groups = filters_num, padding = 'same')
    self.norm = layers.LayerNormalization(epsilon = 1e-6)

    # pointwise conv
    self.pwconv1 = layers.Conv2D(4 * filters_num, (1,1), strides=(1,1), padding = 'same', activation = 'gelu')
    self.pwconv2 = layers.Conv2D(filters_num, (1,1), strides=(1,1), padding = 'same')

  def call(self, inputs, training = None):
    depthwise = self.dwconv(inputs)
    depthwise = self.norm(depthwise)
    pointwise = self.pwconv1(depthwise)
    pointwise = self.pwconv2(pointwise)

    outputs = pointwise + inputs
    return outputs


In [10]:
class ConvNeXt(Model):
  def __init__(self, input_shape, filters_num, num_class = 10):
    super(ConvNeXt, self).__init__()

    self.conv = Sequential([layers.experimental.preprocessing.Resizing(224, 224, interpolation="bilinear", input_shape=input_shape),
                            layers.Conv2D(filters_num, (4,4), strides=(4,4), padding = 'same'),
                            layers.LayerNormalization(epsilon = 1e-6)])

    self.ConvNeXtBlock = ConvNeXtBlock(filters_num)
    self.GAP = layers.GlobalAveragePooling2D()                
    self.classifier = layers.Dense(num_class)

  def call(self, inputs, training = None):
    x = self.conv(inputs)
    x = self.ConvNeXtBlock(x)
    x = self.GAP(x)
    outputs = self.classifier(x)

    return outputs


In [11]:
ConvNeXt_model = ConvNeXt(input_shape = x_train.shape[1:], filters_num=64)

In [12]:
ConvNeXt_model.build((None, x_train.shape[1], x_train.shape[2], x_train.shape[3]))
ConvNeXt_model.summary()

Model: "conv_ne_xt"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential (Sequential)     (None, 56, 56, 64)        3264      
                                                                 
 conv_ne_xt_block (ConvNeXtB  multiple                 36416     
 lock)                                                           
                                                                 
 global_average_pooling2d_1   multiple                 0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dense_1 (Dense)             multiple                  650       
                                                                 
Total params: 40,330
Trainable params: 40,330
Non-trainable params: 0
_________________________________________________________________


In [13]:
inputs = np.ones((batch_size, x_train.shape[1], x_train.shape[2], 3), dtype=np.float32)
ConvNeXt_model(inputs).shape

TensorShape([4, 10])

In [14]:
ConvNeXt_model(inputs)

<tf.Tensor: shape=(4, 10), dtype=float32, numpy=
array([[ 1.5723835 ,  1.2864344 ,  0.2006754 , -1.3479763 , -1.7962484 ,
         2.033651  ,  2.7925642 , -1.1791577 , -0.22830898,  1.0241026 ],
       [ 1.5723835 ,  1.2864344 ,  0.2006754 , -1.3479763 , -1.7962484 ,
         2.033651  ,  2.7925642 , -1.1791577 , -0.22830898,  1.0241026 ],
       [ 1.5723835 ,  1.2864344 ,  0.2006754 , -1.3479763 , -1.7962484 ,
         2.033651  ,  2.7925642 , -1.1791577 , -0.22830898,  1.0241026 ],
       [ 1.5723835 ,  1.2864344 ,  0.2006754 , -1.3479763 , -1.7962484 ,
         2.033651  ,  2.7925642 , -1.1791577 , -0.22830898,  1.0241026 ]],
      dtype=float32)>