In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns

from eikonax import logging, solver, tensorfield, utilities

sns.set_theme(style="ticks")

In [2]:
disable_jit = True
log_compiles = False
mesh_bounds = (-1, 1)
num_points = 3
dimension = 2

solver_data = solver.SolverData(
    tolerance=1e-3,
    max_num_iterations=200,
    loop_type="nonjitted_while",
    max_value=1000,
    softmin_order=5,
    drelu_order=10,
    drelu_cutoff=10,
    log_interval=1,
)

derivator_data = solver.DerivatorData(
    softmin_order=5,
    drelu_order=10,
    drelu_cutoff=1,
)

logger_data = logging.LoggerSettings(
    log_to_console=True,
    logfile_path=None,
)

initial_sites = solver.InitialSites(inds=jnp.array((0,)), values=jnp.array((0,)))

In [None]:
vertices, simplices = utilities.create_test_mesh(mesh_bounds, num_points)
adjacent_vertex_inds = utilities.get_adjacent_vertex_data(simplices, vertices.shape[0])
parameter_vector = jnp.ones(dimension * simplices.shape[0], dtype=jnp.float32)

field_data = tensorfield.LinearTensorFieldData(
    dimension=dimension, num_simplices=simplices.shape[0]
)
with jax.disable_jit(disable_jit), jax.log_compiles(log_compiles):
    tensor_field = tensorfield.LinearTensorField.assemble_field(parameter_vector, field_data)

mesh_data = solver.MeshData(vertices=vertices, adjacent_vertex_inds=adjacent_vertex_inds)
logger = logging.Logger(logger_data)
with jax.disable_jit(disable_jit), jax.log_compiles(log_compiles):
    solution = solver.Solver.run(tensor_field, initial_sites, mesh_data, solver_data, logger)

In [None]:
fig, ax = plt.subplots(figsize=(5, 4), layout="constrained")
ax.triplot(vertices[:, 0], vertices[:, 1], simplices)
scatter_plot = ax.scatter(vertices[:, 0], vertices[:, 1], c=solution.values)
plt.colorbar(scatter_plot)
for i, vertex in enumerate(vertices):
   ax.annotate(i, (vertex[0], vertex[1]))

In [5]:
with jax.disable_jit(disable_jit), jax.log_compiles(log_compiles):
    partial_solution, partial_parameter = solver.Derivator.compute_partial_derivatives(
        solution.values, tensor_field, mesh_data, derivator_data
    )

In [13]:
current_inds = mesh_data.adjacent_vertex_inds[:, 0, 0]
adjacent_inds = mesh_data.adjacent_vertex_inds[:, :, 1:3]
nonzero_mask = jnp.nonzero(partial_solution)

print(current_inds[nonzero_mask[0]])
print(adjacent_inds[nonzero_mask])
print(partial_solution[nonzero_mask])

[0 0 1 2 3 4 4 5 5 6 7 7 8 8]
[3 1 0 1 0 1 3 1 1 3 3 3 5 7]
[0.5000046 0.4999954 1.        1.        1.        0.5000046 0.4999954
 0.5       0.5       1.        0.5       0.5       0.5000046 0.4999954]
