<a href="https://colab.research.google.com/github/ASEM000/PyTreeClass/blob/main/PyTreeClass_benchmarks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pytreeclass 
!pip install equinox
!pip install treex
!pip install treeo

In [2]:
import jax 
import jax.numpy as jnp
import jax.tree_util as jtu

import pytreeclass as pytc

import treex as tx 
import treeo as to 

import equinox as eqx 

import numpy.testing as npt
import matplotlib.pyplot as plt 

def tree_copy(tree):
    return jtu.tree_unflatten(*jtu.tree_flatten(tree)[::-1])


## Model definition

### PYTC model

In [3]:
@pytc.treeclass
class pytc_Linear :
   weight : jnp.ndarray 
   bias   : jnp.ndarray

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

@pytc.treeclass
class pytc_MLP:
    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,9)
        self.l1 = pytc_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = pytc_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = pytc_Linear(key=keys[2],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l4 = pytc_Linear(key=keys[3],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l5 = pytc_Linear(key=keys[4],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l6 = pytc_Linear(key=keys[5],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l7 = pytc_Linear(key=keys[6],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l8 = pytc_Linear(key=keys[7],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l9 = pytc_Linear(key=keys[8],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)
        x = jax.nn.tanh(x)
        x = self.l4(x)
        x = jax.nn.tanh(x)
        x = self.l5(x)
        x = jax.nn.tanh(x)
        x = self.l6(x)
        x = jax.nn.tanh(x)
        x = self.l7(x)
        x = jax.nn.tanh(x)
        x = self.l8(x)
        x = jax.nn.tanh(x)
        x = self.l9(x)

        return x

### treex

In [4]:

class tx_Linear(tx.Module) :
   weight : jnp.ndarray = to.field(node=True)
   bias   : jnp.ndarray = to.field(node=True)

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

class tx_MLP(tx.Module):
    l1 : tx_Linear = to.field(node=True)
    l2 : tx_Linear = to.field(node=True)
    l3 : tx_Linear = to.field(node=True)
    l4 : tx_Linear = to.field(node=True)
    l5 : tx_Linear = to.field(node=True)
    l6 : tx_Linear = to.field(node=True)
    l7 : tx_Linear = to.field(node=True)
    l8 : tx_Linear = to.field(node=True)
    l9 : tx_Linear = to.field(node=True)

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,9)
        self.l1 = tx_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = tx_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = tx_Linear(key=keys[2],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l4 = tx_Linear(key=keys[3],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l5 = tx_Linear(key=keys[4],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l6 = tx_Linear(key=keys[5],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l7 = tx_Linear(key=keys[6],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l8 = tx_Linear(key=keys[7],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l9 = tx_Linear(key=keys[8],in_dim=hidden_dim,out_dim=out_dim)


    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)
        x = jax.nn.tanh(x)
        x = self.l4(x)
        x = jax.nn.tanh(x)
        x = self.l5(x)
        x = jax.nn.tanh(x)
        x = self.l6(x)
        x = jax.nn.tanh(x)
        x = self.l7(x)
        x = jax.nn.tanh(x)
        x = self.l8(x)
        x = jax.nn.tanh(x)
        x = self.l9(x)

        return x


### eqx

In [5]:
class eqx_Linear(eqx.Module) :
   weight : jnp.ndarray
   bias   : jnp.ndarray
   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

class eqx_MLP(eqx.Module):
    l1 : eqx_Linear 
    l2 : eqx_Linear 
    l3 : eqx_Linear 
    l4 : eqx_Linear 
    l5 : eqx_Linear 
    l6 : eqx_Linear 
    l7 : eqx_Linear 
    l8 : eqx_Linear 
    l9 : eqx_Linear 

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,9)
        self.l1 = eqx_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = eqx_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = eqx_Linear(key=keys[2],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l4 = eqx_Linear(key=keys[3],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l5 = eqx_Linear(key=keys[4],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l6 = eqx_Linear(key=keys[5],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l7 = eqx_Linear(key=keys[6],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l8 = eqx_Linear(key=keys[7],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l9 = eqx_Linear(key=keys[8],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)
        x = jax.nn.tanh(x)
        x = self.l4(x)
        x = jax.nn.tanh(x)
        x = self.l5(x)
        x = jax.nn.tanh(x)
        x = self.l6(x)
        x = jax.nn.tanh(x)
        x = self.l7(x)
        x = jax.nn.tanh(x)
        x = self.l8(x)
        x = jax.nn.tanh(x)
        x = self.l9(x)

        return x

## Benchmarking flatten/unflatten

In [None]:
x = jnp.linspace(0,1,100)[:,None]

t_pytc = dict()
t_tx = dict()
t_eqx = dict()


def diff(func):
    return jax.grad(lambda *args:jnp.sum(func(*args)))

hidden_dims = [10, 100, 1_000, 10_000]

for hidden_dim in hidden_dims:
    pytc_model = pytc_MLP(in_dim=1,out_dim=1,hidden_dim=hidden_dim,key=jax.random.PRNGKey(0))
    tx_model = tx_MLP(in_dim=1,out_dim=1,hidden_dim=hidden_dim,key=jax.random.PRNGKey(0))
    eqx_model = eqx_MLP(in_dim=1,out_dim=1,hidden_dim=hidden_dim,key=jax.random.PRNGKey(0))

    # assertion
    npt.assert_allclose(pytc_model(x),tx_model(x))
    npt.assert_allclose(pytc_model(x),eqx_model(x))

    npt.assert_allclose(diff(pytc_model)(x),diff(tx_model)(x))
    npt.assert_allclose(diff(pytc_model)(x),diff(eqx_model)(x))

    t_pytc[hidden_dim] = %timeit -o tree_copy(pytc_model)
    t_tx[hidden_dim] = %timeit -o tree_copy(tx_model)
    t_eqx[hidden_dim] = %timeit -o tree_copy(eqx_model)





In [None]:
plt.figure(figsize=(10,5))

plt.errorbar(jnp.array(list(t_pytc.keys())), jnp.array([t.average for t in t_pytc.values()]),jnp.array([t.stdev for t in t_pytc.values()]),label="PyTreeClass", c="r", fmt="o")
plt.errorbar(jnp.array(list(t_pytc.keys())), jnp.array([t.average for t in t_tx.values()]),jnp.array([t.stdev for t in t_tx.values()]),label="Treex", c="b", fmt="o")
plt.errorbar(jnp.array(list(t_pytc.keys())), jnp.array([t.average for t in t_eqx.values()]),jnp.array([t.stdev for t in t_eqx.values()],),label="Equinox", c="g", fmt="o")


plt.title("Flatten/Unflatten on Colab GPU")
plt.xlabel("Hidden dimension")
plt.ylabel("Time in us")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.savefig("tree_copy_colab_gpu.svg")