In [2]:
import jax 
import jax.numpy as jnp
import jax.tree_util as jtu

import pytreeclass.src.tree_util as ptu
import pytreeclass as pytc

import treex as tx 
import treeo as to 
import equinox as eqx 

### Benchmark Linear layer creation

In [11]:
%%timeit

@pytc.treeclass
class pytc_Linear :
   weight : jnp.ndarray
   bias   : jnp.ndarray

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

@pytc.treeclass
class pytc_MLP:

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,3)
        self.l1 = pytc_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = pytc_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = pytc_Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)

        return x
        
model = pytc_MLP(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))

1.67 ms ± 16.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
%%timeit

class eqx_Linear(eqx.Module) :
   weight : jnp.ndarray
   bias   : jnp.ndarray

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

class eqx_MLP(eqx.Module):
    l1 : eqx_Linear
    l2 : eqx_Linear
    l3 : eqx_Linear

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,3)
        self.l1 = eqx_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = eqx_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = eqx_Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)

        return x
        
model = eqx_MLP(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))

2.17 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%%timeit

class tx_Linear(tx.Module) :
   weight : jnp.ndarray = to.field(node=True)
   bias   : jnp.ndarray = to.field(node=True)

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

class tx_MLP(tx.Module):
    l1 : tx_Linear = to.field(node=True)
    l2 : tx_Linear = to.field(node=True)
    l3 : tx_Linear = to.field(node=True)

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,3)
        self.l1 = tx_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = tx_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = tx_Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)

        return x
        
model = tx_MLP(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))


1.78 ms ± 46.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Benchmark vanilla training

In [49]:
%%timeit

@pytc.treeclass
class pytc_Linear :
   weight : jnp.ndarray
   bias   : jnp.ndarray

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

@pytc.treeclass
class pytc_MLP:

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,3)
        self.l1 = pytc_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = pytc_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = pytc_Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)

        return x
        
pytc_model = pytc_MLP(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01

@jax.value_and_grad
def loss_func(pytc_model, x, y):
    return jnp.mean((pytc_model(x) - y)**2)

@jax.jit
def update(pytc_model, x, y, lr=1e-3):
    value,grad = loss_func(pytc_model, x, y)
    return value,jtu.tree_map(lambda x,y: x - lr * y, pytc_model, grad)

for _ in range(1,20_001):
    value,pytc_model = update(pytc_model, x, y)

881 ms ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
%%timeit

class tx_Linear(tx.Module) :
   weight : jnp.ndarray = to.field(node=True)
   bias   : jnp.ndarray = to.field(node=True)

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

class tx_MLP(tx.Module):
    l1 : tx_Linear = to.field(node=True)
    l2 : tx_Linear = to.field(node=True)
    l3 : tx_Linear = to.field(node=True)

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,3)
        self.l1 = tx_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = tx_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = tx_Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)

        return x
        
tx_model = tx_MLP(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01

@jax.value_and_grad
def loss_func(tx_model, x, y):
    return jnp.mean((tx_model(x) - y)**2)

@jax.jit
def update(tx_model, x, y, lr=1e-3):
    value,grad = loss_func(tx_model, x, y)
    return value,jtu.tree_map(lambda x,y: x - lr * y, tx_model, grad)

for _ in range(1,20_001):
    value,tx_model = update(tx_model, x, y)

857 ms ± 6.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [48]:
%%timeit

class eqx_Linear(eqx.Module) :
   weight : jnp.ndarray
   bias   : jnp.ndarray

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

class eqx_MLP(eqx.Module):
    l1 : eqx_Linear
    l2 : eqx_Linear
    l3 : eqx_Linear

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,3)
        self.l1 = eqx_Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = eqx_Linear(key=keys[1],in_dim=hidden_dim,out_dim=hidden_dim)
        self.l3 = eqx_Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)
        x = jax.nn.tanh(x)
        x = self.l3(x)

        return x
        
eqx_model = tx_MLP(in_dim=1,out_dim=1,hidden_dim=10,key=jax.random.PRNGKey(0))
x = jnp.linspace(0,1,100)[:,None]
y = x**3 + jax.random.uniform(jax.random.PRNGKey(0),(100,1))*0.01

@jax.value_and_grad
def loss_func(eqx_model, x, y):
    return jnp.mean((eqx_model(x) - y)**2)

@jax.jit
def update(eqx_model, x, y, lr=1e-3):
    value,grad = loss_func(eqx_model, x, y)
    return value,jtu.tree_map(lambda x,y: x - lr * y, eqx_model, grad)

for _ in range(1,20_001):
    value,eqx_model = update(eqx_model, x, y)

873 ms ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
value

DeviceArray(0.00066327, dtype=float32)

In [39]:
value

DeviceArray(0.00066327, dtype=float32)