# Introduction to `kfac_pinn`
Demonstrate solving a 1D Poisson equation using the package.

In [None]:
import jax, jax.numpy as jnp
import equinox as eqx
from kfac_pinn import KFACPINNSolver, pinn_loss

net = eqx.nn.MLP(in_size=1, out_size=1, width_size=8, depth=2, key=jax.random.PRNGKey(0))
def loss_fn(net, x):
    bc = jnp.array([0.0, 1.0])
    return pinn_loss(lambda z: net(z).squeeze(), x, bc)

solver = KFACPINNSolver(net=net, loss_fn=loss_fn, lr=1e-2, num_steps=5)
xs = jnp.linspace(0.0, 1.0, 8).reshape(-1, 1)
losses = solver.run(xs, jax.random.PRNGKey(1))
float(losses[-1])