<a href="https://colab.research.google.com/github/PCSchmidt/My-Projects/blob/main/TutorialOnGradientCentralization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Use of Gradient Centralization for Better Training Performance

This is a tutorial on the new optimization technique, [Gradient Centralization](https://arxiv.org/abs/2004.01461), for deep neural networks developed by Hongwei Yong, et all that operates directly on gradients by centralizing the gradient vectors to have mean zero. Viewed as a projected gradient descent method with a constrained loss function, gradient centralization can regularize both the weight space and output feature space to boost the generalization performance of deep neural networks. Gradient centralization improves the Lipschitzness of the loss function and its gradient so that the training process becomes more efficient and stable. 

If `tensorflow_datasets` is needed it can be installed using the
```
pip install tensorflow-datasets
```
command.  This tutorial builds a Gradient Centralization example but there is a package to speed up the process available at [gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow).

In [15]:
pip install tensorflow-datasets



## The Setup

In [16]:
from time import time

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras.optimizers import RMSprop

## Prepare the Data

For this example, we use the [Horses or Humans
dataset](https://www.tensorflow.org/datasets/catalog/horses_or_humans) put together by Laurence Moroney.

In [17]:
num_classes = 2
input_shape = (300, 300, 3)
dataset_name = "horses_or_humans"
batch_size = 128
AUTOTUNE = tf.data.experimental.AUTOTUNE

(train_ds, test_ds), metadata = tfds.load(
    name=dataset_name,
    split=[tfds.Split.TRAIN, tfds.Split.TEST],
    with_info=True,
    as_supervised=True,
)

print(f"Image shape: {metadata.features['image'].shape}")
print(f"Training images: {metadata.splits['train'].num_examples}")
print(f"Test images: {metadata.splits['test'].num_examples}")

Image shape: (300, 300, 3)
Training images: 1027
Test images: 256


## Use Data Augmentation

We rescale the data to `[0, 1]` and perform simple augmentations to our data. 

In [18]:
rescale = layers.experimental.preprocessing.Rescaling(1.0 / 255)

data_augmentation = tf.keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
        layers.experimental.preprocessing.RandomRotation(0.3),
        layers.experimental.preprocessing.RandomZoom(0.2),
    ]
)


def prepare(ds, shuffle=False, augment=False):
    # Rescale dataset
    ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(1024)

    # Batch dataset
    ds = ds.batch(batch_size)

    # Use data augmentation only on the training set
    if augment:
        ds = ds.map(
            lambda x, y: (data_augmentation(x, training=True), y),
            num_parallel_calls=AUTOTUNE,
        )

    # Use buffered prefecting
    return ds.prefetch(buffer_size=AUTOTUNE)


Rescale and augment the data

In [19]:
train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

## Define a Model

This next section defines a convolutional neural network with 16 layers. 

In [20]:
model = tf.keras.Sequential(
    [
        layers.Conv2D(16, (3, 3), activation="relu", input_shape=(300, 300, 3)),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(32, (3, 3), activation="relu"),
        layers.Dropout(0.5),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.Dropout(0.5),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(512, activation="relu"),
        layers.Dense(1, activation="sigmoid"),
    ]
)

## Implement Gradient Centralization

We will now subclass the `RMSprop` optimizer class modifying the `tf.keras.optimizers.Optimizer.get_gradients()` method where we can implement Gradient Centralization. On a high level the idea is that we obtain the gradients through back  propagation for a Dense or Convolution layer then compute the mean of the column vectors of the weight matrix and then remove the mean from each column vector. 

The experiments in this [this paper](https://arxiv.org/abs/2004.01461) on various
applications, including general image classification, fine-grained image classification,
detection and segmentation and Person ReID demonstrate that GC can consistently improve
the performance of DNN learning.

To keep things at a less complex level, this tutorial does not implement gradient clippling functionality, however this is quite easy to do.

We are just creating a subclass for the `RMSprop` optimizer but this can easily be reproduced for any other optimizer or on a custom optimizer in the same way. We will be using this class in the later section when we train a model with Gradient Centralization. 

In [21]:
class GCRMSprop(RMSprop):
  def get_gradients(self, loss, params):
    # Only need to provide a modfified get_gradients() function since we are
    # trying to only compute the centralized gradients.

    grads = []
    gradients = super().get_gradients()
    for grad in gradients:
      grad_len = len(grad.shape)
      if grad_len > 1:
        axis = list(range(grad_len -1))
        grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)
      grads.append(grad)

    return grads  

optimizer = GCRMSprop(learning_rate=1e-4)


## Training Utilities

We also create a callback which allows us to measure the total training time and the time taken for each epoch so that we have the tools to measure the impact of implementing Gradient Centralization on the model built above. 

In [22]:
class TimeHistory(tf.keras.callbacks.Callback):
  def on_train_begin(self, logs={}):
    self.times = []
  
  def on_epoch_begin(self, batch, logs={}):
    self.epoch_time_start = time()

  def on_epoch_end(self, batch, logs={}):
    self.times.append(time() - self.epoch_time_start)
    

## Train the Model without Gradient Centralization

First, we train the model without Gradient Centralization which we can then compare to the model trained with Gradient Centralization.  

In [23]:
time_callback_no_gc = TimeHistory()

model.compile(
    loss="binary_crossentropy",
    optimizer=RMSprop(learning_rate=1e-4),
    metrics=['accuracy'],
)

model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_5 (Conv2D)            (None, 298, 298, 16)      448       
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 149, 149, 16)      0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 147, 147, 32)      4640      
_________________________________________________________________
dropout_3 (Dropout)          (None, 147, 147, 32)      0         
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 73, 73, 32)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 71, 71, 64)        18496     
_________________________________________________________________
dropout_4 (Dropout)          (None, 71, 71, 64)       

We need to save the history so that we can compare the models with and without Gradient Centralization.

In [24]:
history_no_gc = model.fit(
    train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


## Train the Model with Gradient Centralization

Next, we train the same model with Gradient Centralization. Our optimizer is the one that uses gradient centralization this time. 

In [25]:
time_callback_gc = TimeHistory()

model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=['accuracy']) # optimizer = GCRMSprop(learning_rate=1e-4)

model.summary()

history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])


Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_5 (Conv2D)            (None, 298, 298, 16)      448       
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 149, 149, 16)      0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 147, 147, 32)      4640      
_________________________________________________________________
dropout_3 (Dropout)          (None, 147, 147, 32)      0         
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 73, 73, 32)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 71, 71, 64)        18496     
_________________________________________________________________
dropout_4 (Dropout)          (None, 71, 71, 64)       

##  Comparing Performance

We can now compare the performance of the two approaches. 

In [26]:
  print("Not using Gradient Centralization")
  print(f"Loss: {history_no_gc.history['loss'][-1]}")
  print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
  print(f"Training Time: {sum(time_callback_no_gc.times)}")

  print("Using Gradient Centralization")
  print(f"Loss: {history_gc.history['loss'][-1]}")
  print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
  print(f"Training Time: {sum(time_callback_gc.times)}")

Not using Gradient Centralization
Loss: 0.4718668758869171
Accuracy: 0.7750730514526367
Training Time: 165.5166256427765
Using Gradient Centralization
Loss: 0.3393610119819641
Accuracy: 0.8646543622016907
Training Time: 165.13558745384216
