It turns out that matrix multiplication in JAX is not precise by default. Neural networks in flax get around this by writing their own wrapper around the underlying matrix multiplication function. We've done this for the LBDN code here:
def dot_lax(input1, input2, precision: PrecisionLike = None):
return lax.dot_general(
input1,
input2,
(((input1.ndim - 1,), (1,)), ((), ())),
precision=precision,
)
At the moment, all of the REN code is written using @ for matrix multiplication. I've noticed particularly in the new _init_linear_sys() routine for contracting RENs that this makes a huge difference to the network outputs. We should switch over to using something similar to the LBDN code.
As a quick fix, add this to the top of your scripts to globally set a higher precision for matrix multiplication.
import jax
jax.config.update("jax_default_matmul_precision", "highest")