# Custom network example

Defining your own network and training it with the optimiser.

In [None]:
import sys, os
sys.path.append(os.path.abspath('..'))
import jax
import jax.numpy as jnp
import equinox as eqx
from kfac_pinn import KFAC, training

class Net(eqx.Module):
    linear: eqx.nn.Linear
    def __init__(self, key):
        self.linear = eqx.nn.Linear(1, 1, key=key)
    def __call__(self, x):
        return jax.nn.tanh(self.linear(x))

key = jax.random.PRNGKey(0)
model = Net(key)

def loss_fn(m, x):
    return jnp.mean((m(x) - jnp.sin(x)) ** 2)

data = [jax.random.uniform(key, (16, 1)) for _ in range(50)]
opt = KFAC(lr=1e-2)
model, state = training.train(model, opt, loss_fn, data, steps=50)
