In [5]:
from __future__ import annotations

from typing import NamedTuple
import math


class Node(NamedTuple):
    val: float
    parents: tuple[Node, ...] = ()
    grad_fn: callable = None


def add(
    x: Node,
    y: Node,
) -> Node:
    out = Node(val=x.val + y.val, parents=(x, y), grad_fn=lambda g: (g, g))
    return out


Node.__add__ = add


def sub(
    x: Node,
    y: Node,
) -> Node:
    out = Node(val=x.val - y.val, parents=(x, y), grad_fn=lambda g: (g, -g))
    return out


Node.__sub__ = sub


def mul(
    x: Node,
    y: Node,
) -> Node:
    out = Node(
        val=x.val * y.val, parents=(x, y), grad_fn=lambda g: (g * y.val, g * x.val)
    )
    return out


Node.__mul__ = mul


def div(
    x: Node,
    y: Node,
) -> Node:
    out = Node(
        val=x.val / y.val,
        parents=(x, y),
        grad_fn=lambda g: (g / y.val, -(g * x.val) / (y.val**2)),
    )
    return out


Node.__div__ = div


def pow(
    x: Node,
    y: Node,
) -> Node:
    out = Node(
        val=math.pow(x.val, y.val),
        parents=(x, y),
        grad_fn=lambda g: (
            g * y.value * math.pow(x.val, y.val - 1),
            g * math.pow(x.val, y.val) * math.log(x.val),
        ),
    )
    return out


Node.__pow__ = pow


def sin(
    x: Node,
) -> Node:
    out = Node(
        val=math.sin(x.val), parents=(x,), grad_fn=lambda g: (g * math.cos(x.val),)
    )
    return out


def cos(x: Node) -> Node:
    out = Node(
        val=math.cos(x.val), parents=(x,), grad_fn=lambda g: (-g * math.sin(x.val),)
    )
    return out


def toposort(node: Node):

    visited = set()
    nodes = []

    def dfs(n):
        if n not in visited:
            visited.add(n)
            for parent in n.parents:
                dfs(parent)
            nodes.append(n)

    dfs(node)
    return reversed(nodes)


def contribute_grad_to_parents(
    node: Node, curr_grads: dict[int, Node]
) -> dict[int, Node]:
    """Given a node and a dictionary mapping node ids to their current grad accumulation,
    contribute the grad of the current node to its parents.

    Args:
        node: The node which will contribute grad to its parents.
        curr_grads: A dictionary mapping node ids to node's grad accumulations.

    Returns:
        dict[int, Node]: An updated dictionary mapping node ids to node's grads
        where the update is the passed-in node's contribution to its parents grads.
    """
    g = curr_grads[id(node)]
    for parent, parent_grad in zip(node.parents, node.grad_fn(g)):
        if id(parent) in curr_grads:
            curr_grads[id(parent)] += parent_grad
        else:
            curr_grads[id(parent)] = parent_grad

    return curr_grads


def value_and_grad(f: callable):

    def _value_and_grad(*at: tuple[float, ...]):
        input_ids = {}  # to hold ids of input variables

        out = f(*at)  # forward pass

        # to hold grad values of processed nodes
        grads = dict()
        grads[id(out)] = 1.0

        for node in toposort(out):

            # if we have an input node that doesn't have parents
            # it means its an input node
            if not node.parents:
                input_ids[id(node)] = None

            else:
                contribute_grad_to_parents(node, grads)

        return out.val, tuple(grads[k] for k in input_ids.keys())[::-1]

    return _value_and_grad


def grad(f: callable):

    def _grad(*at: tuple[float, ...]):
        _, grad = value_and_grad(f)(*at)
        return grad

    return _grad


def build_recursive_operator(op: callable):

    def fn(*args):
        if len(args) == 1:
            return args[0]
        else:
            return op(args[0], fn(*args[1:]))

    return fn


recursive_add: callable = build_recursive_operator(add)
recursive_sub: callable = build_recursive_operator(sub)
recursive_mul: callable = build_recursive_operator(mul)
recursive_div: callable = build_recursive_operator(div)
recursive_pow: callable = build_recursive_operator(pow)


# def test():
#     import jax
#     import jax.numpy as jnp

#     x, y, z = Node(1.0), Node(2.0), Node(3.0)

#     def f_jax(inputs):
#         x, y, z = inputs["x"], inputs["y"], inputs["z"]
#         a = x * y
#         b = x * x
#         c = z * z * jnp.sin(x)
#         d = z - y + x
#         return a + b + c + d

#     def f(x, y, z):
#         a = x * y
#         b = x * x
#         c = z * z * sin(x)
#         d = z - y + x
#         return recursive_add(a, b, c, d)

#     our_grad = grad(f)(x, y, z)
#     jax_grad = jax.grad(f_jax)({"x": 1.0, "y": 2.0, "z": 3.0})

#     print(f"Our grad: {our_grad}")
#     print(f"Jax grad: {jax_grad}")

#     # check values match
#     assert jnp.allclose(jnp.array(our_grad), jnp.array(list(jax_grad.values())))


# if __name__ == "__main__":
#     test()

In [7]:
x = Node(1.0)
y = Node(2.0)

x + y * x ** y

Node(val=3.0, parents=(Node(val=1.0, parents=(), grad_fn=None), Node(val=2.0, parents=(Node(val=2.0, parents=(), grad_fn=None), Node(val=1.0, parents=(Node(val=1.0, parents=(), grad_fn=None), Node(val=2.0, parents=(), grad_fn=None)), grad_fn=<function pow.<locals>.<lambda> at 0x1267931a0>)), grad_fn=<function mul.<locals>.<lambda> at 0x126793920>)), grad_fn=<function add.<locals>.<lambda> at 0x127a1a840>)