<a href="https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/benchmark_nn_training_equinox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install pytreeclass
!pip install equinox
!pip install optax

In [5]:
import jax
import jax.numpy as jnp
import pytreeclass as pytc
import equinox as eqx
import optax 


class PyTCLinear(pytc.TreeClass):
    def __init__(self, in_dim: int, out_dim: int, key: jax.random.KeyArray, name: str):
        self.name = name
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jax.numpy.array(0.0)

    def __call__(self, x: jax.Array):
        return x @ self.weight + self.bias


class EqxLinear(eqx.Module):
    name: str
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_dim: int, out_dim: int, key: jax.random.KeyArray, name:str):
        self.name = name
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jax.numpy.array(0.0)

    def __call__(self, x: jax.Array):
        return x @ self.weight + self.bias



def sequential_linears(layers, x):
    *layers, last = layers
    for layer in layers:
        x = layer(x)
        x = jax.nn.relu(x)
    return last(x)


x = jnp.linspace(100, 1)[:, None]
y = x**2
key = jax.random.PRNGKey(0)
optim = optax.adam(1e-3)


@jax.value_and_grad
def pytc_loss_func(layers,x,y):
    layers = jax.tree_map(pytc.unfreeze, layers, is_leaf=pytc.is_frozen)
    y = sequential_linears(layers, x)
    return jnp.mean((x-y)**2)

@jax.jit
def pytc_train_step(layers, optim_state, x,y):
    loss, grads = pytc_loss_func(layers,x,y)
    updates, optim_state= optim.update(grads, optim_state)
    layers = optax.apply_updates(layers, updates)
    return layers, optim_state, loss

@eqx.filter_value_and_grad
def eqx_loss_func(layers,x,y):
    y = sequential_linears(layers, x)
    return jnp.mean((x-y)**2)

@eqx.filter_jit
def eqx_train_step(layers, optim_state, x,y):
    loss, grads = eqx_loss_func(layers,x,y)
    updates, optim_state= optim.update(grads, optim_state)
    layers = eqx.apply_updates(layers, updates)
    return layers, optim_state, loss


def pytc_train(layers, optim_state, x,y, epochs=100):
    for _ in range(epochs):
        layers, optim_state, loss = pytc_train_step(layers,optim_state, x,y)
    return layers, loss

def eqx_train(layers,optim_state, x,y, epochs=100):
    for _ in range(epochs):
        layers, optim_state, loss = eqx_train_step(layers,optim_state, x,y)
    return layers, loss


for linear_count in [10,100]:
    pytc_linears = [PyTCLinear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
    # mask non-differentiable parameters
    pytc_linears = jax.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, pytc_linears)
    pytc_optim_state = optim.init(pytc_linears)


    eqx_linears = [EqxLinear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
    eqx_optim_state = optim.init(eqx.filter(eqx_linears, eqx.is_array))


    pytc_linears , pytc_loss = pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=1000)
    eqx_linears, eqx_loss = eqx_train(eqx_linears, eqx_optim_state, x,y, epochs=1000)

    assert pytc_loss == eqx_loss

    time_pytc = %timeit -o pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=100)
    time_eqx = %timeit -o eqx_train(eqx_linears, eqx_optim_state, x,y, epochs=100)
    print(f"Eqx/PyTc: {time_eqx.average/time_pytc.average} for {linear_count} layers")

34.4 ms ± 867 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
230 ms ± 93.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Eqx/PyTc: 6.671167451529536 for 10 layers
659 ms ± 19.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.79 s ± 272 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Eqx/PyTc: 2.714461166827432 for 100 layers
