In [31]:
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit

from ngrad.models import init_params, mlp
from ngrad.domains import Square, SquareBoundary
from ngrad.integrators import DeterministicIntegrator
from ngrad.utility import laplace, grid_line_search_factory
from ngrad.inner import model_laplace, model_identity
from ngrad.gram import gram_factory, nat_grad_factory, pre_gram_factory

In [2]:
jax.config.update("jax_enable_x64", True)

In [27]:
# random seed
seed = 0

# domains
interior = Square(1.)
boundary = SquareBoundary(1.)

# integrators
interior_integrator = DeterministicIntegrator(interior, 30)
boundary_integrator = DeterministicIntegrator(boundary, 30)
eval_integrator = DeterministicIntegrator(interior, 200)

# model
activation = lambda x : jnp.tanh(x)
layer_sizes = [2, 32, 1]
params = init_params(layer_sizes, random.PRNGKey(seed))
model = mlp(activation)
# model(params, x) 对数据维度进行 vmap
v_model = vmap(model, (None, 0))

In [28]:
# solution
@jit
def u_star(x):
    return jnp.prod(jnp.sin(jnp.pi * x))

# rhs
@jit
def f(x):
    return 2. * jnp.pi**2 * u_star(x)

# gramians
gram_bdry = gram_factory(
    model = model,
    trafo = model_identity,
    integrator = boundary_integrator
)

gram_laplace = gram_factory(
    model = model,
    trafo = model_laplace,
    integrator = interior_integrator
)

In [32]:
g = pre_gram_factory(
    model = model,
    trafo = model_laplace,
)

In [34]:
data = interior.deterministic_integration_points(30)

In [37]:
data[0]

Array([0.03448276, 0.03448276], dtype=float64)

In [39]:
g(params, data[0]).shape

(129, 129)

In [29]:
g_la = gram_laplace(params)

In [50]:
g_la.shape

(129, 129)

In [25]:
p_flatten = jax.flatten_util.ravel_pytree(params)[0]

In [26]:
p_flatten.shape

(2241,)

In [13]:
interior.deterministic_integration_points(30).shape

(784, 2)

In [52]:
def test_func(x):
    return jnp.cos(x) + 1j*jnp.sin(x)

In [55]:
grad_func = jax.grad(test_func)