# Toy example with `kfac_pinn`
Demonstrate optimisation on a trivial quadratic loss.

## Step-by-step

1. Create a network with Equinox.
2. Define the residual or loss function.
3. Instantiate `KFACPINNSolver`.
4. Call `solver.run` to train.
5. Inspect the loss history.

In [None]:
import jax, jax.numpy as jnp
import equinox as eqx
from kfac_pinn import KFACPINNSolver

key = jax.random.PRNGKey(0)
net = eqx.nn.MLP(in_size=1, out_size=1, width_size=8, depth=2, key=key)

def loss_fn(net, x):
    return jnp.mean(net(x)**2)

solver = KFACPINNSolver(net=net, loss_fn=loss_fn, lr=1e-2, num_steps=5)
x = jnp.zeros((1,1))
losses = solver.run(x, key)
print('losses:', losses)