In [1]:
import jax
import jax.numpy as jnp
from icosphere import icosphere
from discrete_exterior_calculus import DEC

jax.config.update("jax_enable_x64", True)

# nu:       1   2   3   4    5    6    7    8    9    10
# vertices: 12, 42, 92, 162, 252, 362, 492, 642, 812, 1002
nu = 1
vertices, faces = icosphere(nu=nu)
n = len(vertices)
mesh = DEC.Mesh(vertices, faces)

zmost_point = jnp.argmax(vertices[:, 2])
zleast_point = jnp.argmin(vertices[:, 2])
xmost_point = jnp.argmax(vertices[:, 0])
xleast_point = jnp.argmin(vertices[:, 0])
ymost_point = jnp.argmax(vertices[:, 1])
yleast_point = jnp.argmin(vertices[:, 1])

In [2]:
from persistent_storage import (
    get_value,
    set_value,
    remove_value,
    wipe_db,
    experiment_setup,
    build_experiment_name,
)

import itertools
import probabilistic_solve_icosphere

problem = "wave and tanh"

data = experiment_setup[problem]
dbname, priors, derivatives, timesteps, problem_title = (
    data["dbname"],
    data["priors"],
    data["derivatives"],
    data["timesteps"],
    data["problem_title"],
)

wipe_db(dbname)

diffrax_sol = get_value("diffrax_sol", None, filename=dbname)
if diffrax_sol is None or True:
    print("Calculating diffrax solution")

    from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Kvaerno5

    if problem == "heat and tanh":

        def vector_field(t, y, args):
            u = y[:n]
            return mesh.laplace_matrix @ u + jnp.tanh(mesh.laplace_matrix @ u)

    if problem == "heat":

        def vector_field(t, y, args):
            u = y[:n]
            return mesh.laplace_matrix @ u

    if problem == "heat small tanh":

        def vector_field(t, y, args):
            u = y[:n]
            return mesh.laplace_matrix @ u + 0.1 * jnp.tanh(mesh.laplace_matrix @ u)

    if problem == "wave":

        def vector_field(t, y, args):
            u = y[:n]
            v = y[n:]
            du_dt = v
            dv_dt = mesh.laplace_matrix @ u
            return jnp.concatenate([du_dt, dv_dt])

    if problem == "wave and tanh":

        def vector_field(t, y, args):
            u = y[:n]
            v = y[n:]
            du_dt = v
            dv_dt = jnp.tanh(mesh.laplace_matrix @ u) + mesh.laplace_matrix @ u
            return jnp.concatenate([du_dt, dv_dt])

    # Define the ODE term
    term = ODETerm(vector_field)

    solver = Kvaerno5()
    saveat = SaveAt(ts=jnp.linspace(0, 10, 100))
    stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)

    # Initial conditions
    u0 = jnp.zeros(n)  # Initial condition for u (e.g., zero displacement)
    v0 = jnp.zeros(n)  # Initial condition for v (e.g., zero velocity)
    u0 = u0.at[ymost_point].set(2.0)  # Example: Displace the middle point
    u0 = u0.at[yleast_point].set(-2.0)  # Example: Displace the middle point

    if problem in ["heat", "heat and tanh", "heat small tanh"]:
        y0 = u0
    if problem in ["wave", "wave and tanh"]:
        y0 = jnp.concatenate([u0, v0])

    # Solve the system
    sol = diffeqsolve(
        term,
        solver,
        t0=0,
        t1=10,
        dt0=0.01,
        y0=y0,
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        max_steps=50000,
    )

    diffrax_sol = sol.ys[:, zleast_point]  # Displacement solutions over time
    set_value("diffrax_sol", diffrax_sol, filename=dbname)
    print("saved")
    set_value("diffrax_sol_steps", sol.stats["num_steps"].item(), filename=dbname)

product = list(itertools.product(priors, derivatives, timesteps))

Calculating diffrax solution
saved


In [3]:
import time
import numpy as np

import rich.progress
import rich


progress_bar = rich.progress.track(total=len(product), sequence=product)
for prior, q, timestep in progress_bar:
    experiment_name = build_experiment_name(prior, q, timestep)
    print(f"Running experiment: {experiment_name}")

    means, stds, runtime, rmse = get_value(experiment_name, [None] * 4, dbname)
    if means is None:
        fastest_time = 1e9
        for _ in range(1):
            start_time = time.time()
            try:
                means, stds = probabilistic_solve_icosphere.solve(
                    isosphere_nu=nu,
                    timesteps=timestep,
                    derivatives=q,
                    prior=prior,
                    problem=problem,
                )
                means = means[:, zleast_point]
                stds = stds[:, zleast_point]
            except Exception as err:
                print(f"Experiment {experiment_name} failed because of {err}")
            end_time = time.time()
            if end_time - start_time < fastest_time:
                fastest_time = end_time - start_time
        diff = means - diffrax_sol
        rmse = jnp.sqrt(jnp.mean(diff**2))
        means = means.astype(np.float32)
        stds = stds.astype(np.float32)
        set_value(
            experiment_name,
            (
                means.astype(np.float32),
                stds.astype(np.float32),
                fastest_time,
                rmse,
            ),
            filename=dbname,
        )
    else:
        continue

Output()