# GCTF MNIST

This notebook shows the the process of using the [`gradient-centralization-tf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow) Python package to train on the [Fashion MNIST](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/fashion_mnist) dataset availaible from [`tf.keras.datasets`](https://www.tensorflow.org/api_docs/python/tf/keras/datasets).Gradient Centralization is a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper 
[Gradient Centralization: A New Optimization Technique for Deep Neural Networks](https://arxiv.org/abs/2004.01461). It can both speedup training 
 process and improve the final generalization performance of DNNs.

## A bit about GC

Gradient Centralization operates directly on gradients by centralizing the gradient vectors to have zero mean. It can both speedup training process and improve the final generalization performance of DNNs. Here is an Illustration of the GC operation on gradient matrix/tensor of weights in the fully-connected layer (left) and convolutional layer (right). GC computes the column/slice mean of gradient matrix/tensor and centralizes each column/slice to have zero mean.

![](https://i.imgur.com/KitoO8J.png)

GC can be viewed as a projected gradient descent method with a constrained loss function. The geometrical interpretation of GC. The gradient is projected on a hyperplane $e^T(w-w^t)=0$, where the projected gradient is used to update the weight.

![](https://i.imgur.com/ekHhQv0.png)

## Setup

In [1]:
import tensorflow as tf
from time import time

### Install the package

In [None]:
!pip install gradient-centralization-tf

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gradient-centralization-tf
  Downloading gradient_centralization_tf-0.0.3-py3-none-any.whl (10 kB)
Collecting keras~=2.4.0
  Downloading Keras-2.4.3-py2.py3-none-any.whl (36 kB)
Collecting tensorflow>=2.2.0
  Downloading tensorflow-2.11.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (588.3 MB)
[K     |████████████████████████████████| 588.3 MB 6.4 kB/s 
[?25hCollecting flatbuffers>=2.0
  Downloading flatbuffers-22.10.26-py2.py3-none-any.whl (26 kB)
Collecting tensorflow-estimator<2.12,>=2.11.0
  Downloading tensorflow_estimator-2.11.0-py2.py3-none-any.whl (439 kB)
[K     |████████████████████████████████| 439 kB 68.7 MB/s 
Collecting tensorflow>=2.2.0
  Downloading tensorflow-2.10.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (578.1 MB)
[K     |████████████████████████████████| 578.1 MB 29 kB/s 
[?25h  Downloading tensorflow-2.10.0-cp37-cp37m-man

## Get the data and create model structure

In [None]:
mnist = tf.keras.datasets.fashion_mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
training_images  = training_images / 255.0
test_images = test_images / 255.0

# Model architecture
model = tf.keras.models.Sequential([
                                    tf.keras.layers.Flatten(), 
                                    tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(64, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(64, activation=tf.nn.relu), 
                                    tf.keras.layers.Dense(10, activation=tf.nn.softmax)])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


## Train a model without `gctf`

Make a Callback to compute computation time


In [None]:
class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time() - self.epoch_time_start)

In [None]:
time_callback_no_gctf = TimeHistory()

model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss = 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])

history_no_gctf = model.fit(training_images, training_labels, epochs=5, callbacks = [time_callback_no_gctf])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Train a model with `gctf`

In [None]:
import gctf #import gctf

time_callback_gctf = TimeHistory()

model.compile(optimizer = gctf.optimizers.adam(),
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

history_gctf = model.fit(training_images, training_labels, epochs=5, callbacks=[time_callback_gctf])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


## Compare results

In this example we are further interested in also comparing the results

In [None]:
#Compare Results
from tabulate import tabulate

data = [["Model without gctf:",sum(time_callback_no_gctf.times),history_no_gctf.history['accuracy'][-1],history_no_gctf.history['loss'][-1]],
        ["Model with gctf",sum(time_callback_gctf.times),history_gctf.history['accuracy'][-1],history_gctf.history['loss'][-1]]] 

print(tabulate(data, headers=["Type","Execution time", "Accuracy", "Loss"]))

Type                   Execution time    Accuracy      Loss
-------------------  ----------------  ----------  --------
Model without gctf:            20.183    0.887617  0.310299
Model with gctf                18.464    0.916467  0.22555
