# 2-D Poisson with KFAC
This notebook shows `poisson_nd_residual`.

In [None]:
import jax, jax.numpy as jnp
import equinox as eqx
from bsde_dsgE.kfac import KFACPINNSolver, poisson_nd_residual

net = eqx.nn.MLP(in_size=2, out_size=1, width_size=16, depth=2, key=jax.random.PRNGKey(0))

def loss_fn(net, x):
    interior = x
    bc = jnp.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
    res = poisson_nd_residual(lambda z: net(z).squeeze(), interior)
    bc_res = net(bc).squeeze()
    return jnp.mean(res ** 2) + jnp.mean(bc_res ** 2)

solver = KFACPINNSolver(net=net, loss_fn=loss_fn, lr=1e-2, num_steps=10)
xs = jax.random.uniform(jax.random.PRNGKey(1), (16, 2))
losses = solver.run(xs, jax.random.PRNGKey(2))
print('final loss', float(losses[-1]))