# 🧮 Building a mini optimizer library

In this notebook stateful Exponential moving average  (`EMA`) and `Adam` is built using `PyTreeClass`.

In [1]:
import jax
import jax.numpy as jnp
import pytreeclass as pytc


class Tree(pytc.TreeClass):
    a: float


def loss_func(tree: Tree, x):
    return tree.a**x


def _moment_update(grads, moments, *, beta: float, order: int):
    def _moment_step(grad, moment):
        return beta * moment + (1 - beta) * (grad**order)

    return jax.tree_map(_moment_step, grads, moments)


def _debias_update(moments, *, beta: float, count: int):
    def _debias_step(moment):
        return moment / (1 - beta**count)

    return jax.tree_map(_debias_step, moments)


def ema(decay_rate: float, debias: float = True):
    """Exponential moving average

    Args:
        decay_rate: The decay rate of the moving average.
        debias: Whether to debias the moving average.
    """

    class EMA(pytc.TreeClass):
        def __init__(self, tree):
            self.state = jax.tree_map(jnp.zeros_like, tree)
            self.count = 0

        def _update(self, grads):
            self.count += 1
            self.state = _moment_update(
                grads,
                self.state,
                beta=decay_rate,
                order=1,
            )
            if debias:
                return _debias_update(
                    self.state,
                    beta=decay_rate,
                    count=self.count,
                )
            return self.state

        def update(self, grads):
            return self.at["_update"](grads)

    return EMA


def adam(*, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8):
    """Adam optimizer

    Args:
        beta1: The decay rate of the first moment.
        beta2: The decay rate of the second moment.
        eps: A small value to prevent division by zero.
    """

    class Adam(pytc.TreeClass):
        def __init__(self, tree):
            self.mu = jax.tree_map(jnp.zeros_like, tree)
            self.nu = jax.tree_map(jnp.zeros_like, tree)
            self.count = 0

        def _update(self, grads):
            self.count += 1
            self.mu = _moment_update(
                grads,
                self.mu,
                beta=beta1,
                order=1,
            )
            self.nu = _moment_update(
                grads,
                self.nu,
                beta=beta2,
                order=2,
            )

            mu_hat = _debias_update(
                self.mu,
                beta=beta1,
                count=self.count,
            )
            nu_hat = _debias_update(
                self.nu,
                beta=beta2,
                count=self.count,
            )

            return jax.tree_map(
                lambda mu, nu: mu / (jnp.sqrt(nu) + eps), mu_hat, nu_hat
            )

        def update(self, grads):
            # since self._update mutates the state, we need to use self.at
            # to return the method value and the mutated state
            return self.at["_update"](grads)

    return Adam


tree = Tree(a=2.0)
grads = jax.grad(loss_func)(tree, 2.0)
optim = adam(beta1=0.5, beta2=0.1)
optim_state = optim(tree)

updates, optim_state = optim_state.update(grads)
print(optim_state)

# updated tree
tree = jax.tree_map(lambda x, y: x + y, tree, updates)
print(tree)

Adam(mu=Tree(a=2.0), nu=Tree(a=14.4), count=1)
Tree(a=3.0)
