# Synthetic Gradients

### Resources:
Papers
> [Decoupled Neural Interfaces using Synthetic Gradients, Max Jaderberg et al., 2016](https://arxiv.org/abs/1608.05343)
> [Understanding Synthetic Gradients and Decoupled Neural Interfaces, Wojciech Marian Czarnecki et al., 2017](https://arxiv.org/abs/1703.00522)

Youtube
> [Synthetic Gradients Tutorial by Aurélien Géron](https://youtu.be/1z_Gv98-mkQ)

Github
> [github; jupyter notebook by Nitarshan Rajkumar](https://github.com/nitarshan/decoupled-neural-interfaces/)

In [75]:
# NOTE: this is a custom cell that contains the common imports I personally 
# use these may/may not be necessary for the following examples

# DL framework
import tensorflow as tf

from datetime import datetime

# common packages
import numpy as np
import os # handling file i/o
import sys
import math
import time # timing epochs
import random

# for ordered dict when building layer components
import collections

# plotting pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib import colors # making colors consistent
from mpl_toolkits.axes_grid1 import make_axes_locatable # colorbar helper


from imageio import imread # read image from disk
# + data augmentation
from scipy import ndimage
from scipy import misc


import pickle # manually saving best params
from sklearn.utils import shuffle # shuffling data batches
from tqdm import tqdm # display training progress bar

# const
SEED = 42

# Helper to make the output consistent
def reset_graph(seed=SEED):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

# helper to create dirs if they don't already exist
def maybe_create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        print("{} created".format(dir_path))
    else:
        print("{} already exists".format(dir_path))
    
def make_standard_dirs(saver=True, best_params=True, tf_logs=True):
    # `saver/` will hold tf saver files
    maybe_create_dir("saver")
    # `best_params/` will hold a serialized version of the best params
    # I like to keep this as a backup in case I run into issues with
    # the saver files
    maybe_create_dir("best_params")
    # `tf_logs/` will hold the logs that will be visable in tensorboard
    maybe_create_dir("tf_logs")

    
# set tf log level to supress messages, unless an error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Important Version information
print("Python: {}".format(sys.version_info[:]))
print('TensorFlow: {}'.format(tf.__version__))

# Check if using GPU
if not tf.test.gpu_device_name():
    print('No GPU')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
    
reset_graph()

Python: (3, 5, 4, 'final', 0)
TensorFlow: 1.6.0-dev20180105
No GPU


In [4]:
make_standard_dirs()

saver already exists
best_params already exists
tf_logs already exists


In [5]:
# these two functions (get_model_params and restore_model_params) are 
# ad[a|o]pted from; 
# https://github.com/ageron/handson-ml/blob/master/11_deep_learning.ipynb
def get_model_params():
    global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    return {global_vars.op.name: value for global_vars, value in 
            zip(global_vars, tf.get_default_session().run(global_vars))}

def restore_model_params(model_params, g, sess):
    gvar_names = list(model_params.keys())
    assign_ops = {gvar_name: g.get_operation_by_name(gvar_name + "/Assign")
                  for gvar_name in gvar_names}
    init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
    feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
    sess.run(assign_ops, feed_dict=feed_dict)

# these two functions are used to manually save the best
# model params to disk
def save_obj(obj, name):
    with open('best_params/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('best_params/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

## Dataset

In [None]:
# `!conda install -c conda-forge tqdm`
# from tqdm import tqdm # Used to display training progress bar

In [6]:
ROOT_DATA = "../../ROOT_DATA/"
DATA_DIR = "mnist_data"

MNIST_TRAINING_PATH = os.path.join(ROOT_DATA, DATA_DIR)
# ensure we have the correct directory
for _, _, files in os.walk(MNIST_TRAINING_PATH):
    files = sorted(files)
    for filename in files:
        print(filename)

t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz


In [7]:
from tensorflow.examples.tutorials.mnist import input_data
MNIST = input_data.read_data_sets(MNIST_TRAINING_PATH, one_hot=True)

Extracting ../../ROOT_DATA/mnist_data/train-images-idx3-ubyte.gz
Extracting ../../ROOT_DATA/mnist_data/train-labels-idx1-ubyte.gz
Extracting ../../ROOT_DATA/mnist_data/t10k-images-idx3-ubyte.gz
Extracting ../../ROOT_DATA/mnist_data/t10k-labels-idx1-ubyte.gz


In [8]:
# image dimensions (GLOBAL) - [MG_WIDTH x IMG_HEIGHT, CHANNELS]
# SQUARE_DIM = 299
# if SQUARE_DIM:
#     IMG_WIDTH = SQUARE_DIM
#     IMG_HEIGHT = SQUARE_DIM
# CHANNELS = 3
    
# ROOT_DIR = "../../dataset/record_holder"
# # ensure we have the correct directory
# for _, _, files in os.walk(ROOT_DIR):
#     files = sorted(files)
#     for filename in files:
#         print(filename)

### Read tf records

In [10]:
# GLOBAL_SET_TYPE = None

# def _parse_function(example_proto):
#     global GLOBAL_SET_TYPE
#     labelName = str(GLOBAL_SET_TYPE) + '/label'
#     featureName = str(GLOBAL_SET_TYPE) + '/image'
#     feature = {featureName: tf.FixedLenFeature([], tf.string),
#                labelName: tf.FixedLenFeature([], tf.int64)}
    
#     # decode
#     parsed_features = tf.parse_single_example(example_proto, features=feature)
    
#     # convert image data from string to number
#     image = tf.decode_raw(parsed_features[featureName], tf.float32)
#     image = tf.reshape(image, [IMG_WIDTH, IMG_HEIGHT, CHANNELS])
#     label = tf.cast(parsed_features[labelName], tf.int64)
    
#     # [do any preprocessing here]
    
#     return image, label

# def return_batched_iter(setType, data_params, sess):
#     global GLOBAL_SET_TYPE
#     GLOBAL_SET_TYPE = setType
    
#     filenames_ph = tf.placeholder(tf.string, shape=[None])

#     dataset = tf.data.TFRecordDataset(filenames_ph)
#     dataset = dataset.map(_parse_function)  # Parse the record into tensors.
#     dataset = dataset.shuffle(buffer_size=data_params['buffer_size'])
#     dataset = dataset.batch(data_params['batch_size'])
#     dataset = dataset.repeat(data_params['n_epochs'])
    
#     iterator = dataset.make_initializable_iterator()
    
#     tfrecords_file_name = str(GLOBAL_SET_TYPE) + '.tfrecords'
#     tfrecord_file_path = os.path.join(FINAL_DIR, tfrecords_file_name)
    
#     # initialize
#     sess.run(iterator.initializer, feed_dict={filenames_ph: [tfrecord_file_path]})
    
#     return iterator

### setup

In [9]:
def create_hyper_params():
    data_params = {}
    data_params['n_epochs'] = 5
    data_params['batch_size'] = 16
    data_params['buffer_size'] = 128 # for shuffling

    data_params['init_lr'] = 1e-5
    #data_params['lr_div'] = 10
    #lr_low = int(data_params['n_epochs'] * 0.6)
    #lr_high = int(data_params['n_epochs'] * 0.8)
    #data_params['lr_div_steps'] = set([lr_low, lr_high])

    data_params['update_prob'] = 0.2 # Probability of updating a decoupled layer
    
    return data_params

validation_checkpoint = 1 # How often (epochs) to validate model

In [13]:
data_params = create_hyper_params()

In [11]:
# helpers for creating layers
def dense_layer(inputs, units, name, output=False):
    with tf.variable_scope(name):
        x = tf.layers.dense(inputs, units, name="fc")
        if not output:
            x = tf.layers.batch_normalization(x, name="bn")
            x = tf.nn.relu(x, name="activation_relu")
    return x

def sg_module(inputs, units, name, label):
    with tf.variable_scope(name):
        inputs_c = tf.concat([inputs, label], 1)
        x = tf.layers.dense(inputs_c, units, name="fc", kernel_initializer=tf.zeros_initializer())
    return x

In [76]:
reset_graph()
syntgrad_sess = tf.Session()
backprop_sess = tf.Session()

In [77]:
# network architecture
n_outputs = 10
with tf.variable_scope("architecture"):
    # inputs
    with tf.variable_scope("inputs"):
        #X = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, CHANNELS], name="X")
        # labels
        #y_raw = tf.placeholder(tf.int64, shape=[None, n_outputs], name="y_input")
        #y_ = tf.cast(y_raw, tf.float32)
        X = tf.placeholder(tf.float32, shape=(None, 784), name="data") # Input
        y = tf.placeholder(tf.float32, shape=(None, n_outputs), name="labels") # Target
        
    # Inference Layers
    h_1 = dense_layer(X, 256, "layer_01")
    h_2 = dense_layer(h_1, 256, "layer_02")
    h_3 = dense_layer(h_2, 256, "layer_03")
    logits = dense_layer(h_3, n_outputs, name="layer_04", output=True)
    
    # Synthetic Gradient Layers
    sg_1 = sg_module(h_1, 256, "sg_02", y)
    sg_2 = sg_module(h_2, 256, "sg_03", y)
    sg_3 = sg_module(h_3, 256, "sg_04", y)
    
# collections of trainable variables in each block
layer_vars = [tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_01/"),
              tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_02/"),
              tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_03/"),
              tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_04/")]
sg_vars = [None,
           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/sg_02/"),
           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/sg_03/"),
           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/sg_04/")]

In [78]:
# optimize layer and the sythetic gradient module
def train_layer_n(n, h_m, h_n, sg_m, class_loss, d_n=None):
    with tf.variable_scope("layer"+str(n)):
        layer_grads = tf.gradients(h_n, [h_m]+layer_vars[n-1], d_n)
        layer_gv = list(zip(layer_grads[1:], layer_vars[n-1]))
        layer_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).apply_gradients(layer_gv)
    with tf.variable_scope("sg"+str(n)):
        d_m = layer_grads[0]
        sg_loss = tf.divide(tf.losses.mean_squared_error(sg_m, d_m), class_loss)
        sg_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(sg_loss, var_list=sg_vars[n-1])
    return layer_opt, sg_opt


# Ops: training
with tf.variable_scope("train"):
    with tf.variable_scope("learning_rate"):
        learning_rate = tf.Variable(data_params['init_lr'], dtype=tf.float32, name="lr")
        #reduce_lr = tf.assign(learning_Rate, learning_rate/lr_div, name="lr_decrease")
        
    pred_loss = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=logits, scope="prediction_loss")
    
    # Optimizers when using synthetic gradients
    with tf.variable_scope("synthetic"):
        layer4_opt, sg4_opt = train_layer_n(4, h_3, pred_loss, sg_3, pred_loss)
        layer3_opt, sg3_opt = train_layer_n(3, h_2, h_3, sg_2, pred_loss, sg_3)
        layer2_opt, sg2_opt = train_layer_n(2, h_1, h_2, sg_1, pred_loss, sg_2)
        with tf.variable_scope("layer1"):
            layer1_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(h_1, var_list=layer_vars[0], grad_loss=sg_1)
        
    with tf.variable_scope("backprop"):
        backprop_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(pred_loss)
        
init_global = tf.global_variables_initializer()

In [79]:
# Ops: validation+testing
with tf.variable_scope("test"):
    preds = tf.nn.softmax(logits, name="predictions")
    correct_preds = tf.equal(tf.argmax(preds,1), tf.argmax(y,1), name="correct_predictions")
    accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32), name="correct_prediction_count") / data_params['batch_size']

In [80]:
# Ops: tensorboard
with tf.variable_scope("summary"):
    cost_summary_opt = tf.summary.scalar("loss", pred_loss)
    accuracy_summary_opt = tf.summary.scalar("accuracy", accuracy)
    summary_op = tf.summary.merge_all()

# Training

## Backprop (locked)

In [81]:
# backprop
with backprop_sess.as_default():
    backprop_train_path = os.path.join("tf_logs","backprop","train")
    backprop_train_writer = tf.summary.FileWriter(backprop_train_path)
    backprop_validation_path = os.path.join("tf_logs","backprop","validation")
    backprop_validation_writer = tf.summary.FileWriter(backprop_validation_path)
    
    backprop_sess.run(init_global)
    for i in tqdm(range(1,data_params['n_epochs']+1)):
        data, target = MNIST.train.next_batch(data_params['batch_size'])
        _, summary = backprop_sess.run([backprop_opt, summary_op], feed_dict={X:data, y:target})
        backprop_train_writer.add_summary(summary, i)
        
        # run validation
        # TODO: update this to validation
        Xb, yb = MNIST.test.next_batch(data_params['batch_size'])
        summary = backprop_sess.run([summary_op], feed_dict={X:Xb, y:yb})[0]
        backprop_validation_writer.add_summary(summary, i)
        
    # close writers
    backprop_train_writer.close()
    backprop_validation_writer.close()

100%|██████████| 5/5 [00:00<00:00, 53.10it/s]


In [83]:
with syntgrad_sess.as_default():
    sg_train_path = os.path.join("tf_logs","sg","train")
    sg_train_writer = tf.summary.FileWriter(sg_train_path, syntgrad_sess.graph)
    sg_validation_path = os.path.join("tf_logs","sg","train")
    sg_validation_writer = tf.summary.FileWriter(sg_validation_path)
    
    syntgrad_sess.run(init_global)
    for i in tqdm(range(1,data_params['n_epochs']+1)):
        data, target = MNIST.train.next_batch(data_params['batch_size'])
        
        # The layers here could be independently updated (data parallism) - device placement
        # > stochastic updates are possible
        if random.random() <= data_params['update_prob']:
            syntgrad_sess.run([layer1_opt], feed_dict={X:data, y:target})
        if random.random() <= data_params['update_prob']:
            syntgrad_sess.run([layer2_opt, sg2_opt], feed_dict={X:data, y:target})
        if random.random() <= data_params['update_prob']:
            syntgrad_sess.run([layer3_opt, sg3_opt], feed_dict={X:data, y:target})
        if random.random() <= data_params['update_prob']:
            _, _, summary = syntgrad_sess.run([layer4_opt, sg4_opt, summary_op], feed_dict={X:data, y:target})
            sg_train_writer.add_summary(summary, i)
            
        # validation
        # TODO: convert to validation
        Xb, yb = MNIST.test.next_batch(data_params['batch_size'])
        summary = syntgrad_sess.run([summary_op], feed_dict={X:Xb, y:yb})[0]
        sg_validation_writer.add_summary(summary, i)
        
    sg_train_writer.close()
    sg_validation_writer.close()

100%|██████████| 5/5 [00:00<00:00, 26.72it/s]


In [92]:
# Test using backprop
with backprop_sess.as_default():
    n_batches = int(MNIST.test.num_examples/data_params['batch_size'])
    test_accuracy = 0
    test_loss = 0
    for _ in range(n_batches):
        Xb, yb = MNIST.test.next_batch(data_params['batch_size'])
        batch_accuracy, batch_loss = backprop_sess.run([accuracy, pred_loss], feed_dict={X:Xb,y:yb})
        test_accuracy += batch_accuracy
        test_loss += batch_loss
    print("loss: {:.4f}".format(test_loss/n_batches))
    print("acc: {:.4f}".format(test_accuracy/n_batches))

loss: 2.3280
acc: 0.0928


In [93]:
# Test using backprop
with syntgrad_sess.as_default():
    n_batches = int(MNIST.test.num_examples/data_params['batch_size'])
    test_accuracy = 0
    test_loss = 0
    for _ in range(n_batches):
        Xb, yb = MNIST.test.next_batch(data_params['batch_size'])
        batch_accuracy, batch_loss = backprop_sess.run([accuracy, pred_loss], feed_dict={X:Xb,y:yb})
        test_accuracy += batch_accuracy
        test_loss += batch_loss
    print("loss: {:.4f}".format(test_loss/n_batches))
    print("acc: {:.4f}".format(test_accuracy/n_batches))

loss: 2.3293
acc: 0.0914


In [95]:
# Cleanup
backprop_sess.close()
syntgrad_sess.close()