In [1]:
from jaxtyping import Float, Array
from typing import Any
import jax.tree_util as jtu
import jax.numpy as jnp

# Mytree code:

In [2]:
from mytree import Mytree, param_field, Softplus

class Mytree_SubFoo(Mytree):
    a: Float[Array, "..."] = param_field(bijector=Softplus)
    b: Float[Array, "..."] = param_field(bijector=Softplus)

    def __init__(self, a, b):
        self.a = a
        self.b = b

class Mytree_Foo(Mytree):
    b: list[Mytree_SubFoo]
    a: Float[Array, "..."] = param_field(bijector=Softplus)

    def __init__(self, a, b):
        self.a = a
        self.b = b

  jax.tree_util.register_keypaths(


# Pytree

In [3]:
from simple_pytree import Pytree, static_field

class Param(Pytree):
    value: Any
    bijector: Any = static_field()
    trainable: Any = static_field()

    def __init__(self, value, bijector=Softplus, trainable=True):
        self.value = value
        self.bijector = bijector
        self.trainable = trainable

def _is_param(x):
    return isinstance(x, Param)

def _resolve_bijector_forward(x):
    if _is_param(x):

        return x.replace(value = x.bijector.forward(x.value))
    else:
        return x
    
def _resolve_bijector_inverse(x):
    if _is_param(x):
        return x.replace(value = x.bijector.inverse(x.value))
    else:
        return x

class Module(Pytree):    
    def constrain(self):
        return jtu.tree_map(_resolve_bijector_forward, self, is_leaf=_is_param)
    
    def unconstrain(self):
        return jtu.tree_map(_resolve_bijector_inverse, self, is_leaf=_is_param)

 

class Pytree_SubFoo(Module):
    a: Param 
    b: Param 

    def __init__(self, a: Float[Array, "..."] , b: Float[Array, "..."]):
        self.a = Param(a)
        self.b = Param(b)


class Pytree_Foo(Module):
    b: list[Pytree_SubFoo]
    a: Param

    def __init__(self, b: list[Pytree_SubFoo], a:  Float[Array, "..."]):
        self.b = b
        self.a = Param(a)

# Performance comparison:

Run on a M1 Pro CPU.

- **Initialisation**: is faster for mytree.
- **Transformations**: is faster for mytree.
- **Replacing attributes**: is faster for mytree implimentation.

Completeing further benchmarks is on the todo list.

In [4]:
pytree_classes = [Pytree_Foo, Pytree_SubFoo]
mytree_classes = [Mytree_Foo, Mytree_SubFoo]


def init_tree(tree, subtree):
    return tree([subtree(jnp.array([1.0] * 10000), jnp.array([1.0] * 10000))]*10000, jnp.array([3.0] * 10000))

for name, implimentation in zip(["mytree", "pytree"], [mytree_classes, pytree_classes]):

    print(f"\n {name}:")
    %timeit init_tree(*implimentation)
    foo = init_tree(*implimentation)
    %timeit foo.constrain()
    %timeit foo.replace(a=123)


 mytree:
52.1 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.02 s ± 35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.52 µs ± 15.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

 pytree:
58 ms ± 2.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.08 s ± 20.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.05 µs ± 76.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
