# Custom Op Gradients in TensorFlow
In my [forward prop implementation](https://github.com/IdRatherBeCoding/sparse_cnn/blob/master/sparse_cnn.ipynb) for sparse CNNs, I used [tf.py_func](https://www.tensorflow.org/api_docs/python/tf/py_func) to create a custom op to build $H_\mathrm{out}$ and $Q$ from the sparse representation of the previous layer activations, $a^{[l-1]}$. The output activations are computed from Q using TensorFlow matmul and relu ($g$) ops:

\begin{equation*}
a^{[l]} = g(Q(a^{[l-1]})\cdot W + b).
\end{equation*}

Since we are using TensorFlow ops to compute the matrix product and relu, TensorFlow will handle the derivatives for $g$ and the $Q.W$ product; we only have to implement the gradient of the custom py_func op itself. Specifically, given the gradient of the Loss with respect to our function's output, $\frac{\partial L}{\partial Q}$, our gradient function needs to compute

\begin{equation*}
\frac{\partial L}{\partial a^{[l-1]}_{ij}} = \sum_{pq} \frac{\partial L}{\partial Q_{pq}} \frac{\partial Q_{pq}}{\partial a^{[l-1]}_{ij}}.
\end{equation*}

Recall that in place of the dense input activations $a^{[l]}$, we are using custom sparse representations: *SparseDataValue* and *SparseDataTensor*. These store the indices of the active sites, $H_\mathrm{in}$, the values of the active sites, $M_\mathrm{in}$, the dense shape and the ground state value for each channel. I will now introduce the notation $t_{\mathrm{in},c}$ to represent the ground-state value of the $c^\mathrm{th}$ channel.

To enable back propagation, we will need to provide gradients with respect to $M_\mathrm{in}$ and $t_\mathrm{in}$. Dropping the subscript $\mathrm{in}$ for clarity:

\begin{equation*}
\frac{\partial L}{\partial M_{ij}} = \sum_{pq} \frac{\partial L}{\partial Q_{pq}} \frac{\partial Q_{pq}}{\partial M_{ij}}\\
\frac{\partial L}{\partial t_{c}} = \sum_{pq} \frac{\partial L}{\partial Q_{pq}} \frac{\partial Q_{pq}}{\partial t_{c}}
\end{equation*}

## Gradients of py_func ops
I came across several discussions concerning this ([issue#1095](https://github.com/tensorflow/tensorflow/issues/1095), [SO1](https://datascience.stackexchange.com/questions/12974/tensorflow-how-to-set-gradient-of-an-external-process-py-func), [issue#3710](https://github.com/tensorflow/tensorflow/issues/3710), [SO2](https://stackoverflow.com/questions/38833934/write-custom-python-based-gradient-function-for-an-operation-without-c-imple)), but there doesn't appear to be an official guide specifically for py_func ops.

The [adding an op](https://www.tensorflow.org/extend/adding_an_op#implement_the_gradient_in_python) guide describes how to register a gradient function using the [tf.RegisterGradient](https://www.tensorflow.org/api_docs/python/tf/RegisterGradient) decorator for an Op registered in C++. Unfortunately, RegisterGradient only registers functions to ops by type name. Since we're using py_func, the type of our custom op is always PyFunc. From the links above, there are two possible approaches: *Defun* and *gradient_override_map*.

## The Defun approach
Based on [this SO answer](https://stackoverflow.com/questions/38833934/write-custom-python-based-gradient-function-for-an-operation-without-c-imple). It it only [experimental](https://github.com/tensorflow/tensorflow/issues/14080) and [not ready for py_func](https://github.com/tensorflow/tensorflow/issues/10282), which I'll show below.

### Simple example: custom gradient for tf.square

In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import function

In [2]:
def squared_back_prop(op, grad):
    return tf.multiply(op.inputs[0] * 2.0, grad)

@function.Defun(tf.float32, python_grad_func=squared_back_prop)
def squared_forward_prop(a):
    return tf.square(a)

In [3]:
tf.reset_default_graph()

x = tf.Variable(tf.constant(np.array([1., 2., 3., 4.]), dtype=tf.float32))
x2 = squared_forward_prop(x)
L = tf.reduce_sum(x2)
dL = tf.gradients(L, [x])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dL))
    print("error:", tf.test.compute_gradient_error(x, [4], L, [1]))

[array([ 2.,  4.,  6.,  8.], dtype=float32)]
error: 6.85453414917e-05


### Defun example with py_func

In [4]:
def square_numpy(x):
    return np.square(x)

@function.Defun(tf.float32, python_grad_func=squared_back_prop)
def squared_forward_prop_py_func(a):
    return tf.py_func(square_numpy, [a], tf.float32)

In [5]:
x2 = squared_forward_prop_py_func(x)
L = tf.reduce_sum(x2)
dL = tf.gradients(L, [x])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    try:
        print(sess.run(dL))
    except:
        pass

[array([ 2.,  4.,  6.,  8.], dtype=float32)]


## The gradient_override_map approach
I will use the approach suggested in [issue#1095](https://github.com/tensorflow/tensorflow/issues/1095), and demonstrated in [this gist](https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342).

A custom py_func function is defined, which takes a grad function. The grad function is given a random name and registered with tf.RegisterGradient.

Finally, *gradient_override_map* is called before calling tf.py_func.

In [6]:
from tensorflow.python.framework import ops

# directly taken from https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342#gistcomment-2011084
#
# Define custom py_func which takes also a grad op as argument:
def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
    
    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
    
    tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

# Actual gradient:
def _MySquareGrad(op, grad):
    x = op.inputs[0]
    return grad * 2 * x  # add a "small" error just to see the difference:

# Def custom square function using np.square instead of tf.square:
def mysquare(x, name=None):
    
    with ops.name_scope(name, "Mysquare", [x]) as name:
        sqr_x = py_func(np.square,
                        [x],
                        [tf.float32],
                        name=name,
                        grad=_MySquareGrad)  # <-- here's the call to the gradient
        return sqr_x[0]

In [7]:
tf.reset_default_graph()

x = tf.Variable(tf.constant(np.array([1., 2., 3., 4.]), dtype=tf.float32))
x2 = mysquare(x)
L = tf.reduce_sum(x2)
dL = tf.gradients(L, [x])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dL))
    print("error:", tf.test.compute_gradient_error(x, [4], L, [1]))

[array([ 2.,  4.,  6.,  8.], dtype=float32)]
error: 6.30617141724e-05


##### Great, that worked. Now let's try with a py_func op for the gradient too.

In [8]:
def _MyCubeGrad(op, grad):
    name = "MyCubeGrad"
    x = op.inputs[0]
    cube_x_grad = py_func(lambda a: np.power(a, 2) * 3,
                    [x],
                    [tf.float32],
                    name=name)
    return cube_x_grad[0]

def my_cube(x, name=None):
    
    with ops.name_scope(name, "MyCube", [x]) as name:
        cube_x = py_func(lambda a: np.power(a, 3),
                        [x],
                        [tf.float32],
                        name=name,
                        grad=_MyCubeGrad)
        return cube_x[0]

In [9]:
tf.reset_default_graph()

x = tf.Variable(tf.constant(np.array([1., 2., 3., 4.]), dtype=tf.float32))
x3 = my_cube(x)
L = tf.reduce_sum(x3)
dL = tf.gradients(L, [x])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(dL))
    print("error:", tf.test.compute_gradient_error(x, [4], L, [1]))

[array([  3.,  12.,  27.,  48.], dtype=float32)]
error: 2.02655792236e-05


##### Ok, that's all good, now we can implement the gradient of Q.

## Compute $\frac{\partial L}{\partial M_\mathrm{in}}$
We need to compute the derivative with respect to the active-site values, $M_\mathrm{in}$,

\begin{equation*}
\frac{\partial L}{\partial M_{\mathrm{in},ij}} = \sum_{pq} \frac{\partial L}{\partial Q_{pq}} \frac{\partial Q_{pq}}{\partial M_{\mathrm{in},ij}}.
\end{equation*}

and the derivative with respect to the ground-state values, $t$,

\begin{equation*}
\frac{\partial L}{\partial t_{c}} = \sum_{pq} \frac{\partial L}{\partial Q_{pq}} \frac{\partial Q_{pq}}{\partial t_{c}}
\end{equation*}

Recall how Q is constructed: each row corresponds to an active site in the output; the values in each row correspond to the elements of $M_\mathrm{in}$, ordered according to the filter weights to which their are visible. The derivative $\frac{\partial Q_{pq}}{\partial M_{\mathrm{in},ij}}$ is equal to 1 when the value $Q_{pq}$ was taken from $M_{\mathrm{in},ij}$, otherwise it is equal to zero. So a given element $ij$ of the loss gradient is the sum of the $\frac{\partial L}{\partial Q}$ elements for which $Q$ was assigned the value of $M_{\mathrm{in},ij}$. A quick way to implement this is to take the loop structure used to build $Q$.

For $\frac{\partial L}{\partial t_{c}}$ we need to sum all the $\frac{\partial L}{\partial Q_{pq}}$ values for which index $q$ corresponds to the ground state of channel $c$. To achieve this we will reshape Q from $(a_\mathrm{out}, f^2n_\mathrm{in})$ to $(a_\mathrm{out}, f^2, n_\mathrm{in})$ and initialize the output array as the sum over the first two axes. All we need to do next is subtract out the active-site values, which can be done at the same time as building $\frac{\partial L}{\partial Q}$.

In [10]:
def grad_Q(dLdQ, *inputs):
    (H_in, M_in, dense_shape, f, n_in, ground_state) = inputs

    dM = np.zeros_like(M_in)

    height = dense_shape[0]
    width = dense_shape[1]

    output_sites = {}
    # enumerate all output active sites and store the positions
    # these could be reused from forward prop with a slight refactoring
    i_out = 0
    for [row, col] in H_in:
        for i, j in filter_positions(row, col, height, width, f):
            if (i, j) not in output_sites:
                output_sites[(i, j)] = i_out
                i_out += 1

    a_out = i_out
    # initialize dt by summing over all elements of dLdQ for each channel
    dt = np.sum(dLdQ.reshape((a_out, f*f, n_in)), axis=(0, 1))
    
    for idx, [row, col] in enumerate(H_in):
        # summing with explicit loops could be replaced by generating list of index permutations and summing slices
        for i, j in filter_positions(row, col, height, width, f):
            i_out = output_sites[(i, j)]
            for i_val in range(n_in):
                d = dLdQ[i_out, position_in_filter(i, j, row, col, f, i_val, n_in)]
                dM[idx, i_val] += d
                dt[i_val] -= d        
       
    return [
        dM,
        dt
    ]

def _grad_Q(op, *grads):
    dM, dt = tf.py_func(grad_Q, [grads[1], *op.inputs], [op.inputs[1].dtype, op.inputs[5].dtype])
    return [None, dM, None, None, None, dt]

Now use the modified py_func function to create the forward prop operation.

In [11]:
from sparse_cnn_tensorflow.sparse_cnn import build_h_out_and_Q, next_ground_state, filter_positions, position_in_filter
from sparse_cnn_tensorflow.sparse_data_tensor import SparseDataTensor

def sparse_conv_2d(sparse_input, W, f, n_out, b):
    H_in = sparse_input.H
    M_in = sparse_input.M
    dense_shape = sparse_input.dense_shape
    n_in = dense_shape[2]
    ground_state = sparse_input.ground_state

    output_spatial_shape = (dense_shape[0] - f + 1, dense_shape[1] - f + 1)

    H_out, Q = py_func(build_h_out_and_Q,
                          [H_in, M_in, dense_shape, f, n_in, ground_state],
                          [H_in.dtype, M_in.dtype], grad=_grad_Q)

    M_out = tf.add(tf.matmul(Q, W), b)

    output_dense_shape = (output_spatial_shape[0], output_spatial_shape[1], n_out)

    output_ground_state = next_ground_state(W, ground_state, f) + b

    return SparseDataTensor(H_out, M_out, output_dense_shape, output_ground_state)

In [12]:
from sparse_cnn_tensorflow.sparse_data_tensor import SparseDataValue

# using double precision to improve finite difference accuracy of tf.test.compute_gradient_error.

f1 = 2
n_in_1 = 2
n_out_1 = 4

W1 = tf.Variable(np.random.rand(f1*f1*n_in_1, n_out_1), dtype=tf.float64)
b1 = tf.Variable(np.random.rand(n_out_1), dtype=tf.float64)

f2 = 2
n_in_2 = n_out_1
n_out_2 = 8

W2 = tf.Variable(np.random.rand(f2*f2*n_in_2, n_out_2), dtype=tf.float64)
b2 = tf.Variable(np.random.rand(n_out_2), dtype=tf.float64)

x_dense = np.array([
    [[1.7, 0.7], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
    [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
    [[0.0, 0.0], [0.0, 0.0], [7.9, 0.9], [4.8, 0.8]],
    [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]
], dtype=np.float64)

x_sparse = SparseDataValue(x_dense)

sparse_tensor = SparseDataTensor(
        tf.constant(x_sparse.H),
        tf.constant(x_sparse.M),
        x_sparse.dense_shape,
        tf.constant(x_sparse.ground_state))

forward1 = sparse_conv_2d(sparse_tensor, W1, f1, n_out_1, b1)
# gradient1 = tf.gradients(forward1.M, sparse_tensor.M)
forward2 = sparse_conv_2d(forward1, W2, f2, n_out_2, b2)
# gradient2 = tf.gradients(forward2.M, sparse_tensor.M)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
#     print(sess.run(gradient[0]))
    print("error dM1/dM0:", tf.test.compute_gradient_error(sparse_tensor.M, [3, 2], forward1.M, [5, 4]))
    print("error dM1/dt1:", tf.test.compute_gradient_error(sparse_tensor.ground_state, [2], forward1.M, [5, 4]))
    print("error dM2/dM0:", tf.test.compute_gradient_error(sparse_tensor.M, [3, 2], forward2.M, [4, 8]))
    print("error dM2/dt0:", tf.test.compute_gradient_error(sparse_tensor.ground_state, [2], forward2.M, [4, 8]))
    print("error dt1/dt0:", tf.test.compute_gradient_error(sparse_tensor.ground_state, [2], forward1.ground_state, [4]))
    print("error dt2/dt0:", tf.test.compute_gradient_error(sparse_tensor.ground_state, [2], forward2.ground_state, [8]))

error dM1/dM0: 1.89182003396e-13
error dM1/dt1: 1.4199752485e-12
error dM2/dM0: 1.18438592267e-12
error dM2/dt0: 9.92272930489e-12
error dt1/dt0: 5.50670620214e-14
error dt2/dt0: 2.90967250294e-12
