In [71]:
import numpy as np
from fimjax.util.mesh_generation import generate_identity_2d_mesh
from fimjax.util.datastructures import InitialValues
from fimjax.main import Solver

mesh, metrics = generate_identity_2d_mesh(30)
initial_values = InitialValues(locations=np.array([0]), values=np.array([0.]))
solver = Solver()
result = solver.solve(
    mesh = mesh,
    initial_values = initial_values,
    metrics = metrics,
    iter=200
)

In [72]:
import jax.numpy as jnp
import jax


ground_truth = result.solution
def loss(x):
    return jnp.sum((x-ground_truth)**2)

loss_gradient = jax.grad(loss)

In [73]:
alpha = float(1.8)
metrics_prime = alpha*metrics
solution_prime = solver.solve(
    mesh = mesh,
    initial_values = initial_values,
    metrics = metrics_prime,
    iter=200
).solution

adjoint = loss_gradient(solution_prime)

In [74]:
val, vjp = solver.value_and_vjp(
    mesh = mesh,
    initial_values = initial_values,
    metrics = metrics_prime,
    iter=200,
    adjoint_vector=adjoint
)

In [75]:
vjp.shape

(1682, 2, 2)

In [83]:
m = float(2)
def metric_tensor(m: float):
    return m*metrics

_, vjp_fun = jax.vjp(metric_tensor, m)
dl_dm = vjp_fun(vjp)
dl_dm

(Array(265.15057, dtype=float32, weak_type=True),)

In [81]:
vjp_fun(vjp)

(Array(265.15057, dtype=float32, weak_type=True),)