In [None]:
import dolfinx as dlx
import numpy as np
from mpi4py import MPI
from petsc4py import PETSc

from ls_prior import components, fem, prior

In [None]:
kappa, tau = 1.0, 1.0
nx, ny = 2, 2
mpi_communicator = MPI.COMM_WORLD

cg_solver_settings = components.InverseMatrixSolverSettings(
    solver_type=PETSc.KSP.Type.CG,
    preconditioner_type=PETSc.PC.Type.JACOBI,
    relative_tolerance=1e-6,
    absolute_tolerance=1e-8,
    max_num_iterations=10000,
)

amg_solver_settings = components.InverseMatrixSolverSettings(
    solver_type=PETSc.KSP.Type.PREONLY,
    preconditioner_type=PETSc.PC.Type.GAMG,
    relative_tolerance=1e-6,
    absolute_tolerance=1e-8,
    max_num_iterations=10000,
)

In [None]:
mesh = dlx.mesh.create_rectangle(
    mpi_communicator,
    [np.array([0, 0]), np.array([1, 1])],
    [nx, ny],
    dlx.mesh.CellType.triangle,
)

In [None]:
fem_handler = fem.FEMHandler(mesh, ("CG", 1))
mass_matrix_form, spde_matrix_form = fem_handler.generate_forms(kappa, tau)
mass_matrix = fem_handler.assemble_matrix(mass_matrix_form)
spde_matrix = fem_handler.assemble_matrix(spde_matrix_form)
mass_matrix_factorization = fem.FEMMatrixBlockFactorization(
    mesh, fem_handler.function_space, mass_matrix_form
)
mass_matrix_factor = mass_matrix_factorization.assemble()

In [None]:
mass_matrix_inverse = components.InverseMatrixSolver(
    cg_solver_settings, mass_matrix, mpi_communicator
)
spde_matrix_inverse = components.InverseMatrixSolver(
    amg_solver_settings, spde_matrix, mpi_communicator
)

In [None]:
precision_operator = components.BilaplacianPrecision(spde_matrix, mass_matrix_inverse)
covariance_operator = components.BilaplacianCovariance(mass_matrix, spde_matrix_inverse)
sampling_factor = components.BilaplacianCovarianceFactor(mass_matrix_factor, spde_matrix_inverse)

In [None]:
num_dofs = fem_handler.function_space.dofmap.index_map.size_local
mean_vector = np.zeros(num_dofs, dtype=np.float64)
bilaplace_prior = prior.Prior(
    mean_vector,
    precision_operator,
    covariance_operator,
    sampling_factor,
    dimension=num_dofs,
    seed=0,
)

In [None]:
test_vector = np.random.rand(bilaplace_prior.dimension)
cost = bilaplace_prior.evaluate_cost(test_vector)
grad = bilaplace_prior.evaluate_gradient(test_vector)
hvp = bilaplace_prior.evaluate_hessian_vector_product(test_vector)
sample = bilaplace_prior.generate_sample()