In [1]:
import jax
import jax.numpy as jnp
from icosphere import icosphere
from discrete_exterior_calculus import DEC

jax.config.update("jax_enable_x64", True)

# nu:       1   2   3   4    5    6    7    8    9    10
# vertices: 12, 42, 92, 162, 252, 362, 492, 642, 812, 1002
nu = 1
vertices, faces = icosphere(nu=nu)
n = len(vertices)
mesh = DEC.Mesh(vertices, faces)

zmost_point = jnp.argmax(vertices[:, 2])
zleast_point = jnp.argmin(vertices[:, 2])
xmost_point = jnp.argmax(vertices[:, 0])
xleast_point = jnp.argmin(vertices[:, 0])
ymost_point = jnp.argmax(vertices[:, 1])
yleast_point = jnp.argmin(vertices[:, 1])

In [2]:
from persistent_storage import (
    get_value,
    set_value,
    remove_value,
    wipe_db,
    experiment_setup,
    build_experiment_name,
    choice,
)

import itertools
import probabilistic_solve_icosphere

problem_name = choice

data = experiment_setup[problem_name]
priors, prior_scales, derivatives, timesteps, problem_title, vf, order = (
    data["priors"],
    data["prior_scale"],
    data["derivatives"],
    data["timesteps"],
    data["problem_title"],
    data["vector_field"],
    data["order"],
)
if jnp.min(jnp.asarray(derivatives)) < order:
    raise ValueError(
        f"Derivatives {derivatives} must be at least {order} for order {order} method"
    )

wipe_db(problem_name)


ValueError: Derivatives [1, 2, 3] must be at least 2 for order 2 method

#### Reference `diffrax` solution:

In [None]:
diffrax_sol = get_value("diffrax_sol", None, filename=problem_name)
if diffrax_sol is None or True:
    print("Calculating diffrax solution")

    from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Kvaerno5

    if order == 1:

        def vector_field(t, y, args):
            return vf(y, mesh.laplace_matrix)

    if order == 2:

        def vector_field(t, y, args):
            return jnp.concatenate(
                (y[n : 2 * n], vf(y[:n], y[n : 2 * n], mesh.laplace_matrix))
            )

    u0 = jnp.zeros(n)
    u0 = u0.at[ymost_point].set(1.0)
    u0 = u0.at[yleast_point].set(-1.0)

    if order == 1:
        y0 = u0
    if order == 2:
        y0 = jnp.concatenate([u0, jnp.zeros(n)])

    # Solve the system
    sol = diffeqsolve(
        ODETerm(vector_field),
        Kvaerno5(),
        t0=0,
        t1=10,
        dt0=0.01,
        y0=y0,
        saveat=SaveAt(ts=jnp.linspace(0, 10, 1500, endpoint=True)),
        stepsize_controller=PIDController(rtol=1e-8, atol=1e-8),
        max_steps=50000,
    )

    diffrax_sol = sol.ys[:, zleast_point]  # Displacement solutions over time
    set_value("diffrax_sol", diffrax_sol, filename=problem_name)
    print("saved")
    set_value("diffrax_sol_steps", sol.stats["num_steps"].item(), filename=problem_name)


Calculating diffrax solution
saved


In [None]:
import time
import numpy as np
import tqdm

product = list(itertools.product(priors, derivatives, timesteps))

iter = tqdm.tqdm(product)

for prior, q, timestep in iter:
    experiment_name = build_experiment_name(prior, q, timestep)
    iter.set_description(f"Running experiment: {experiment_name}")

    means, stds, runtime, rmse = get_value(experiment_name, [None] * 4, problem_name)
    if means is None:
        fastest_time = 1e9
        for _ in range(1):
            start_time = time.time()
            means, stds = probabilistic_solve_icosphere.solve(
                isosphere_nu=nu,
                n_solution_points=timestep,
                derivatives=q,
                prior_type=prior,
                prior_scale=prior_scales[0],  # TODO
                vector_field=vf,
                order=order,
            )
            try:
                means = means[:, zleast_point]
                stds = stds[:, zleast_point]
            except Exception as err:
                print(f"Experiment {experiment_name} failed because of {err}")
            end_time = time.time()
            if end_time - start_time < fastest_time:
                fastest_time = end_time - start_time
        diff = means - diffrax_sol
        rmse = jnp.sqrt(jnp.mean(diff[:-1] ** 2))
        means = means.astype(np.float32)
        stds = stds.astype(np.float32)
        set_value(
            experiment_name,
            (
                means.astype(np.float32),
                stds.astype(np.float32),
                fastest_time,
                rmse,
            ),
            filename=problem_name,
        )
    else:
        continue

Running experiment: wave_2_10:   0%|          | 0/60 [00:00<?, ?it/s]

laplace


Running experiment: wave_2_18:   2%|▏         | 1/60 [00:02<02:03,  2.09s/it]

laplace


Running experiment: wave_2_35:   3%|▎         | 2/60 [00:03<01:21,  1.40s/it]

laplace


Running experiment: wave_2_68:   5%|▌         | 3/60 [00:04<01:12,  1.27s/it]

laplace


Running experiment: wave_2_129:   7%|▋         | 4/60 [00:05<01:08,  1.22s/it]

laplace


Running experiment: wave_2_244:   8%|▊         | 5/60 [00:06<01:05,  1.19s/it]

laplace


Running experiment: wave_2_464:  10%|█         | 6/60 [00:07<01:10,  1.31s/it]

laplace


Running experiment: wave_2_879:  12%|█▏        | 7/60 [00:09<01:05,  1.24s/it]

laplace


Running experiment: wave_2_1668:  13%|█▎        | 8/60 [00:10<01:10,  1.36s/it]

laplace


Running experiment: wave_2_3162:  15%|█▌        | 9/60 [00:13<01:34,  1.86s/it]

laplace


Running experiment: wave_3_10:  17%|█▋        | 10/60 [00:19<02:31,  3.03s/it]  

laplace


Running experiment: wave_3_18:  18%|█▊        | 11/60 [00:21<02:11,  2.68s/it]

laplace


Running experiment: wave_3_35:  20%|██        | 12/60 [00:22<01:45,  2.20s/it]

laplace


Running experiment: wave_3_68:  22%|██▏       | 13/60 [00:23<01:27,  1.87s/it]

laplace


Running experiment: wave_3_129:  23%|██▎       | 14/60 [00:24<01:16,  1.66s/it]

laplace


Running experiment: wave_3_244:  25%|██▌       | 15/60 [00:25<01:11,  1.58s/it]

laplace


Running experiment: wave_3_464:  27%|██▋       | 16/60 [00:27<01:13,  1.67s/it]

laplace


Running experiment: wave_3_879:  28%|██▊       | 17/60 [00:30<01:22,  1.91s/it]

laplace


Running experiment: wave_3_1668:  30%|███       | 18/60 [00:35<01:58,  2.82s/it]

laplace


Running experiment: wave_3_3162:  32%|███▏      | 19/60 [00:46<03:45,  5.50s/it]

laplace


Running experiment: wave_4_10:  33%|███▎      | 20/60 [01:05<06:20,  9.51s/it]  

laplace


Running experiment: wave_4_18:  35%|███▌      | 21/60 [01:07<04:43,  7.27s/it]

laplace


Running experiment: wave_4_35:  37%|███▋      | 22/60 [01:09<03:27,  5.45s/it]

laplace


Running experiment: wave_4_68:  38%|███▊      | 23/60 [01:10<02:34,  4.18s/it]

laplace


Running experiment: wave_4_129:  40%|████      | 24/60 [01:12<02:03,  3.44s/it]

laplace


Running experiment: wave_4_244:  42%|████▏     | 25/60 [01:14<01:47,  3.08s/it]

laplace


Running experiment: wave_4_464:  43%|████▎     | 26/60 [01:17<01:45,  3.09s/it]

laplace


Running experiment: wave_4_879:  45%|████▌     | 27/60 [01:24<02:22,  4.30s/it]

laplace


Running experiment: wave_4_1668:  47%|████▋     | 28/60 [01:36<03:32,  6.63s/it]

laplace


Running experiment: wave_4_3162:  48%|████▊     | 29/60 [01:57<05:40, 10.97s/it]

laplace


Running experiment: iwp_2_10:  50%|█████     | 30/60 [02:38<09:54, 19.81s/it]   

laplace


Running experiment: iwp_2_18:  52%|█████▏    | 31/60 [02:38<06:46, 14.00s/it]

laplace


Running experiment: iwp_2_35:  53%|█████▎    | 32/60 [02:38<04:37,  9.92s/it]

laplace


Running experiment: iwp_2_68:  55%|█████▌    | 33/60 [02:39<03:11,  7.08s/it]

laplace


Running experiment: iwp_2_129:  57%|█████▋    | 34/60 [02:39<02:13,  5.14s/it]

laplace


Running experiment: iwp_2_244:  58%|█████▊    | 35/60 [02:40<01:33,  3.75s/it]

laplace


Running experiment: iwp_2_464:  60%|██████    | 36/60 [02:41<01:07,  2.81s/it]

laplace


Running experiment: iwp_2_879:  62%|██████▏   | 37/60 [02:42<00:51,  2.23s/it]

laplace


Running experiment: iwp_2_1668:  63%|██████▎   | 38/60 [02:43<00:44,  2.01s/it]

laplace


Running experiment: iwp_2_3162:  65%|██████▌   | 39/60 [02:46<00:45,  2.17s/it]

laplace


Running experiment: iwp_3_10:  67%|██████▋   | 40/60 [02:50<00:59,  2.99s/it]  

laplace


Running experiment: iwp_3_18:  68%|██████▊   | 41/60 [02:51<00:43,  2.27s/it]

laplace


Running experiment: iwp_3_35:  70%|███████   | 42/60 [02:52<00:31,  1.77s/it]

laplace


Running experiment: iwp_3_68:  72%|███████▏  | 43/60 [02:52<00:24,  1.41s/it]

laplace


Running experiment: iwp_3_129:  73%|███████▎  | 44/60 [02:53<00:18,  1.15s/it]

laplace


Running experiment: iwp_3_244:  75%|███████▌  | 45/60 [02:54<00:16,  1.09s/it]

laplace


Running experiment: iwp_3_464:  77%|███████▋  | 46/60 [02:55<00:17,  1.26s/it]

laplace


Running experiment: iwp_3_879:  78%|███████▊  | 47/60 [02:58<00:22,  1.70s/it]

laplace


Running experiment: iwp_3_1668:  80%|████████  | 48/60 [03:02<00:28,  2.39s/it]

laplace


Running experiment: iwp_3_3162:  82%|████████▏ | 49/60 [03:10<00:43,  4.00s/it]

laplace


Running experiment: iwp_4_10:  83%|████████▎ | 50/60 [03:24<01:09,  6.97s/it]  

laplace


Running experiment: iwp_4_18:  85%|████████▌ | 51/60 [03:25<00:46,  5.16s/it]

laplace


Running experiment: iwp_4_35:  87%|████████▋ | 52/60 [03:25<00:30,  3.80s/it]

laplace


Running experiment: iwp_4_68:  88%|████████▊ | 53/60 [03:26<00:20,  2.86s/it]

laplace


Running experiment: iwp_4_129:  90%|█████████ | 54/60 [03:28<00:14,  2.49s/it]

laplace


Running experiment: iwp_4_244:  92%|█████████▏| 55/60 [03:29<00:10,  2.07s/it]

laplace


Running experiment: iwp_4_464:  93%|█████████▎| 56/60 [03:32<00:10,  2.55s/it]

laplace


Running experiment: iwp_4_879:  95%|█████████▌| 57/60 [03:39<00:11,  3.75s/it]

laplace


Running experiment: iwp_4_1668:  97%|█████████▋| 58/60 [03:50<00:11,  5.95s/it]

laplace


Running experiment: iwp_4_3162:  98%|█████████▊| 59/60 [04:11<00:10, 10.54s/it]

laplace


Running experiment: iwp_4_3162: 100%|██████████| 60/60 [04:51<00:00,  4.86s/it]
