# Automatic differentiation lab work

*Notebook prepared by Mathieu Blondel, November 2020.
The accompanying slides are available [here](https://www.mblondel.org/teaching/autodiff-2020.pdf).*

In this lab work, we are going to implement reverse differentiation (a.k.a. backpropagation) for a feedforward network (that is, the composition of a **sequence** or **chain** of functions).

## Numerical differentiation utilities

In this section, I define utility functions for computing Jacobians, Jacobian-vector products (VJPs), and vector Jacobian products (VJPs). You will need to use them to check the correctness of your implementations.

In [None]:
import numpy as np

def num_jvp(f, x, v, eps=1e-6):
  """
  Args:
    f: a function returning an array.
    x: an array.
    v: an array (same shape as x).

  Returns:
    numerical_jvp
  """
  if not np.array_equal(x.shape, v.shape):
    raise ValueError("x and v should have the same shape.")

  return (f(x + eps * v) - f(x - eps * v)) / (2 * eps)

def num_jacobian(f, x, eps=1e-6):
  """
  Args:
    f: a function returning an array.
    x: an array (only 1d and 2d arrays supported).

  Returns:
    numerical_jacobian
  """
  def e(i):
    ret = np.zeros_like(x)
    ret[i] = 1
    return ret

  def E(i, j):
    ret = np.zeros_like(x)
    ret[i, j] = 1
    return ret

  if len(x.shape) == 1:
    return np.array([num_jvp(f, x, e(i), eps=eps) for i in range(len(x))]).T
  elif len(x.shape) == 2:
    return np.array([[num_jvp(f, x, E(i, j), eps=eps) \
                     for i in range(x.shape[0])] \
                     for j in range(x.shape[1])]).T
  else:
    raise NotImplementedError

def num_vjp(f, x, u, eps=1e-6):
  """
  Args:
    f: a function returning an array.
    x: an array (only 1d and 2d arrays supported).

  Returns:
    numerical_vjp
  """
  J = num_jacobian(f, x, eps=eps)
  if len(J.shape) == 2:
    return J.T.dot(u)
  elif len(J.shape) == 3:
    shape = J.shape[1:]
    J = J.reshape(J.shape[0], -1)
    return u.dot(J).reshape(shape)
  else:
    raise NotImplementedError


## Vector Jacobian products (VJPs) for basic primitives

In this section, we are going to define VJPs for basic primitives. 

In [None]:
def dot(x, W):
  return np.dot(W, x)

def dot_make_vjp(x, W):
  def vjp(u):
    return W.T.dot(u), np.outer(u, x)
  return vjp

dot.make_vjp = dot_make_vjp

def squared_loss(y_pred, y):
  # The code requires every output to be an array.
  return np.array([0.5 * np.sum((y - y_pred) ** 2)])

def squared_loss_make_vjp(y_pred, y):
  diff = y_pred - y

  def vjp(u):
    return diff * u, -diff * u

  return vjp

squared_loss.make_vjp = squared_loss_make_vjp

def add(a, b):
  return a + b

def add_make_vjp(a, b):
  gprime = np.ones(len(a))

  def vjp(u):
    return u * gprime, u * gprime

  return vjp

add.make_vjp = add_make_vjp

def mul(a, b):
  return a * b

def mul_make_vjp(a, b):
  gprime_a = b
  gprime_b = a

  def vjp(u):
    return u * gprime_a, u * gprime_b

  return vjp

mul.make_vjp = mul_make_vjp

def exp(x):
  return np.exp(x)

def exp_make_vjp(x):
  gprime = exp(x)

  def vjp(u):
    return u * gprime,

  return vjp

exp.make_vjp = exp_make_vjp

def sqrt(x):
  return np.sqrt(x)

def sqrt_make_vjp(x):
  gprime = 1. / (2 * sqrt(x))

  def vjp(u):
    return u * gprime,

  return vjp

sqrt.make_vjp = sqrt_make_vjp

**Exercise 1** 

Look at the "exp" and "sqrt"  examples above and define the primitive and its associated VJP for the relu function `relu(x) = np.maximum(x, 0)`. Check the correctness of your implementation using the `num_vjp` utility function.

In [None]:
def relu(x):
  return 

rng = np.random.RandomState(0)
x = rng.randn(5)
u = rng.randn(5)

# Check the correctness of your vjp using num_vjp:
# num_vjp(relu.vjp, x, u)

## Reverse differentiation of feedforward networks

Feedforward networks use a sequence of functions. The functions can either be of the form `func(x, param)` if the function has learnable parameters (e.g., `dot(x, W)`) or `func(x)` if the function doesn't have learnable parameters (e.g., `exp(x)`). 

We represent a feedforward network using a list of functions and a list of parameters. Let us create a small utility function for creating such a network.

In [None]:
def create_feed_forward(n, y, seed=0):
  rng = np.random.RandomState(seed)

  funcs = [
    dot,
    relu,
    dot,
    relu,
    dot,
    squared_loss
  ]

  params = [
    rng.randn(3, n),
    None,
    rng.randn(4, 3),
    None,
    rng.randn(1, 4),
    y
  ]

  return funcs, params

Next, let us create a small utility function for correctly calling each function, depending on whether it has 1 or 2 arguments.

In [None]:
def call_func(x, func, param):
  """Make sure the function is called with the correct number of arguments."""

  if param is None:
    # Unary function
    return func(x)
  else:
    # Binary function
    return func(x, param)

**Exercise 2.** 

Implement the following function for evaluating the feedforward network. Check that the returned value is correct.

In [None]:
def evaluate_chain(x, funcs, params, return_all=False):
  """
  Evaluate a chain of functions.

  Args:
    x: initial input to the chain.
    funcs: a list of functions of the form func(x) or func(x, param).
    params: a list of parameters, with len(params) = len(funcs).
            If a function doesn't have parameters, use None.
    return_all: whether to return all intermediate values or only the last one.

  Returns:
    value (return_all == False) or values (return_all=True)
  """
  if len(funcs) != len(params):
    raise ValueError("len(funcs) and len(params) should be equal.")

  if return_all:
    return
  else:
    return

rng = np.random.RandomState(0)
x = rng.randn(2)
y = 1.5

funcs, params = create_feed_forward(n=len(x), y=y, seed=0)
W1, _, W3, _, W5, y = params

# Make sure that `evaluate_chain(x, funcs, params)` returns the same value as
# a manual implementaton:
# x1 = dot(x, W1)
# x2 = relu(x1)
# x3 = dot(x2, W3)
# x4 = relu(x3)
# x5 = dot(x4, W5)
# value = squared_loss(x5, y)

**Exercise 3.**

Reusing the previous function with `return_all=True`, implement the following function that returns both the network value and the Jacobian w.r.t. `x`. Check correctness of the Jacobian using `num_jacobian`.

In [None]:
def call_vjp(x, func, param, u):
  """Make sure the vjp is called with the correct number of arguments."""
  if param is None:
    vjp = func.make_vjp(x)
    vjp_x, = vjp(u)
    vjp_param = None
  else:
    vjp = func.make_vjp(x, param)
    vjp_x, vjp_param = vjp(u)
  return vjp_x, vjp_param


def reverse_diff_chain(x, funcs, params):
  """
  Reverse-mode differentiation of a chain of computations.

  Args:
    x: initial input to the chain.
    funcs: a list of functions of the form func(x) or func(x, param).
    params: a list of parameters, with len(params) = len(funcs).
            If a function doesn't have parameters, use None.

  Returns:
    value, Jacobian w.r.t. x
  """
  # Evaluate the feedforward model and store intermediate computations,
  # as they will be needed during the backward pass.
  xs = evaluate_chain(x, funcs, params, return_all=True)

  m = xs[-1].shape[0]  # Output size
  K = len(funcs)  # Number of functions.

  # We need a list as the shape of U can change.
  U = list(np.eye(m))

  for k in reversed(range(K)):
    # Implement backward differentiation here

  return xs[-1], np.array(U)

# Check correctness of Jacobian using `num_jacobian`.
# def f(x):
#   return evaluate_chain(x, funcs, params)
# # num_jacobian only accepts functions of one argument.
# num_jac = num_jacobian(f, x)
# value, jac = reverse_diff_chain(x, funcs, params)

**Bonus exercise.**

Modify the above function to also return the Jacobians w.r.t. W1, W3, W5. Check correctness using `num_jacobian`.