In [38]:
def slax_full():
    import jax
    import jax.numpy as jnp
    import slax as sl
    import flax.linen as nn

    key = jax.random.PRNGKey(0)

    benchmark_title = f"slax full-precision v{'0.0.1'}"

    def prepare_fn(batch_size, n_steps, n_neurons, n_layers, device):
        class Model(nn.Module):
            @nn.compact
            def __call__(self,x):
                x = nn.Dense(n_neurons)(x)
                x = sl.RNN(sl.LIF(2.,spike_fn=sl.atan()))(x)
                x = nn.Dense(n_neurons)(x)
                x = sl.RNN(sl.LIF(2.,spike_fn=sl.atan()))(x)
                x = nn.Dense(n_neurons)(x)
                x = sl.RNN(sl.LIF(2.,spike_fn=sl.atan()))(x)

                return x

        input_static = jax.random.normal(key,shape=(n_steps, batch_size, n_neurons), dtype=jnp.float32)

        # Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
        SNN = Model()
        params = SNN.init(key, input_static)


        @jax.jit
        def net_eval(weights, events):
            readout = SNN.apply(weights, events,mutable=['carry'])
            traces, V_f = readout
            return traces.sum()
            #return readout[0].sum()

        model = (net_eval, params)

        return dict(model=model, input=input_static, n_neurons=n_neurons)
    #@jax.tree_util.Partial(jax.jit,static_argnums=(0,))
    def forward_fn(bench_dict):
        model, input_static = bench_dict["model"], bench_dict["input"]
        net_eval, params = model
        net_eval(params, input_static)
        bench_dict["output"] = input_static
        return bench_dict


    def backward_fn(bench_dict):
        input_static = bench_dict["input"]
        net_eval, params = bench_dict["model"]
        jax.grad(net_eval)(params, input_static)

    d= prepare_fn(64,500,1024,3,3)
    backward_fn(d)
    forward_fn(d)

    return prepare_fn, forward_fn, backward_fn, benchmark_title

In [39]:
prepare_fn, forward_fn, backward_fn, b = slax_full()

In [40]:
d= prepare_fn(64,500,1024,3,3)

In [42]:
forward_fn(d);

In [34]:
backward_fn(d);