# Mixed Boundary Poisson Example
Demonstrate KFACPINNSolver on a 1D Poisson equation with Dirichlet and Neumann boundaries.

## Step-by-step
1. Create a network with Equinox.
2. Define a loss using both boundary conditions.
3. Instantiate `KFACPINNSolver`.
4. Train and print the final loss.

In [1]:
import jax, jax.numpy as jnp
import equinox as eqx
from bsde_dsgE.kfac import KFACPINNSolver, poisson_1d_residual

net = eqx.nn.MLP(in_size=1, out_size=1, width_size=16, depth=2, key=jax.random.PRNGKey(0))
net_scalar = lambda z: net(jnp.atleast_1d(z))[0]

def loss_fn(net, x):
    interior = x
    dir_x = jnp.array([0.0])
    neu_x = jnp.array([1.0])
    res = poisson_1d_residual(net_scalar, interior)
    dir_res = poisson_1d_residual(net_scalar, dir_x, dirichlet_bc=lambda z: 0.0)
    neu_res = poisson_1d_residual(net_scalar, neu_x, neumann_bc=lambda z: 1.0)
    return jnp.mean(res**2) + jnp.mean(dir_res**2) + jnp.mean(neu_res**2)

solver = KFACPINNSolver(net=net, loss_fn=loss_fn, lr=1e-2, num_steps=5)
xs = jnp.linspace(0.0, 1.0, 8)
losses = solver.run(xs, jax.random.PRNGKey(1))
print('final loss', float(losses[-1]))

final loss 1.296364665031433
