Experiment 1: Combination of BatchNorm and GroupNorm

In [None]:
try:
  %tensorflow_version 2.x
except:
  pass

import tensorflow as tf

We need to install Module: tfa (TensorFlow Addons) to call Group Normalisation

In [None]:
pip install -q  --no-deps tensorflow-addons~=0.7

In [None]:
import tensorflow_addons as tfa

Calling and Splitting Mnist

In [None]:
#Downloading Mnist
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Experiment Begins - 
* 1. Only Batch
* 2. Only Group
* 3. Batch then Group
* 4. Group then Batch

The model remains the same for all experiments except the added and subtracted Norm layers

In [None]:
#Instance 1 - Only BatchNorm
model = tf.keras.models.Sequential([
  # Reshape into "channels last" setup.
  tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3),data_format="channels_last"),
  tf.keras.layers.BatchNormalization(axis=3),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.summary()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_test, y_test)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 10)        100       
_________________________________________________________________
batch_normalization (BatchNo (None, 26, 26, 10)        40        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 24, 24, 10)        910       
_________________________________________________________________
flatten (Flatten)            (None, 5760)              0         
_________________________________________________________________
dense (Dense)                (None, 128)               737408    
_________________________________________________________________
dropout (Dropout)            (None, 128)               0

<tensorflow.python.keras.callbacks.History at 0x7f6d31a1ce10>

In [None]:
#Instance 2 - Only GroupNorm
model = tf.keras.models.Sequential([
  # Reshape into "channels last" setup.
  tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3),data_format="channels_last"),
  tfa.layers.GroupNormalization(groups=5, axis=3),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.summary()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_test, y_test)

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 26, 26, 10)        100       
_________________________________________________________________
group_normalization (GroupNo (None, 26, 26, 10)        20        
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 24, 24, 10)        910       
_________________________________________________________________
flatten_1 (Flatten)          (None, 5760)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               737408    
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)              

<tensorflow.python.keras.callbacks.History at 0x7f6d2eaf1668>

In [None]:
#Instance 3 - Batch and then Group
model = tf.keras.models.Sequential([
  # Reshape into "channels last" setup.
  tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3),data_format="channels_last"),
  tfa.layers.GroupNormalization(groups=5, axis=3),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3)),
  tf.keras.layers.BatchNormalization(axis=3),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.summary()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_test, y_test)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_2 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 26, 26, 10)        100       
_________________________________________________________________
group_normalization_1 (Group (None, 26, 26, 10)        20        
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 24, 24, 10)        910       
_________________________________________________________________
batch_normalization_1 (Batch (None, 24, 24, 10)        40        
_________________________________________________________________
flatten_2 (Flatten)          (None, 5760)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 128)              

<tensorflow.python.keras.callbacks.History at 0x7f6d2e8a4160>

In [None]:
#Instance 4 - Group Then Batch
model = tf.keras.models.Sequential([
  # Reshape into "channels last" setup.
  tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3),data_format="channels_last"),
  tf.keras.layers.BatchNormalization(axis=3),
  tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3)),
  tfa.layers.GroupNormalization(groups=5, axis=3),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.summary()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_test, y_test)

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape_3 (Reshape)          (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 26, 26, 10)        100       
_________________________________________________________________
batch_normalization_2 (Batch (None, 26, 26, 10)        40        
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 24, 24, 10)        910       
_________________________________________________________________
group_normalization_2 (Group (None, 24, 24, 10)        20        
_________________________________________________________________
flatten_3 (Flatten)          (None, 5760)              0         
_________________________________________________________________
dense_6 (Dense)              (None, 128)              

<tensorflow.python.keras.callbacks.History at 0x7f6d2decb470>

Final Results:
BatchNorm Alone -   loss: 0.4021 - accuracy: 0.8783  
GroupNorm Alone -   loss: 0.4263 - accuracy: 0.8714  
BatchNorm, GroupNorm -   loss: 0.4598 - accuracy: 0.8695  
GroupNorm, Batch Norm -  loss: 0.4683 - accuracy: 0.8658  