# Automatic differentiation

Basic automatic differentiation implementation in Python with `numpy`.

Inspired by [Understanding autodiff in 30 lines of Python](https://vmartin.fr/understanding-automatic-differentiation-in-30-lines-of-python.html#disqus_thread).

In [1]:
import numpy as np
from typing import TypeVar

T = TypeVar("T")

For starters, let's implement a simple Tensor class with basic arithmetic operations defined.

In [2]:
class BasicTensor:
    __slots__ = ("value",)

    def __init__(self, value: T = None) -> None:
        self.value = value

    def __repr__(self) -> str:
        return f"Tensor(value={self.value})"

    def __add__(self, other: "BasicTensor") -> "BasicTensor":
        return BasicTensor(self.value + other.value)

    def __sub__(self, other: "BasicTensor") -> "BasicTensor":
        return BasicTensor(self.value - other.value)

    def __mul__(self, other: "BasicTensor") -> "BasicTensor":
        return BasicTensor(self.value * other.value)

    def __truediv__(self, other: "BasicTensor") -> "BasicTensor":
        return BasicTensor(self.value / other.value)

Now we can use all the operations and the result will be correct, but we still have a lot to do in terms of differentiation.

In [3]:
x = BasicTensor(1)
y = BasicTensor(2)
z = BasicTensor(3)

z / (x + y + z)

Tensor(value=0.5)

In [4]:
import json
from typing import Callable, Optional

## Operations tree

But to reach our goal, we have to keep the structure of the operations, hence the `Operation`
class. It represents both unary and binary operations, so `left` is required and `right` can
be omitted, but the `func` has to be compatible with the amount of inputs and their types.

In [5]:
class Operation:
    __slots__ = ("left", "right", "func")

    def __init__(
        self,
        func: Callable[[T], T],
        left: "Tensor[T]",
        right: Optional["Tensor[T]"] = None,
    ) -> None:
        self.func = func
        self.left = left
        self.right = right

    def to_dict(self) -> dict:
        if self.right is None:
            return {
                "operation": str(self.func),
                "left": self.left.to_dict(),
            }

        return {
            "operation": str(self.func),
            "left": self.left.to_dict(),
            "right": self.right.to_dict(),
        }

    def forward(self) -> T:
        left = self.left.forward()

        if self.right is None:
            return self.func(left.value)

        right = self.right.forward()
        return self.func(left.value, right.value)

Let's make `Tensor` a little more complex, adding the `forward` method that calculates its
value using the `operation` attribute.

In [6]:
class Tensor:
    __slots__ = ("value", "operation")

    def __init__(
        self, value: T = None, operation: Optional["Operation"] = None
    ) -> None:
        self.value = value
        self.operation = operation

    def to_dict(self) -> dict:
        if self.operation is None:
            return {
                "value": float(self.value),
            }

        return {
            "value": float(self.value),
            "operation": self.operation.to_dict(),
        }

    def __repr__(self) -> str:
        return json.dumps(self.to_dict(), indent=2)

    def __add__(self, other: "Tensor") -> "Tensor":
        operation = Operation(np.add, self, other)
        return Tensor(operation=operation).forward()

    def __sub__(self, other: "Tensor") -> "Tensor":
        operation = Operation(np.subtract, self, other)
        return Tensor(operation=operation).forward()

    def __mul__(self, other: "Tensor") -> "Tensor":
        operation = Operation(np.multiply, self, other)
        return Tensor(operation=operation).forward()

    def __truediv__(self, other: "Tensor") -> "Tensor":
        operation = Operation(np.true_divide, self, other)
        return Tensor(operation=operation).forward()

    def __neg__(self) -> "Tensor":
        operation = Operation(np.negative, self)
        return Tensor(operation=operation).forward()

    def __pow__(self, other: "Tensor") -> "Tensor":
        operation = Operation(np.power, self, other)
        return Tensor(operation=operation).forward()

    def forward(self) -> "Tensor":
        if self.operation is None:
            return self

        self.value = self.operation.forward()
        return self

Now we can see structure of the operations in the form of a binary tree. 
Let's try building this expression:

$$
\frac {(x \cdot y)^z} {x^z \cdot y}
$$

In [7]:
x = Tensor(10)
y = Tensor(34)
z = Tensor(4)

((x * y) ** z) / (x**z * y)

{
  "value": 39304.0,
  "operation": {
    "operation": "<ufunc 'divide'>",
    "left": {
      "value": 13363360000.0,
      "operation": {
        "operation": "<ufunc 'power'>",
        "left": {
          "value": 340.0,
          "operation": {
            "operation": "<ufunc 'multiply'>",
            "left": {
              "value": 10.0
            },
            "right": {
              "value": 34.0
            }
          }
        },
        "right": {
          "value": 4.0
        }
      }
    },
    "right": {
      "value": 340000.0,
      "operation": {
        "operation": "<ufunc 'multiply'>",
        "left": {
          "value": 10000.0,
          "operation": {
            "operation": "<ufunc 'power'>",
            "left": {
              "value": 10.0
            },
            "right": {
              "value": 4.0
            }
          }
        },
        "right": {
          "value": 34.0
        }
      }
    }
  }
}

Storing the operation structure makes it quite easy to differentiate any
expression we have. We just need to add a `.differentiate` method that, well,
*differentiates* with respect to the `target` tensor.

In the multiplication and division methods you could notice a shortcut that
let's us skip a branch of calculations when we know it evaluates to $0$. There
is also a similar shortcut in the power rule. Another shortcut I added is
removing a branch if it does not depend on the target tensor. This is done
using the `.depends_on` method that recursively checks each branch before
calculating its derivative.

In [8]:
class Tensor(Tensor):
    def differentiate(self, target: "Tensor[T]") -> "Tensor":
        if target is self:
            return Tensor(1)

        if self.operation is None:
            return Tensor(0)

        if self.depends_on(target) is False:
            return Tensor(0)

        left = self.operation.left
        right = self.operation.right
        func = self.operation.func

        match func:
            # sum rule
            case np.add:
                return left.differentiate(target) + right.differentiate(target)

            case np.subtract:
                return left.differentiate(target) - right.differentiate(target)

            # product rule
            case np.multiply:
                if left.value == 0 or right.value == 0:
                    return Tensor(0)

                return left.differentiate(target) * right + left * right.differentiate(
                    target
                )

            # quotient rule
            case np.true_divide:
                if left.value == 0:
                    return Tensor(0)

                return (
                    left.differentiate(target) * right
                    - left * right.differentiate(target)
                ) / (right ** Tensor(2))

            # power rule
            case np.power:
                p = right - Tensor(1)

                if p.value == 0:
                    return right

                return right * left**p

    def depends_on(self, target) -> bool:
        if target is self:
            return True

        if self.operation is None:
            return False

        left = self.operation.left
        right = self.operation.right

        if right is None:
            return left.depends_on(target)

        return left.depends_on(target) or right.depends_on(target)

Let's build an expression, find its derivative and compare:

$$
\begin{aligned}
f(x, y, z) &= y^z + xy + x^zz \\
\frac {\partial f} {\partial y} &= z y^{(z - 1)} + x
\end{aligned}
$$

In [9]:
x = Tensor(10)
y = Tensor(34)
z = Tensor(2)

r = y**z + x * y + x**z * z
r.differentiate(y)

{
  "value": 78.0,
  "operation": {
    "operation": "<ufunc 'add'>",
    "left": {
      "value": 78.0,
      "operation": {
        "operation": "<ufunc 'add'>",
        "left": {
          "value": 68.0,
          "operation": {
            "operation": "<ufunc 'multiply'>",
            "left": {
              "value": 2.0
            },
            "right": {
              "value": 34.0,
              "operation": {
                "operation": "<ufunc 'power'>",
                "left": {
                  "value": 34.0
                },
                "right": {
                  "value": 1.0,
                  "operation": {
                    "operation": "<ufunc 'subtract'>",
                    "left": {
                      "value": 2.0
                    },
                    "right": {
                      "value": 1.0
                    }
                  }
                }
              }
            }
          }
        },
        "right": {
          "value":

The results matched perfectly, and you can see that the last term $x^zz$ evaluated
to $0$ immediately, as it doesn't contain the $y$ variable. Let's try something else:
we can change value of $y$, and sure enough, the derivative at that point is $0$.

In [10]:
y.value = 0
r.differentiate(y)

{
  "value": 0.0,
  "operation": {
    "operation": "<ufunc 'add'>",
    "left": {
      "value": 0.0,
      "operation": {
        "operation": "<ufunc 'add'>",
        "left": {
          "value": 0.0,
          "operation": {
            "operation": "<ufunc 'multiply'>",
            "left": {
              "value": 2.0
            },
            "right": {
              "value": 0.0,
              "operation": {
                "operation": "<ufunc 'power'>",
                "left": {
                  "value": 0.0
                },
                "right": {
                  "value": 1.0,
                  "operation": {
                    "operation": "<ufunc 'subtract'>",
                    "left": {
                      "value": 2.0
                    },
                    "right": {
                      "value": 1.0
                    }
                  }
                }
              }
            }
          }
        },
        "right": {
          "value": 0.0
