<a href="https://colab.research.google.com/github/applejxd/colaboratory/blob/master/algorithm/AutomaticDerivative.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Automatic Derivative by myself

[Automatic Derivatives - Ceres Solver](http://ceres-solver.org/automatic_derivatives.html)

In [31]:
import numpy as np

class Jet:
    def __init__(self, a: float, v):
        self.a = a
        self.v = np.array(v)

    def __str__(self):
        return f"{self.a}+{self.v}"

    def __add__(self, other):
        if isinstance(other, (int, float)):
            other = Jet(other, np.zeors(len(self.v)))
        return Jet(self.a + other.a, self.v + other.v)

    def __sub__(self, other):
        if isinstance(other, (int, float)):
            other = Jet(other, np.zeors(len(self.v)))
        return Jet(self.a - other.a, self.v - other.v)
    
    def __mul__(self, other):
        if isinstance(other, (int, float)):
            other = Jet(other, np.zeros(len(self.v)))
        return Jet(self.a * other.a, self.a * other.v + self.v * other.a)
    
    def __truediv__(self, other):
        if isinstance(other, (int, float)):
            other = Jet(other, np.zeors(len(self.v)))
        return Jet(self.a / other.a, self.v / other.a - self.a * other.v / other.a ** 2)
    
    def __pow__(self, other):
        if isinstance(other, (int, float)):
            other = Jet(other, np.zeros(len(self.v)))
        return Jet(self.a ** other.a,
                   other.a * self.a ** (other.a - 1) * self.v 
                   + self.a ** other.a * np.log(self.a) * other.v)

    def __radd__(self, other):
        return self.__add__(other)
    
    def __rsub__(self, other):
        return self.__sub__(other)
    
    def __rmul__(self, other):
        return self.__mul__(other)


x = Jet(1, (2, 3))
y = Jet(4, np.array((5, 6)))

print(x+y)
print(2*x)

5+[7 9]
2+[4. 6.]


In [29]:
def exp(x):
    return Jet(np.exp(x.a), np.exp(x.a) * x.v)


print(exp(Jet(2, (3, 4))))

7.38905609893065+[22.1671683 29.5562244]


\begin{equation}
(\nabla (x^2+2y^2+3xy))(x=2, y=1)
\end{equation}

In [32]:
def target_func(x, y):
    return x**2 + 2*y**2 + 3*x*y

# x=2
x = Jet(2, (1, 0))
# y=1
y = Jet(1, (0, 1))

print(target_func(x,y))

12+[ 7. 10.]


## Automatic Derivative by JAX

From [the quickstart](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)

In [34]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

\begin{equation}
    \left.\sum_i (1+e^{-x_i})^{-1}\right|_{x_0=0, x_1=1, x_2=2}
\end{equation}

In [35]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))



[0.25       0.19661197 0.10499357]


check numerical derivative solution

In [37]:
def first_finite_differences(f, x):
  eps = 1e-3
  # v_i is the tiny shift vector w.r.t. i-th component
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]
