# Custom Gradient

This is a brief demonstration of tensorflow [custom gradients](https://www.tensorflow.org/api_docs/python/tf/custom_gradient)

## Chain rule

Lets say we have a function $f(x) = x^2$. If we now compose this function such that $y = f(f(f(x)))$. Now we want to find the gradient $\frac{dy}{dx}$.

We first decompose this into

\begin{align*}
  y &= x_0^2\\
  x_0 &= x_1^2\\
  x_1 &= x^2
\end{align*}

On taking first order derivative, we get

\begin{align*}
  \frac{dy}{dx_0} &= 2x_0\\
  \frac{dx_0}{dx_1} &= 2x_1\\
  \frac{dx_1}{dx} &= 2x
\end{align*}

Using chain rule

\begin{equation*}
  \frac{dy}{dx} = \frac{dy}{dx_0}  \frac{dx_0}{dx_1}  \frac{dx_1}{dx}
\end{equation*}

To generalize

\begin{equation*}
  \frac{dy}{dx} = \frac{dy}{dx_{0}}  ...  \frac{dx_i}{dx_{i+1}}  ...  \frac{dx_n}{dx}
\end{equation*}

In tensorflow the `upstream` gradient is passed as an argument to the inner function `grad`.

\begin{equation*}
  upstream = \frac{dx_{i+1}}{dx_{i+2}}  ...  \frac{dx_n}{dx}
\end{equation*}

Now we can multiply this upstream gradient to the gradient of the current function $\frac{dx_{i}}{dx_{i+1}}$ and pass it downstream.

\begin{equation*}
  downstream = \frac{dx_{i}}{dx_{i+1}}  * upstream
\end{equation*}


In [1]:
import tensorflow as tf

In [2]:
# @tf.function
@tf.custom_gradient
def foo(x):
    tf.debugging.assert_rank(x, 0)

    def grad(dy_dx_upstream):
        dy_dx = 2 * x
        dy_dx_downstream = dy_dx * dy_dx_upstream
        tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
        return dy_dx_downstream
    
    y = x ** 2
    tf.print(f'x={x}\ty={y}')
    
    return y, grad


x = tf.constant(2.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = foo(foo(foo(x))) # y = x ** 8

tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')

x=2.0	y=4.0
x=4.0	y=16.0
x=16.0	y=256.0
x=16.0	upstream=1.0	current=32.0		downstream=32.0
x=4.0	upstream=32.0	current=8.0		downstream=256.0
x=2.0	upstream=256.0	current=4.0		downstream=1024.0

final dy/dx=1024.0


### Gradients with multiple variables

If the function takes multiple variables, then the gradient for each variable has to be returned as demonstrated in the example.

In [3]:
@tf.custom_gradient
def bar(x, y):
    tf.debugging.assert_rank(x, 0)
    tf.debugging.assert_rank(y, 0)

    def grad(upstream):
        dz_dx = y
        dz_dy = x
        return upstream * dz_dx, upstream * dz_dy
    
    z = x * y
    
    return z, grad

x = tf.constant(2.0, dtype=tf.float32)
y = tf.constant(3.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    tape.watch(y)
    z = bar(x, y)

tf.print(z)
tf.print(tape.gradient(z, x))
tf.print(tape.gradient(z, y))
tf.print(tape.gradient(x, y))

6
3
2
None


## Application of custom gradients

### Toy example: Differentiable approximation of non-differentiable functions

We take the sign function as an example

\begin{equation}
sign(x)= \\
\begin{cases}
  -1, & \text{if}\ x<0 \\
  0, & \text{if}\ x=0 \\
  1, & \text{if}\ x>0 \\
\end{cases}\end{equation}
  
By implementing a custom gradient, we can continue to have the $sign(x)$ function in forward pass but a differentiable approximation in the backward pass. In this case we approximate $sign(x)$ with the sigmoid function $ \sigma(x)$

\begin{equation}
\frac{dsign_{approx}(x)}{dx} = \sigma(x) (1 - \sigma(x)) \\
sign_{approx}(x) = \sigma(x) + C \\
\end{equation}

In [4]:
# @tf.function
@tf.custom_gradient
def differentiable_sign(x):
    tf.debugging.assert_rank(x, 0)

    def grad(upstream):
        dy_dx = tf.math.sigmoid(x) * (1 - tf.math.sigmoid(x))
        return upstream * dy_dx
    
    if x > tf.constant(0.0):
        return tf.constant(1.0), grad
    else:
        return tf.constant(-1.0), grad


x = tf.constant(3.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = differentiable_sign(x)
    loss = tf.nn.l2_loss(y - tf.constant(-1.0))
    
tf.print(y)
tf.print(tape.gradient(y, x))
tf.print(loss)
tf.print(tape.gradient(loss, x))

1
0.0451766551
2
0.0903533101


In [5]:
x = tf.Variable(-1.0)
opt = tf.keras.optimizers.Adam(1e-1)
# opt = tf.keras.optimizers.SGD(1)

def train_step():
    with tf.GradientTape() as tape:
        y = differentiable_sign(x)
        loss = tf.nn.l2_loss(y - tf.constant(1.0))
    grads = tape.gradient(loss, x)
    opt.apply_gradients(zip([grads], [x]))
    return loss, y, grads

for i in range(100):
    loss, y, grads = train_step()
    if i % 10 == 0:
        tf.print(i, loss, grads, x, y)

0 2 -0.393223882 -0.89999783 -1
10 0 0 0.0995881185 1
20 0 0 0.6450876 1
30 0 0 0.855591536 1
40 0 0 0.938390434 1
50 0 0 0.970632374 1
60 0 0 0.983009696 1
70 0 0 0.987700462 1
80 0 0 0.989459515 1
90 0 0 0.990113616 1
