In [1]:
import jax
import jax.numpy as jnp
import numpy as np

from flax import nnx

In [3]:
class Count(nnx.Variable[nnx.A]):
  pass


class SingleLayer(nnx.Module):
	"""
	Single layer neural network with an activation function
	"""
	def __init__(self, dim_in: int, dim_out: int, activation, *, rngs: nnx.Rngs):
		key = rngs.params()
		self.count = Count(jnp.array(0))
		initializer = nnx.initializers.uniform(scale=1.0)
		self.w = nnx.Param(initializer(key, (dim_in, dim_out)))
		self.b = nnx.Param(jnp.zeros((dim_out,)))
		self.activation = activation
		self.dim_in, self.dim_out = dim_in, dim_out

	def __call__(self, x):
		self.count.value += 1
		x = x @ self.w + self.b
		x = self.activation(x)
		return x

In [None]:
  model = SingleLayer(
    dim_in=1,
    dim_out=10,
    activation=jnp.tanh,
    rngs=
  )

In [6]:
a, b = 0.0, 1.0
x, w = np.polynomial.legendre.leggauss(deg=1)
x = 0.5 * (b - a) * x + 0.5 * (b + a)  # Translation from [-1, 1] to [a, b]
w = 0.5 * (b - a) * w  # Scale quadrature weights

In [7]:
x

array([0.5])

In [8]:
w

array([1.])