# Overview of `bsde_dsgE.kfac`

This notebook demonstrates the high level API for the KFAC solver.

In [ ]:
import jax, jax.numpy as jnp
import equinox as eqx
from bsde_dsgE.kfac import KFACPINNSolver, pinn_loss

net = eqx.nn.MLP(in_size=1, out_size=1, width_size=8, depth=2, key=jax.random.PRNGKey(0))
def loss_fn(net, x):
    bc = jnp.array([0.0, 1.0])
    return pinn_loss(lambda z: net(z).squeeze(), x, bc)

solver = KFACPINNSolver(net=net, loss_fn=loss_fn, lr=1e-2, num_steps=5)
xs = jnp.linspace(0.0, 1.0, 8).reshape(-1, 1)
losses = solver.run(xs, jax.random.PRNGKey(1))
print('final loss', float(losses[-1]))