In [1]:
import dolfinx as dlx
import numpy as np
from dolfinx.fem import petsc
from mpi4py import MPI
from petsc4py import PETSc

from ls_prior import components, fem, prior

In [2]:
kappa, tau = 1.0, 1.0
nx, ny = 2, 2
mpi_communicator = MPI.COMM_WORLD

cg_solver_settings = components.InverseMatrixSolverSettings(
    solver_type=PETSc.KSP.Type.CG,
    preconditioner_type=PETSc.PC.Type.JACOBI,
    relative_tolerance=1e-6,
    absolute_tolerance=1e-8,
    max_num_iterations=10000,
)

amg_solver_settings = components.InverseMatrixSolverSettings(
    solver_type=PETSc.KSP.Type.PREONLY,
    preconditioner_type=PETSc.PC.Type.GAMG,
    relative_tolerance=1e-6,
    absolute_tolerance=1e-8,
    max_num_iterations=10000,
)

In [3]:
mesh = dlx.mesh.create_rectangle(
    mpi_communicator,
    [np.array([0, 0]), np.array([1, 1])],
    [nx, ny],
    dlx.mesh.CellType.triangle,
)

In [4]:
function_space = dlx.fem.functionspace(mesh, ("Lagrange", 1))
mass_matrix_form, spde_matrix_form = fem.generate_forms(function_space, kappa, tau)
mass_matrix = petsc.assemble_matrix(dlx.fem.form(mass_matrix_form))
spde_matrix = petsc.assemble_matrix(dlx.fem.form(spde_matrix_form))
mass_matrix.assemble()
spde_matrix.assemble()
mass_matrix_factorization = fem.FEMMatrixBlockFactorization(mesh, function_space, mass_matrix_form)
mass_matrix_factor = mass_matrix_factorization.assemble()
converter = fem.FEMConverter(function_space)

In [5]:
mass_matrix_inverse = components.InverseMatrixSolver(
    cg_solver_settings, mass_matrix, mpi_communicator
)
spde_matrix_inverse = components.InverseMatrixSolver(
    amg_solver_settings, spde_matrix, mpi_communicator
)

In [6]:
precision_operator = components.BilaplacianPrecision(spde_matrix, mass_matrix_inverse)
covariance_operator = components.BilaplacianCovariance(mass_matrix, spde_matrix_inverse)
sampling_factor = components.BilaplacianCovarianceFactor(mass_matrix_factor, spde_matrix_inverse)

In [7]:
precision_operator_interface = components.InterfaceComponent(precision_operator, converter)
covariance_operator_interface = components.InterfaceComponent(covariance_operator, converter)
sampling_factor_interface = components.InterfaceComponent(
    sampling_factor, converter, convert_input_from_mesh=False, convert_output_to_mesh=True
)

In [8]:
mean_vector = np.zeros(mesh.geometry.x.shape[0], dtype=np.float64)
bilaplace_prior = prior.Prior(
    mean_vector,
    precision_operator_interface,
    covariance_operator_interface,
    sampling_factor_interface,
    seed=0,
)

In [9]:
test_vector_1 = np.ones_like(mean_vector)
test_vector_2 = 2*np.ones_like(mean_vector)
test_vector_3 = 3*np.ones_like(mean_vector)
cost = bilaplace_prior.evaluate_cost(test_vector_1)
grad = bilaplace_prior.evaluate_gradient(test_vector_2)
hvp = bilaplace_prior.evaluate_hessian_vector_product(test_vector_3)
sample = bilaplace_prior.generate_sample()
print(cost)
print(grad)
print(hvp)
print(sample)

0.4999999999999977
[0.25       0.08333333 0.25       0.5        0.16666667 0.16666667
 0.25       0.25       0.08333333]
[0.375 0.125 0.375 0.75  0.25  0.25  0.375 0.375 0.125]
[-0.3997324  -0.46210551 -0.58363611 -0.55901791 -0.52177595 -0.70408941
 -0.48298646 -0.5273036  -0.45923362]
