# Custom Gradient Functions in TensorFlow

This notebook demonstrates how to implement custom gradient functions in TensorFlow using the [`@tf.custom_gradient`](https://www.tensorflow.org/api_docs/python/tf/custom_gradient) decorator.

## Why Custom Gradients?

Custom gradients are essential for differentiable programming because they allow us to:

- **Define gradients for non-differentiable functions** (like sign, argmax, etc.)
- **Improve numerical stability** by providing more stable gradient computations
- **Implement domain-specific optimizations** for gradient calculations
- **Create differentiable approximations** of discrete operations
- **Control gradient flow** in complex computational graphs

This is a fundamental tool for making traditionally non-differentiable algorithms learnable through backpropagation.

## Chain Rule Fundamentals

Before diving into custom gradients, let's review how the chain rule works in automatic differentiation.

Consider a function $f(x) = x^2$, and we compose it as $y = f(f(f(x)))$. To find $\frac{dy}{dx}$, we decompose this into intermediate variables:

\begin{align}
  y &= x_0^2\\
  x_0 &= x_1^2\\
  x_1 &= x^2
\end{align}

Taking first-order derivatives of each step:

\begin{align}
  \frac{dy}{dx_0} &= 2x_0\\
  \frac{dx_0}{dx_1} &= 2x_1\\
  \frac{dx_1}{dx} &= 2x
\end{align}

Applying the chain rule:

$$\frac{dy}{dx} = \frac{dy}{dx_0} \cdot \frac{dx_0}{dx_1} \cdot \frac{dx_1}{dx}$$

**General Form**: For a computational graph with $n$ intermediate steps:

$$\frac{dy}{dx} = \frac{dy}{dx_{0}} \prod_{i=0}^{n-1} \frac{dx_i}{dx_{i+1}}$$

### TensorFlow's Gradient Flow Mechanism

In TensorFlow's automatic differentiation, gradients flow backward through the computational graph. Each function receives an **upstream gradient** and computes a **downstream gradient**.

**Upstream Gradient**: The gradient flowing from later operations in the graph:
$$\text{upstream} = \frac{dy}{dx_{i-1}} = \frac{dy}{dx_{i-2}} \cdot \frac{dx_{i-2}}{dx_{i-1}}$$

**Current Function Gradient**: The local gradient of the current operation:
$$\frac{dx_i}{dx_{i+1}}$$

**Downstream Gradient**: The product passed to earlier operations:
$$\text{downstream} = \frac{dx_i}{dx_{i+1}} \times \text{upstream}$$

This mechanism allows each operation to contribute its local gradient while preserving the chain rule structure. Custom gradients integrate seamlessly into this flow by implementing the `grad` function that receives upstream gradients and returns downstream gradients.

In [1]:
import tensorflow as tf

## Basic Custom Gradient Implementation

Let's implement a simple custom gradient for $f(x) = x^2$. The debug prints will show how gradients flow backward through nested function calls:

## Multi-Variable Custom Gradients

For functions with multiple inputs, the custom gradient function must return gradients for each input variable in the same order as the function parameters.

### Example: Product Function
For $z = f(x,y) = xy$, we have:
- $\frac{\partial z}{\partial x} = y$
- $\frac{\partial z}{\partial y} = x$

The gradient function returns `(upstream * y, upstream * x)` to properly distribute the upstream gradient to both input variables.

In [2]:
# @tf.function
@tf.custom_gradient
def foo(x):
    tf.debugging.assert_rank(x, 0)

    def grad(dy_dx_upstream):
        dy_dx = 2 * x
        dy_dx_downstream = dy_dx * dy_dx_upstream
        tf.print(f'x={x}\tupstream={dy_dx_upstream}\tcurrent={dy_dx}\t\tdownstream={dy_dx_downstream}')
        return dy_dx_downstream
    
    y = x ** 2
    tf.print(f'x={x}\ty={y}')
    
    return y, grad


x = tf.constant(2.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = foo(foo(foo(x))) # y = x ** 8

tf.print(f'\nfinal dy/dx={tape.gradient(y, x)}')

x=2.0	y=4.0
x=4.0	y=16.0
x=16.0	y=256.0
x=16.0	upstream=1.0	current=32.0		downstream=32.0
x=4.0	upstream=32.0	current=8.0		downstream=256.0
x=2.0	upstream=256.0	current=4.0		downstream=1024.0

final dy/dx=1024.0


## Practical Application: Differentiable Approximations

One of the most powerful applications of custom gradients is making non-differentiable functions differentiable by providing smooth approximations for the backward pass.

### Case Study: Differentiable Sign Function

The sign function is non-differentiable:

$$\text{sign}(x) = \begin{cases}
  -1, & \text{if } x < 0 \\
  0, & \text{if } x = 0 \\
  1, & \text{if } x > 0
\end{cases}$$

**Strategy**: Keep the discrete sign function in the forward pass, but use a smooth approximation for gradients.

**Gradient Approximation**: We use the sigmoid derivative:
$$\frac{d\text{sign}_{\text{approx}}(x)}{dx} = \sigma(x)(1 - \sigma(x))$$

where $\sigma(x) = \frac{1}{1 + e^{-x}}$ is the sigmoid function.

**Why This Works**: 
- Forward pass maintains the exact discrete behavior we want
- Backward pass provides smooth gradients that enable optimization
- The sigmoid derivative is bell-shaped, providing strongest gradients near $x=0$ where the sign function transitions

### Gradients with multiple variables

If the function takes multiple variables, then the gradient for each variable has to be returned as demonstrated in the example.

### Testing the Differentiable Sign Function

Let's test our differentiable sign function:
- **Forward pass**: `sign(3.0) = 1.0` (correct discrete behavior)
- **Gradient**: `≈ 0.045` (smooth, non-zero gradient enables optimization)
- **Loss computation**: We can compute loss and its gradient, enabling training!

### Training with the Differentiable Sign Function

Now let's train a parameter to minimize the loss `L = ||sign(x) - 1||²`. We start with `x = -1` and want to find `x` such that `sign(x) = 1`.

**Expected Behavior**: The optimizer should drive `x` toward positive values to make `sign(x) = 1`.

**Key Insight**: Without the custom gradient, this would be impossible because `sign(x)` has zero gradient almost everywhere. With our smooth approximation, the optimizer can "see" which direction to move!

### Training Success! 🎉

Perfect results! The training shows:

1. **Initial State**: `x = -1`, `sign(x) = -1`, `loss = 2`
2. **Optimization**: Gradients guide `x` toward positive values
3. **Final State**: `x ≈ 0.99`, `sign(x) = 1`, `loss = 0`

**Key Achievements**:
- ✅ **Non-differentiable function made trainable**: Sign function integrated into gradient-based optimization
- ✅ **Discrete behavior preserved**: Forward pass maintains exact sign function behavior  
- ✅ **Smooth optimization**: Custom gradients enable efficient parameter updates
- ✅ **Perfect convergence**: Loss reaches zero, target behavior achieved

## Summary

Custom gradients are essential for differentiable programming because they enable:

1. **Making discrete operations continuous** for optimization
2. **Maintaining exact forward behavior** while providing smooth gradients
3. **Integrating non-differentiable functions** into neural network architectures
4. **Enabling end-to-end training** of complex algorithms

This technique is fundamental to many advanced differentiable programming applications, from differentiable sorting to program synthesis!

In [3]:
@tf.custom_gradient
def bar(x, y):
    tf.debugging.assert_rank(x, 0)
    tf.debugging.assert_rank(y, 0)

    def grad(upstream):
        dz_dx = y
        dz_dy = x
        return upstream * dz_dx, upstream * dz_dy
    
    z = x * y
    
    return z, grad

x = tf.constant(2.0, dtype=tf.float32)
y = tf.constant(3.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    tape.watch(y)
    z = bar(x, y)

tf.print(z)
tf.print(tape.gradient(z, x))
tf.print(tape.gradient(z, y))
tf.print(tape.gradient(x, y))

6
3
2
None


## Application of custom gradients

### Toy example: Differentiable approximation of non-differentiable functions

We take the sign function as an example

\begin{equation}
sign(x)= \\
\begin{cases}
  -1, & \text{if}\ x<0 \\
  0, & \text{if}\ x=0 \\
  1, & \text{if}\ x>0 \\
\end{cases}\end{equation}
  
By implementing a custom gradient, we can continue to have the $sign(x)$ function in forward pass but a differentiable approximation in the backward pass. In this case we approximate $sign(x)$ with the sigmoid function $ \sigma(x)$

\begin{equation}
\frac{dsign_{approx}(x)}{dx} = \sigma(x) (1 - \sigma(x)) \\
sign_{approx}(x) = \sigma(x) + C \\
\end{equation}

In [4]:
# @tf.function
@tf.custom_gradient
def differentiable_sign(x):
    tf.debugging.assert_rank(x, 0)

    def grad(upstream):
        dy_dx = tf.math.sigmoid(x) * (1 - tf.math.sigmoid(x))
        return upstream * dy_dx
    
    if x > tf.constant(0.0):
        return tf.constant(1.0), grad
    else:
        return tf.constant(-1.0), grad


x = tf.constant(3.0, dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(x)
    y = differentiable_sign(x)
    loss = tf.nn.l2_loss(y - tf.constant(-1.0))
    
tf.print(y)
tf.print(tape.gradient(y, x))
tf.print(loss)
tf.print(tape.gradient(loss, x))

1
0.0451766551
2
0.0903533101


In [5]:
x = tf.Variable(-1.0)
opt = tf.keras.optimizers.Adam(1e-1)
# opt = tf.keras.optimizers.SGD(1)

def train_step():
    with tf.GradientTape() as tape:
        y = differentiable_sign(x)
        loss = tf.nn.l2_loss(y - tf.constant(1.0))
    grads = tape.gradient(loss, x)
    opt.apply_gradients(zip([grads], [x]))
    return loss, y, grads

for i in range(100):
    loss, y, grads = train_step()
    if i % 10 == 0:
        tf.print(i, loss, grads, x, y)

0 2 -0.393223882 -0.89999783 -1
10 0 0 0.0995881185 1
20 0 0 0.6450876 1
30 0 0 0.855591536 1
40 0 0 0.938390434 1
50 0 0 0.970632374 1
60 0 0 0.983009696 1
70 0 0 0.987700462 1
80 0 0 0.989459515 1
90 0 0 0.990113616 1
