# Synthetic Gradients

### Resources:
Papers
> [Decoupled Neural Interfaces using Synthetic Gradients, Max Jaderberg et al., 2016](https://arxiv.org/abs/1608.05343)
> [Understanding Synthetic Gradients and Decoupled Neural Interfaces, Wojciech Marian Czarnecki et al., 2017](https://arxiv.org/abs/1703.00522)

Youtube
> [Synthetic Gradients Tutorial by Aurélien Géron](https://youtu.be/1z_Gv98-mkQ)

Github
> [github; jupyter notebook by Nitarshan Rajkumar](https://github.com/nitarshan/decoupled-neural-interfaces/)

In [2]:
# NOTE: this is a custom cell that contains the common imports I personally 
# use these may/may not be necessary for the following examples

# DL framework
import tensorflow as tf

from datetime import datetime

# common packages
import numpy as np
import os # handling file i/o
import sys
import math
import time # timing epochs
import random

# for ordered dict when building layer components
import collections

# plotting pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import pyplot
from matplotlib import colors # making colors consistent
from mpl_toolkits.axes_grid1 import make_axes_locatable # colorbar helper


# from imageio import imread # read image from disk
# + data augmentation
from scipy import ndimage
from scipy import misc


import pickle # manually saving best params
from sklearn.utils import shuffle # shuffling data batches
from tqdm import tqdm # display training progress bar

# const
SEED = 42

# Helper to make the output consistent
def reset_graph(seed=SEED):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

# helper to create dirs if they don't already exist
def maybe_create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        print("{} created".format(dir_path))
    else:
        print("{} already exists".format(dir_path))
    
def make_standard_dirs(saver=True, best_params=True, tf_logs=True):
    # `saver/` will hold tf saver files
    maybe_create_dir("saver")
    # `best_params/` will hold a serialized version of the best params
    # I like to keep this as a backup in case I run into issues with
    # the saver files
    maybe_create_dir("best_params")
    # `tf_logs/` will hold the logs that will be visable in tensorboard
    maybe_create_dir("tf_logs")

    
# set tf log level to supress messages, unless an error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Important Version information
print("Python: {}".format(sys.version_info[:]))
print('TensorFlow: {}'.format(tf.__version__))

# Check if using GPU
if not tf.test.gpu_device_name():
    print('No GPU')
else:
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
    
reset_graph()

Python: (3, 5, 4, 'final', 0)
TensorFlow: 1.4.0
Default GPU Device: /device:GPU:0


In [3]:
make_standard_dirs()

saver already exists
best_params already exists
tf_logs already exists


In [4]:
### Clean all logs
## WARNING! You likely don't want to do this (but if you do, this is a convenient call)
# !rm -r -f ./tf_logs/*

In [5]:
# these two functions (get_model_params and restore_model_params) are 
# ad[a|o]pted from; 
# https://github.com/ageron/handson-ml/blob/master/11_deep_learning.ipynb
def get_model_params():
    global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    return {global_vars.op.name: value for global_vars, value in 
            zip(global_vars, tf.get_default_session().run(global_vars))}

def restore_model_params(model_params, g, sess):
    gvar_names = list(model_params.keys())
    assign_ops = {gvar_name: g.get_operation_by_name(gvar_name + "/Assign")
                  for gvar_name in gvar_names}
    init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
    feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
    sess.run(assign_ops, feed_dict=feed_dict)

# these two functions are used to manually save the best
# model params to disk
def save_obj(obj, name):
    with open('best_params/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('best_params/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

## Dataset

In [6]:
# `!conda install -c conda-forge tqdm`
# from tqdm import tqdm # Used to display training progress bar

In [7]:
ROOT_DATA = "../../ROOT_DATA/"
DATA_DIR = "mnist_data"

MNIST_TRAINING_PATH = os.path.join(ROOT_DATA, DATA_DIR)
# ensure we have the correct directory
for _, _, files in os.walk(MNIST_TRAINING_PATH):
    files = sorted(files)
    for filename in files:
        print(filename)

t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz


In [8]:
from tensorflow.examples.tutorials.mnist import input_data
MNIST = input_data.read_data_sets(MNIST_TRAINING_PATH, one_hot=True)

Extracting ../../ROOT_DATA/mnist_data/train-images-idx3-ubyte.gz
Extracting ../../ROOT_DATA/mnist_data/train-labels-idx1-ubyte.gz
Extracting ../../ROOT_DATA/mnist_data/t10k-images-idx3-ubyte.gz
Extracting ../../ROOT_DATA/mnist_data/t10k-labels-idx1-ubyte.gz


In [9]:
# image dimensions (GLOBAL) - [MG_WIDTH x IMG_HEIGHT, CHANNELS]
#SQUARE_DIM = 299
SQUARE_DIM = 28
if SQUARE_DIM:
    IMG_WIDTH = SQUARE_DIM
    IMG_HEIGHT = SQUARE_DIM
#CHANNELS = 3
CHANNELS = 1
    
# ROOT_DIR = "../../dataset/record_holder"
# # ensure we have the correct directory
# for _, _, files in os.walk(ROOT_DIR):
#     files = sorted(files)
#     for filename in files:
#         print(filename)

### Read tf records

In [10]:
# GLOBAL_SET_TYPE = None

# def _parse_function(example_proto):
#     global GLOBAL_SET_TYPE
#     labelName = str(GLOBAL_SET_TYPE) + '/label'
#     featureName = str(GLOBAL_SET_TYPE) + '/image'
#     feature = {featureName: tf.FixedLenFeature([], tf.string),
#                labelName: tf.FixedLenFeature([], tf.int64)}
    
#     # decode
#     parsed_features = tf.parse_single_example(example_proto, features=feature)
    
#     # convert image data from string to number
#     image = tf.decode_raw(parsed_features[featureName], tf.float32)
#     image = tf.reshape(image, [IMG_WIDTH, IMG_HEIGHT, CHANNELS])
#     label = tf.cast(parsed_features[labelName], tf.int64)
    
#     # [do any preprocessing here]
    
#     return image, label

# def return_batched_iter(setType, data_params, sess):
#     global GLOBAL_SET_TYPE
#     GLOBAL_SET_TYPE = setType
    
#     filenames_ph = tf.placeholder(tf.string, shape=[None])

#     dataset = tf.data.TFRecordDataset(filenames_ph)
#     dataset = dataset.map(_parse_function)  # Parse the record into tensors.
#     dataset = dataset.shuffle(buffer_size=data_params['buffer_size'])
#     dataset = dataset.batch(data_params['batch_size'])
#     dataset = dataset.repeat(data_params['n_epochs'])
    
#     iterator = dataset.make_initializable_iterator()
    
#     tfrecords_file_name = str(GLOBAL_SET_TYPE) + '.tfrecords'
#     tfrecord_file_path = os.path.join(FINAL_DIR, tfrecords_file_name)
    
#     # initialize
#     sess.run(iterator.initializer, feed_dict={filenames_ph: [tfrecord_file_path]})
    
#     return iterator

### setup

In [11]:
def create_hyper_params():
    data_params = {}
    data_params['n_epochs'] = 20
    data_params['batch_size'] = 512
    data_params['buffer_size'] = 128 # for shuffling

    data_params['init_lr'] = 1e-5
    #data_params['lr_div'] = 10
    #lr_low = int(data_params['n_epochs'] * 0.6)
    #lr_high = int(data_params['n_epochs'] * 0.8)
    #data_params['lr_div_steps'] = set([lr_low, lr_high])

    data_params['update_prob'] = 0.2 # Probability of updating a decoupled layer
    
    return data_params

validation_checkpoint = 1 # How often (epochs) to validate model

In [12]:
data_params = create_hyper_params()

In [13]:
# helpers for creating layers
def dense_layer(inputs, units, name, output=False):
    with tf.variable_scope(name):
        x = tf.layers.dense(inputs, units, name="fc")
        if not output:
            x = tf.layers.batch_normalization(x, name="bn")
            x = tf.nn.relu(x, name="activation_relu")
    return x

def sg_module(inputs, units, name, label):
    with tf.variable_scope(name):
        inputs_c = tf.concat([inputs, label], 1)
        x = tf.layers.dense(inputs_c, units, name="fc", kernel_initializer=tf.zeros_initializer())
    return x


def create_conv_layer(inputs, units, kern_shape, name, stride_len, output=False):
    with tf.variable_scope(name):
        x = tf.layers.conv2d(inputs, filters=units, kernel_size=kern_shape, 
                             padding='SAME', strides=stride_len,
                             name="conv")
        if not output:
            x = tf.nn.relu(x, name="activation_relu")
    return x


def create_conv_sg_module(inputs, units, kern_shape, name, label, stride_len):
    with tf.variable_scope(name):
        # TODO: really unsure about this concat here.
        # using in attempt to adapt: https://github.com/vyraun/DNI-tensorflow/blob/master/utils.py
        out_shape = inputs.get_shape().as_list()
        label_shape = label.get_shape().as_list()
        #label_tile = tf.reshape(tf.tile(label, [1,out_shape[1]*out_shape[2]]), 
         #                       [out_shape[0], out_shape[1], out_shape[2], label_shape[1]])
        #inputs_c = tf.concat(3, [inputs, label_tile])
        #inputs_c = tf.concat(3, [inputs, label])
        x = tf.layers.conv2d(inputs, filters=units, kernel_size=kern_shape, 
                             padding='SAME', strides=stride_len,
                             name="conv")
    return x

In [14]:
reset_graph()
syntgrad_sess = tf.Session()
backprop_sess = tf.Session()

In [15]:
# network architecture
n_outputs = 10
with tf.variable_scope("architecture"):
    # inputs
    with tf.variable_scope("inputs"):
        #X = tf.placeholder(tf.float32, shape=[None, IMG_HEIGHT, IMG_WIDTH, CHANNELS], name="X")
        # labels
        X = tf.placeholder(tf.float32, shape=(None, 784), name="data") # Input
        X_reshaped = tf.reshape(X, shape=[-1, IMG_HEIGHT, IMG_WIDTH, CHANNELS])
        y = tf.placeholder(tf.float32, shape=(None, n_outputs), name="labels") # Target
        
    # Inference Layers
    # conv
    h_1 = create_conv_layer(X_reshaped, units=12, kern_shape=3, name="layer_01", stride_len=1)
    h_2 = create_conv_layer(h_1, units=24, kern_shape=3, name="layer_02", stride_len=1)
    h_3 = create_conv_layer(h_2, units=36, kern_shape=3, name="layer_03", stride_len=2)
    # pooling
    h_4 = tf.layers.max_pooling2d(h_3, pool_size=[2,2], strides=2, name="max_pool_01")
    last_shape = int(np.prod(h_4.get_shape()[1:]))
    pool_flat = tf.reshape(h_4, shape=[-1, last_shape])
    
    # fc
    h_5 = dense_layer(pool_flat, 64, "layer_05")
    logits = dense_layer(h_5, n_outputs, name="layer_06", output=True)
    
    # Synthetic Gradient Layers
    sg_1 = create_conv_sg_module(h_1, units=12, kern_shape=3, name="sg_02", label=y, stride_len=1)
    sg_2 = create_conv_sg_module(h_2, units=24, kern_shape=3, name="sg_03", label=y, stride_len=1)
    sg_4 = create_conv_sg_module(h_3, units=36, kern_shape=3, name="sg_05", label=y, stride_len=2)
    sg_5 = sg_module(pool_flat, 64, "sg_06", y)
    
# collections of trainable variables in each block
layer_vars = [tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_01/"),
              tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_02/"),
              tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_03/"),
              None,
              tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_05/"),
              tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/layer_06/")]
sg_vars = [None,
           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/sg_02/"),
           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/sg_03/"),
           None,
           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/sg_05/"),
           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="architecture/sg_06/")]

In [16]:
# optimize layer and the sythetic gradient module
def train_layer_n(n, h_m, h_n, sg_m, class_loss, d_n=None):
    with tf.variable_scope("layer_0"+str(n)):
        layer_grads = tf.gradients(h_n, [h_m]+layer_vars[n-1], d_n)
        layer_gv = list(zip(layer_grads[1:], layer_vars[n-1]))
        layer_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).apply_gradients(layer_gv)
    with tf.variable_scope("sg_0"+str(n)):
        d_m = layer_grads[0]
        sg_loss = tf.divide(tf.losses.mean_squared_error(sg_m, d_m), class_loss)
        sg_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(sg_loss, var_list=sg_vars[n-1])
    return layer_opt, sg_opt

# Ops: training
with tf.variable_scope("train"):
    with tf.variable_scope("learning_rate"):
        learning_rate = tf.Variable(data_params['init_lr'], dtype=tf.float32, name="lr")
        #reduce_lr = tf.assign(learning_Rate, learning_rate/lr_div, name="lr_decrease")
        
    #pred_loss = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=logits, scope="prediction_loss")
    with tf.variable_scope("prediction_loss"):
        #pred_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y)
        pred_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)
        batch_loss = tf.reduce_mean(pred_loss)
    
    # Optimizers when using synthetic gradients
    with tf.variable_scope("synthetic"):
        layer6_opt, sg6_opt = train_layer_n(6, h_5, pred_loss, sg_5, pred_loss)
        layer5_opt, sg5_opt = train_layer_n(5, h_4, h_5, sg_4, pred_loss, sg_5)
        # none
        layer3_opt, sg3_opt = train_layer_n(3, h_2, h_3, sg_2, pred_loss, h_3)
        layer2_opt, sg2_opt = train_layer_n(2, h_1, h_2, sg_1, pred_loss, sg_2)
        with tf.variable_scope("layer_01"):
            layer1_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(h_1, var_list=layer_vars[0], grad_loss=sg_1)
        
    with tf.variable_scope("backprop"):
        backprop_opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(pred_loss)
        
init_global = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()

In [17]:
# Ops: training metrics
with tf.name_scope("metrics") as scope:
    with tf.name_scope("train_metrics") as scope:
        preds = tf.nn.softmax(logits, name="prediction")
        train_y_true_cls = tf.argmax(y,1)
        train_y_pred_cls = tf.argmax(preds,1)

        train_correct_prediction = tf.equal(train_y_pred_cls, train_y_true_cls, name="correct_predictions")
        train_batch_acc = tf.reduce_mean(tf.cast(train_correct_prediction, tf.float32))

        train_auc, train_auc_update = tf.metrics.auc(labels=y, predictions=preds)
        train_acc, train_acc_update = tf.metrics.accuracy(labels=train_y_true_cls, predictions=train_y_pred_cls)

        train_acc_vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
        train_acc_reset_op = tf.variables_initializer(train_acc_vars, name="train_acc_reset_op")

        #for node in (y_, preds, train_y_true_cls, train_y_pred_cls, correct_prediction, train_batch_acc):
                #g.add_to_collection("label_nodes", node)

    # Ops: validation metrics
    with tf.name_scope("validation_metrics") as scope:
        preds = tf.nn.softmax(logits, name="prediction")
        val_y_true_cls = tf.argmax(y,1)
        val_y_pred_cls = tf.argmax(preds,1)        

        val_correct_prediction = tf.equal(val_y_pred_cls, val_y_true_cls)
        val_batch_acc = tf.reduce_mean(tf.cast(val_correct_prediction, tf.float32))

        val_auc, val_auc_update = tf.metrics.auc(labels=y, predictions=preds)
        val_acc, val_acc_update = tf.metrics.accuracy(labels=val_y_true_cls, predictions=val_y_pred_cls)

        val_acc_vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
        val_acc_reset_op = tf.variables_initializer(val_acc_vars, name="val_acc_reset_op")

    # Ops: test metrics
    with tf.name_scope("test_metrics") as scope:    
        preds = tf.nn.softmax(logits, name="prediction")
        test_y_true_cls = tf.argmax(y,1)
        test_y_pred_cls = tf.argmax(preds,1)

        test_auc, test_auc_update = tf.metrics.auc(labels=y, predictions=preds)

        test_acc, test_acc_update = tf.metrics.accuracy(labels=test_y_true_cls, predictions=test_y_pred_cls)
        test_acc_vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
        test_acc_reset_op = tf.variables_initializer(test_acc_vars, name="test_acc_reset_op")

    # =============================================== loss 
    with tf.name_scope("train_loss_eval") as scope:
        train_mean_loss, train_mean_loss_update = tf.metrics.mean(batch_loss)
        train_loss_vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
        train_loss_reset_op = tf.variables_initializer(train_loss_vars, name="train_loss_reset_op")

    with tf.name_scope("val_loss_eval") as scope:
        val_mean_loss, val_mean_loss_update = tf.metrics.mean(batch_loss)
        val_loss_vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
        val_loss_reset_op = tf.variables_initializer(val_loss_vars, name="val_loss_reset_op")

    with tf.name_scope("test_loss_eval") as scope:
        test_mean_loss, test_mean_loss_update = tf.metrics.mean(batch_loss)
        test_loss_vars = tf.contrib.framework.get_variables(scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
        test_loss_reset_op = tf.variables_initializer(test_loss_vars, name="test_loss_rest_op")

In [18]:
# =====combine operations
# ===== epoch, train
with tf.name_scope("tensorboard_writer") as scope:
    epoch_train_loss_scalar = tf.summary.scalar('train_epoch_loss', train_mean_loss)
    epoch_train_acc_scalar = tf.summary.scalar('train_epoch_acc', train_acc)
    epoch_train_auc_scalar = tf.summary.scalar('train_epoch_auc', train_auc)
    epoch_train_write_op = tf.summary.merge([epoch_train_loss_scalar, epoch_train_acc_scalar, epoch_train_auc_scalar], name="epoch_train_write_op")

    # ===== epoch, validation
    epoch_validation_loss_scalar = tf.summary.scalar('validation_epoch_loss', val_mean_loss)
    epoch_validation_acc_scalar = tf.summary.scalar('validation_epoch_acc', val_acc)
    epoch_validation_auc_scalar = tf.summary.scalar('validation_epoch_auc', val_auc)
    epoch_validation_write_op = tf.summary.merge([epoch_validation_loss_scalar, epoch_validation_acc_scalar, epoch_validation_auc_scalar], name="epoch_validation_write_op")

    # ====== batch, train
    train_batch_loss_scalar = tf.summary.scalar('train_batch_loss', batch_loss)
    train_batch_acc_scalar = tf.summary.scalar('train_batch_acc', train_batch_acc)
    train_batch_write_op = tf.summary.merge([train_batch_loss_scalar, train_batch_acc_scalar], name="train_batch_write_op")

    # ====== checkpoint, validation
    checkpoint_validation_loss_scalar = tf.summary.scalar('validation_batch_loss', batch_loss)
    checkpoint_validation_acc_scalar = tf.summary.scalar('validation_batch_acc', val_batch_acc)
    checkpoint_validation_write_op = tf.summary.merge([checkpoint_validation_loss_scalar, checkpoint_validation_acc_scalar], name="checkpoint_valdiation_write_op")


# Training

## Backprop (locked)

In [19]:
# backprop
with backprop_sess.as_default():
    backprop_train_path = os.path.join("tf_logs","backprop","train")
    backprop_train_writer = tf.summary.FileWriter(backprop_train_path)
    backprop_validation_path = os.path.join("tf_logs","backprop","validation")
    backprop_validation_writer = tf.summary.FileWriter(backprop_validation_path)
    
    backprop_sess.run([init_global,init_local])
    
    for e in tqdm(range(1,data_params['n_epochs']+1)):
        backprop_sess.run([val_acc_reset_op,val_loss_reset_op,train_acc_reset_op,train_loss_reset_op])
        
        n_batches = int(MNIST.train.num_examples/data_params['batch_size'])
        for i in range(1,n_batches+1):
            data, target = MNIST.train.next_batch(data_params['batch_size'])
            backprop_sess.run([backprop_opt, train_auc_update, train_acc_update, train_mean_loss_update], feed_dict={X:data, y:target})
        
        # write average for epoch
        summary = backprop_sess.run(epoch_train_write_op)    
        backprop_train_writer.add_summary(summary, e)
        backprop_train_writer.flush()
        
        # run validation
        n_batches = int(MNIST.validation.num_examples/data_params['batch_size'])
        for i in range(1,n_batches+1):
            Xb, yb = MNIST.validation.next_batch(data_params['batch_size'])
            backprop_sess.run([val_auc_update, val_acc_update, val_mean_loss_update], feed_dict={X:Xb, y:yb})
        
        #summary = backprop_sess.run([summary_op], feed_dict={X:Xb, y:yb})[0]
        summary = backprop_sess.run(epoch_validation_write_op) 
        backprop_validation_writer.add_summary(summary, e)
        backprop_validation_writer.flush()
        
    # close writers
    backprop_train_writer.close()
    backprop_validation_writer.close()

100%|██████████| 20/20 [01:25<00:00,  4.29s/it]


In [21]:
with syntgrad_sess.as_default():
    sg_train_path = os.path.join("tf_logs","sg","train")
    sg_train_writer = tf.summary.FileWriter(sg_train_path, syntgrad_sess.graph)
    sg_validation_path = os.path.join("tf_logs","sg","validation")
    sg_validation_writer = tf.summary.FileWriter(sg_validation_path)
    
    syntgrad_sess.run([init_global,init_local])
    
    for e in tqdm(range(1,data_params['n_epochs']+1)):
        syntgrad_sess.run([val_acc_reset_op,val_loss_reset_op,train_acc_reset_op,train_loss_reset_op])
        
        n_batches = int(MNIST.train.num_examples/data_params['batch_size'])
        for i in range(1,n_batches+1):
            data, target = MNIST.train.next_batch(data_params['batch_size'])
            # The layers here could be independently updated (data parallism) - device placement
            # > stochastic updates are possible
            if random.random() <= data_params['update_prob']:
                syntgrad_sess.run([layer1_opt], feed_dict={X:data, y:target})
            if random.random() <= data_params['update_prob']:
                syntgrad_sess.run([layer2_opt, sg2_opt], feed_dict={X:data, y:target})
            if random.random() <= data_params['update_prob']:
                syntgrad_sess.run([layer3_opt, sg3_opt], feed_dict={X:data, y:target})
            if random.random() <= data_params['update_prob']:
                syntgrad_sess.run([layer5_opt, sg5_opt], feed_dict={X:data, y:target})
            if random.random() <= data_params['update_prob']:
                syntgrad_sess.run([layer6_opt, sg6_opt, train_auc_update, train_acc_update, train_mean_loss_update], 
                                  feed_dict={X:data, y:target})
                
        # write average for epoch
        summary = syntgrad_sess.run(epoch_train_write_op)    
        sg_train_writer.add_summary(summary, e)
        sg_train_writer.flush()
            
        # validation
        n_batches = int(MNIST.validation.num_examples/data_params['batch_size'])
        for i in range(1,n_batches+1):
            Xb, yb = MNIST.validation.next_batch(data_params['batch_size'])
            syntgrad_sess.run([val_auc_update, val_acc_update, val_mean_loss_update], feed_dict={X:Xb, y:yb})

        summary = syntgrad_sess.run(epoch_validation_write_op) 
        sg_validation_writer.add_summary(summary, e)
        sg_validation_writer.flush()
        
    sg_train_writer.close()
    sg_validation_writer.close()

100%|██████████| 20/20 [00:36<00:00,  1.82s/it]


In [22]:
# Test using backprop
# batching isn't working "perfectly".. we seem to iterate over a slightly different test
# set each time (even if the batch size is a multiple of the num_examples)..
with backprop_sess.as_default():
    backprop_sess.run([test_acc_reset_op, test_loss_reset_op])
    
    n_batches = int(MNIST.test.num_examples/data_params['batch_size'])
    for i in tqdm(range(n_batches)):
        Xb, yb = MNIST.test.next_batch(data_params['batch_size'])
        batch_accuracy, batch_loss, batch_auc = backprop_sess.run([test_acc_update, test_mean_loss_update, test_auc_update], 
                                                                  feed_dict={X:Xb,y:yb})
    # print
    final_test_acc, final_test_loss, final_test_auc = backprop_sess.run([test_acc, test_mean_loss, test_auc])
    print("test auc: {:.3f}% acc: {:.3f}% loss: {:.5f}".format(final_test_auc*100, 
                                                              final_test_acc*100,
                                                              final_test_loss))

100%|██████████| 19/19 [00:00<00:00, 42.28it/s]

test auc: 99.130% acc: 90.234% loss: 0.34111





In [23]:
# Now use synthetic grad
with syntgrad_sess.as_default():
    syntgrad_sess.run([test_acc_reset_op, test_loss_reset_op])
    
    n_batches = int(MNIST.test.num_examples/data_params['batch_size'])
    for i in tqdm(range(n_batches)):
        Xb, yb = MNIST.test.next_batch(data_params['batch_size'])
        batch_accuracy, batch_loss, batch_auc = syntgrad_sess.run([test_acc_update, test_mean_loss_update, test_auc_update], 
                                                                  feed_dict={X:Xb,y:yb})
    # print
    final_test_acc, final_test_loss, final_test_auc = syntgrad_sess.run([test_acc, test_mean_loss, test_auc])
    print("test auc: {:.3f}% acc: {:.3f}% loss: {:.5f}".format(final_test_auc*100, 
                                                              final_test_acc*100,
                                                              final_test_loss))

100%|██████████| 19/19 [00:00<00:00, 44.99it/s]

test auc: 74.709% acc: 35.310% loss: 2.26109





In [23]:
# Cleanup
backprop_sess.close()
syntgrad_sess.close()