In [4]:
from jax import tree_util

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [5]:
class CustomClass:
  def __init__(self, x: jnp.ndarray, mul: bool):
    self.x = x
    self.mul = mul

  @jax.jit
  def calc(self, y):
    if self.mul:
      return self.x * y
    return y

  def _tree_flatten(self):
    children = (self.x,)  # arrays / dynamic values
    aux_data = {'mul': self.mul}  # static values
    return (children, aux_data)

  @classmethod
  def _tree_unflatten(cls, aux_data, children):
    return cls(*children, **aux_data)

tree_util.register_pytree_node(CustomClass,
                               CustomClass._tree_flatten,
                               CustomClass._tree_unflatten)



In [6]:
c = CustomClass(2, True)
print(c.calc(3))

c.mul = False  # mutation is detected
print(c.calc(3))

c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
print(c.calc(3))

6
3
6
