<a href="https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/assets/benchmark_nn_training_flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytreeclass
!pip install flax
!pip install optax

In [None]:
import jax
import jax.numpy as jnp
import pytreeclass as pytc
import flax
import optax 


class PyTCLinear(pytc.TreeClass):
    def __init__(self, in_dim: int, out_dim: int, key: jax.random.KeyArray, name: str):
        self.name = name
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jax.numpy.array(0.0)

    def __call__(self, x: jax.Array):
        return x @ self.weight + self.bias


def flax_linear(in_dim: int, out_dim: int, key: jax.random.KeyArray, name: str):
    class FlaxLinear(flax.struct.PyTreeNode):
        name: str = flax.struct.field(pytree_node=False)
        weight: jax.Array
        bias: jax.Array

        def __call__(self, x: jax.Array):
            return x @ self.weight + self.bias

    return FlaxLinear(name, jax.random.normal(key, (in_dim, out_dim)), jax.numpy.array(0.0))


def sequential_linears(layers, x):
    *layers, last = layers
    for layer in layers:
        x = layer(x)
        x = jax.nn.relu(x)
    return last(x)


x = jnp.linspace(100, 1)[:, None]
y = x**2
key = jax.random.PRNGKey(0)
optim = optax.adam(1e-3)


@jax.value_and_grad
def pytc_loss_func(layers,x,y):
    layers = jax.tree_map(pytc.unfreeze, layers, is_leaf=pytc.is_frozen)
    y = sequential_linears(layers, x)
    return jnp.mean((x-y)**2)

@jax.jit
def pytc_train_step(layers, optim_state, x,y):
    loss, grads = pytc_loss_func(layers,x,y)
    updates, optim_state= optim.update(grads, optim_state)
    layers = optax.apply_updates(layers, updates)
    return layers, optim_state, loss

@jax.value_and_grad
def flax_loss_func(layers,x,y):
    y = sequential_linears(layers, x)
    return jnp.mean((x-y)**2)

@jax.jit
def flax_train_step(layers, optim_state, x,y):
    loss, grads = flax_loss_func(layers,x,y)
    updates, optim_state= optim.update(grads, optim_state)
    layers = optax.apply_updates(layers, updates)
    return layers, optim_state, loss


def pytc_train(layers, optim_state, x,y, epochs=100):
    for _ in range(epochs):
        layers, optim_state, loss = pytc_train_step(layers,optim_state, x,y)
    return layers, loss

def flax_train(layers,optim_state, x,y, epochs=100):
    for _ in range(epochs):
        layers, optim_state, loss = flax_train_step(layers,optim_state, x,y)
    return layers, loss


for linear_count in [10, 100]:
    pytc_linears = [PyTCLinear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
    # mask non-differentiable parameters
    pytc_linears = jax.tree_map(lambda x: pytc.freeze(x) if pytc.is_nondiff(x) else x, pytc_linears)
    pytc_optim_state = optim.init(pytc_linears)


    flax_linears = [flax_linear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
    flax_optim_state = optim.init(flax_linears)


    pytc_linears , pytc_loss = pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=1000)
    flax_linears, flax_loss = flax_train(flax_linears, flax_optim_state, x,y, epochs=1000)

    assert pytc_loss == flax_loss

    time_pytc = %timeit -o pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=100)
    time_flax = %timeit -o flax_train(flax_linears, flax_optim_state, x,y, epochs=100)
    print(f"Flax/PyTc: {time_flax.average/time_pytc.average} for {linear_count} layers")

21.6 ms ± 548 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
30.8 ms ± 355 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Flax/PyTc: 1.4270735299354067 for 10 layers
474 ms ± 80.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
528 ms ± 41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Flax/PyTc: 1.113071349681521 for 100 layers
