# Quickstart: Training a PINN with KFAC

In [ ]:
import jax
import jax.numpy as jnp
import equinox as eqx
from bsde_seed.bsde_dsgE.optim import KFACPINNSolver

In [ ]:
def residual(net: eqx.Module, x: jnp.ndarray) -> jnp.ndarray:
    y = net(x)
    return jnp.mean(y ** 2)

In [ ]:
key = jax.random.PRNGKey(0)
net = eqx.nn.MLP(in_size=1, out_size=1, width_size=8, depth=2, key=key)
solver = KFACPINNSolver(net=net, loss_fn=residual, num_steps=100)
x = jnp.zeros((1, 1))
losses = solver.run(x, key)
float(losses[-1])