<a href="https://colab.research.google.com/github/Rishit-dagli/Gradient-Centralization-TensorFlow/blob/main/example/gctf_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GCTF MNIST

This notebook shows the the process of using the [`gradient-centralization-tf`](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow) Python package to train on the [Fashion MNIST](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/fashion_mnist) dataset availaible from [`tf.keras.datasets`](https://www.tensorflow.org/api_docs/python/tf/keras/datasets).Gradient Centralization is a simple and effective optimization technique for Deep Neural Networks as suggested by Yong et al. in the paper 
[Gradient Centralization: A New Optimization Technique for Deep Neural Networks](https://arxiv.org/abs/2004.01461). It can both speedup training 
 process and improve the final generalization performance of DNNs.

## A bit about GC

Gradient Centralization operates directly on gradients by centralizing the gradient vectors to have zero mean. It can both speedup training process and improve the final generalization performance of DNNs. Here is an Illustration of the GC operation on gradient matrix/tensor of weights in the fully-connected layer (left) and convolutional layer (right). GC computes the column/slice mean of gradient matrix/tensor and centralizes each column/slice to have zero mean.

![](https://i.imgur.com/KitoO8J.png)

GC can be viewed as a projected gradient descent method with a constrained loss function. The geometrical interpretation of GC. The gradient is projected on a hyperplane $e^T(w-w^t)=0$, where the projected gradient is used to update the weight.

![](https://i.imgur.com/ekHhQv0.png)

## Setup

In [1]:
import tensorflow as tf
from time import time

### Install the package

Will soon be replaced by `pip install`

In [2]:
import os
from getpass import getpass
import urllib

user = input('User name: ')
password = getpass('Password: ')
password = urllib.parse.quote(password) # your password is converted into url format
repo_name = input('Repo name: ')

cmd_string = 'git clone https://{0}:{1}@github.com/{0}/{2}.git'.format(user, password, repo_name)

os.system(cmd_string)
cmd_string, password = "", "" # removing the password from the variable

User name: Rishit-dagli
Password: ··········
Repo name: Gradient-Centralization-TensorFlow


In [3]:
%cd Gradient-Centralization-TensorFlow/

#Install the package
!pip install -e .
import gctf

/content/Gradient-Centralization-TensorFlow
Obtaining file:///content/Gradient-Centralization-TensorFlow
Installing collected packages: gradient-centralization-tf
  Running setup.py develop for gradient-centralization-tf
Successfully installed gradient-centralization-tf


## Get the data and create model structure

In [4]:
mnist = tf.keras.datasets.fashion_mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
training_images  = training_images / 255.0
test_images = test_images / 255.0

# Model architecture
model = tf.keras.models.Sequential([
                                    tf.keras.layers.Flatten(), 
                                    tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(64, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(512, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(256, activation=tf.nn.relu),
                                    tf.keras.layers.Dense(64, activation=tf.nn.relu), 
                                    tf.keras.layers.Dense(10, activation=tf.nn.softmax)])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


## Train a model without `gctf`

Make a Callback to compute computation time


In [5]:
class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time() - self.epoch_time_start)

In [7]:
time_callback_no_gctf = TimeHistory()

model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss = 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])

history_no_gctf = model.fit(training_images, training_labels, epochs=5, callbacks = [time_callback_no_gctf])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [9]:
print("gctf not used")
print(f"Execution time: {sum(time_callback_no_gctf.times)} s")
print(f"Accuracy: {history_no_gctf.history['accuracy'][-1]}")
print(f"Loss: {history_no_gctf.history['loss'][-1]}")

gctf not used
Execution time: 20.81371283531189 s
Accuracy: 0.8871999979019165
Loss: 0.30977925658226013


## Train a model with `gctf`

In [10]:
time_callback_gctf = TimeHistory()

model.compile(optimizer = gctf.optimizers.adam(),
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

history_gctf = model.fit(training_images, training_labels, epochs=5, callbacks=[time_callback_gctf])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [11]:
print("gctf used")
print(f"Execution time: {sum(time_callback_gctf.times)} s")
print(f"Accuracy: {history_gctf.history['accuracy'][-1]}")
print(f"Loss: {history_gctf.history['loss'][-1]}")

gctf used
Execution time: 19.401079177856445 s
Accuracy: 0.9035999774932861
Loss: 0.2570233643054962
