Skip to content

Imprecise matrix multiplication can cause networks to violate robustness constraints #15

@nic-barbara

Description

@nic-barbara

It turns out that matrix multiplication in JAX is not precise by default. Neural networks in flax get around this by writing their own wrapper around the underlying matrix multiplication function. We've done this for the LBDN code here:

def dot_lax(input1, input2, precision: PrecisionLike = None):
    return lax.dot_general(
        input1,
        input2,
        (((input1.ndim - 1,), (1,)), ((), ())),
        precision=precision,
    )

At the moment, all of the REN code is written using @ for matrix multiplication. I've noticed particularly in the new _init_linear_sys() routine for contracting RENs that this makes a huge difference to the network outputs. We should switch over to using something similar to the LBDN code.

As a quick fix, add this to the top of your scripts to globally set a higher precision for matrix multiplication.

import jax
jax.config.update("jax_default_matmul_precision", "highest")

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions