# PINN Training with KFAC

This notebook builds a simple PINN using JAX/Equinox and trains it with the `KFACPINNSolver`.

In [ ]:
import jax
import jax.numpy as jnp
import equinox as eqx
import matplotlib.pyplot as plt
from bsde_seed.bsde_dsgE.optim import KFACPINNSolver


In [ ]:
# Define a simple PDE: y''(x) + y(x) = 0 with y(0)=0, y(pi)=0 (solution sin(x))
def pinn_loss(net, x):
    def y_fn(x):
        return net(x)[0]
    dy_dx = jax.vmap(jax.grad(y_fn))(x)
    d2y_dx2 = jax.vmap(jax.grad(jax.grad(y_fn)))(x)
    resid = d2y_dx2 + y_fn(x)
    bc = jnp.array([
        y_fn(jnp.array([0.0])),
        y_fn(jnp.array([jnp.pi]))
    ])
    return jnp.mean(resid ** 2) + jnp.mean(bc ** 2)

key = jax.random.PRNGKey(0)
net = eqx.nn.MLP(in_size=1, out_size=1, width_size=32, depth=2, key=key)
solver = KFACPINNSolver(net=net, loss_fn=pinn_loss, lr=1e-2, num_steps=200)
xs = jnp.linspace(0, jnp.pi, 64).reshape(-1, 1)
losses = solver.run(xs, key)


In [ ]:
plt.plot(losses)
plt.xlabel('step')
plt.ylabel('loss')
plt.title('PINN convergence with KFAC')
plt.show()
