# Full KFAC demo

Train a 1D Poisson PINN using the full KFAC optimiser.

In [None]:
import sys, os
sys.path.append(os.path.abspath('..'))
import jax
import jax.numpy as jnp
from kfac_pinn import PINNKFAC, pinn, pdes, training

key = jax.random.PRNGKey(0)
model = pinn.make_mlp(in_dim=1, key=key)
opt = PINNKFAC(lr=1e-2)

def rhs(x):
    return (jnp.pi ** 2) * jnp.sin(jnp.pi * x)

def exact(x):
    return jnp.sin(jnp.pi * x)

def loss_fn(m, batch):
    interior, boundary = batch
    loss_i = pinn.interior_loss(m, interior, rhs)
    loss_b = pinn.boundary_loss(m, boundary, exact)
    return loss_i + loss_b

domain = jnp.array([[0.0], [1.0]])
key, k1, k2 = jax.random.split(key, 3)
interior = pdes.sample_interior(k1, domain[0], domain[1], 64)
boundary = pdes.sample_boundary(k2, domain[0], domain[1], 64)
batch = (interior, boundary)
state = opt.init(model)
model, state = opt.step(model, loss_fn, batch, state)
print('Loss:', loss_fn(model, batch))
