In [1]:
import jax
import jax.numpy as jnp
from jax import nn, random

from flax import nnx


In [None]:

class MLP(nnx.Module):
    """
    Sequential module
    """

    def __init__(self, dim_input, dim_output, dim_middle, nb_layer, rngs):
        super().__init__()

        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_middle = dim_middle

        assert nb_layer >= 2, "nb_layer should be at least 2"

        self.layers = nnx.List(
            [
                nnx.Linear(in_features=dim_middle, out_features=dim_middle, rngs=rngs)
                for _ in range(nb_layer - 2)
            ]
        )

        self.input_mapping = nnx.Linear(
            in_features=dim_input, out_features=dim_middle, rngs=rngs
        )
        self.output_mapping = nnx.Linear(
            in_features=dim_middle, out_features=dim_output, rngs=rngs
        )

    def __call__(self, x):
        x = self.input_mapping(x)
        x = nn.gelu(x)
        for layer in self.layers:
            x = layer(x)
            x = nn.gelu(x)
        x = self.output_mapping(x)
        return x

class Ferminet(nnx.Module):
    def __init__(self, nb_elements=1 ,dim_input=3, dim_output=1, dim_middle=32, nb_layer=2, nb_determinant=1, rngs=None):
        super().__init__()

        self.nb_elements = nb_elements
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_middle = dim_middle
        self.nb_layer = nb_layer
        self.nb_determinant = nb_determinant

        # first the mapping
        self.input_mapping = nnx.Linear(
            in_features=dim_input, out_features=dim_middle, rngs=rngs
        )

        # MLP module
        self.layers = nnx.List(
            [
                MLP(dim_input, dim_output=dim_middle, dim_middle=dim_middle, nb_layer=2, rngs=rngs)
                for _ in range(nb_layer)
            ]
        )

        # correction of the 
        self.last_layer = nnx.Linear(in_features=dim_middle, out_features=nb_elements*nb_determinant, rngs=rngs)





