# Quickstart: KFAC for PINNs
This notebook demonstrates how to solve a 1D Poisson equation using `KFACPINNSolver`.

## Step-by-step

1. Create a network with Equinox.
2. Define the residual or loss function.
3. Instantiate `KFACPINNSolver`.
4. Call `solver.run` to train.
5. Inspect the loss history.

In [None]:
import jax, jax.numpy as jnp
import equinox as eqx
from bsde_dsgE.kfac import KFACPINNSolver, pinn_loss

net = eqx.nn.MLP(in_size=1, out_size=1, width_size=16, depth=2, key=jax.random.PRNGKey(0))
def loss_fn(net, x):
    interior = x
    bc = jnp.array([0.0, 1.0])
    return pinn_loss(lambda z: net(z).squeeze(), interior, bc)

solver = KFACPINNSolver(net=net, loss_fn=loss_fn, lr=1e-2, num_steps=10)
xs = jnp.linspace(0.0, 1.0, 16).reshape(-1, 1)
losses = solver.run(xs, jax.random.PRNGKey(1))
print('final loss', float(losses[-1]))