# PINNKFAC Training Loop Demo

This notebook demonstrates using `pinn_train` from the `kfac_pinn` package to train a simple 1D Poisson PINN.

In [ ]:
import jax, jax.numpy as jnp, equinox as eqx
from kfac_pinn import PINNKFAC, pinn, pdes, training

key = jax.random.PRNGKey(0)
model = pinn.make_mlp(key=key)
opt = PINNKFAC(lr=1e-2, use_line_search=False)

def rhs(x):
    return (jnp.pi**2) * jnp.sin(jnp.pi * x)

def bc(x):
    return jnp.zeros_like(x)

key_i, key_b = jax.random.split(key)
interior_points = [pdes.sample_interior(key_i, jnp.array([0.0]), jnp.array([1.0]), 64)] * 100
boundary_points = [jnp.array([[0.0], [1.0]])] * 100

model, state = training.pinn_train(model, opt, rhs, bc, interior_points, boundary_points, steps=100)

In [ ]:
res = pinn.pinn_residual(model, interior_points[0], rhs)
print('Final interior residual', jnp.mean(res**2))