Written by Dr. Jonathan Terhorst

In [None]:
import tensorflow as tf
import numpy as np

In [None]:
def reloid(x):
  '(sigma(x[1]), ..., sigma(x[-2]), relu(x[-1])'
  return tf.concat([
    tf.nn.sigmoid(x[:-1]),
    tf.nn.relu(x[-1:])
  ], axis=0)

def R(x, theta):
  '''
  x: scalar
  theta: (3, H)
  '''
  w, v, b = theta[..., None]  # convert to vectors
  ew, ev = tf.math.exp([w, v])
  x = tf.reshape(x, (1, 1))
  v0 = tf.transpose(ev) @ reloid(b)
  return tf.transpose(ev) @ reloid(ew @ x + b) - v0

def Rinv(y, theta):
  w, v, b = theta[..., None]  # convert to vectors
  x_left = 0
  # as x -> oo, R is asymyptotic to exp(v[-1] + w[-1]) x
  # fixme: calculate this exactly.
  x_right = 10 * tf.math.exp(w[-1] + v[-1])
  tf.debugging.assert_greater(R(x_right, theta), y)
  for i in range(50):
    x_i = (x_left + x_right) / 2.
    y_i = R(x_i, theta)
    left = tf.cast(y_i < y, dtype=float)
    x_left = left * x_i + (1. - left) * x_left
    x_right = (1. - left) * x_i + left * x_right
  return x_i

theta = tf.convert_to_tensor(np.random.rand(3, 10).astype(np.float32))
x = 5.
y = R(x, theta)
x_star = Rinv(y, theta)
print(x, x_star)

5.0 tf.Tensor([[5.]], shape=(1, 1), dtype=float32)


## Derivative of $R^{-1}$
We have $R(R^{-1}(y, \theta), \theta) = y.$ Therefore, $$0 = \frac{\partial R(R^{-1}(y, \theta), \theta)}{\partial \theta} = \left. \frac{ \partial R(x,\theta)}{\partial x} \right|_{x=R^{-1}(y,\theta)} \times \frac{\partial R^{-1}(y,\theta)}{\partial \theta} + \left.\frac{\partial R(x,\theta)}{\partial \theta}\right|_{x=R^{-1}(y,\theta)}.$$ Hence,
$$\frac{\partial R^{-1}(x,\theta)}{\partial \theta} = \left.-\frac{\partial R(x,\theta)/\partial \theta}{\partial R(x,\theta)/\partial x}\right|_{x=R^{-1}(y,\theta)}.$$



For the derivative w/r/t $y$ we get $$1 = \frac{\partial R(R^{-1}(y, \theta), \theta)}{\partial y} = \left. \frac{\partial{R}}{\partial x} \right|_{x=R^{-1}(y,\theta)} \times \frac{\partial R^{-1}(y,\theta)}{\partial y}$$

In [None]:
@tf.custom_gradient
def custom_Rinv(y, theta):
  x = Rinv(y, theta)
  with tf.GradientTape() as g:
    g.watch([x, theta])
    y = R(x, theta)
  dR_dtheta, dR_dx = g.gradient(y, [theta, x])
  def grad(dx):
    return dx / dR_dx, -dx * dR_dtheta / dR_dx
  return x, grad

custom_Rinv(tf.constant(5.), theta)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.6888485]], dtype=float32)>

In [None]:
def f(y, theta):
  return custom_Rinv(y, theta) ** 3
list(tf.test.compute_gradient(f, [tf.constant(1.), theta]))

tf.Tensor([[0.04421389]], shape=(1, 1), dtype=float32)
tf.Tensor([[0.04421389]], shape=(1, 1), dtype=float32)


[(array([[0.00546339]], dtype=float32),
  array([[-3.1790059e-04, -2.5330819e-04, -8.2001847e-04, -2.4853079e-04,
          -2.5501815e-04, -4.4931806e-04, -6.0675596e-04, -7.5498206e-04,
          -1.9648220e-04, -1.4652527e-03, -3.3134641e-04, -2.7074703e-04,
          -8.3309761e-04, -2.5655897e-04, -2.5879024e-04, -4.5377173e-04,
          -6.3056935e-04, -7.6361018e-04, -1.9964894e-04, -1.4652526e-03,
           1.5377718e-04,  1.3415341e-04,  8.6978725e-05,  1.2126378e-04,
           5.1921903e-05,  3.8855134e-05,  1.5877417e-04,  5.6325258e-05,
           4.0780418e-05, -0.0000000e+00]], dtype=float32)),
 (array([[0.00545912]], dtype=float32),
  array([[-3.15951154e-04, -2.51340680e-04, -8.18632485e-04,
          -2.46393029e-04, -2.55589839e-04, -4.48373583e-04,
          -6.05825335e-04, -7.51751882e-04, -1.98313486e-04,
          -1.46491209e-03, -3.33064207e-04, -2.70141754e-04,
          -8.30448640e-04, -2.57277861e-04, -2.55939085e-04,
          -4.50003397e-04, -6.302142