In [1]:
import jax
import jax.numpy as jnp
from icosphere import icosphere
from discrete_exterior_calculus import DEC

jax.config.update("jax_enable_x64", True)

# nu:       1   2   3   4    5    6    7    8    9    10
# vertices: 12, 42, 92, 162, 252, 362, 492, 642, 812, 1002
nu = 1
vertices, faces = icosphere(nu=nu)
n = len(vertices)
mesh = DEC.Mesh(vertices, faces)

zmost_point = jnp.argmax(vertices[:, 2])
zleast_point = jnp.argmin(vertices[:, 2])
xmost_point = jnp.argmax(vertices[:, 0])
xleast_point = jnp.argmin(vertices[:, 0])
ymost_point = jnp.argmax(vertices[:, 1])
yleast_point = jnp.argmin(vertices[:, 1])

In [2]:
from persistent_storage import (
    get_value,
    set_value,
    remove_value,
    wipe_db,
    experiment_setup,
    build_experiment_name,
    choice,
)

import itertools
import probabilistic_solve_icosphere

problem_name = choice

data = experiment_setup[problem_name]
priors, prior_scales, derivatives, timesteps, problem_title, vf, order = (
    data["priors"],
    data["prior_scale"],
    data["derivatives"],
    data["timesteps"],
    data["problem_title"],
    data["vector_field"],
    data["order"],
)
if jnp.min(jnp.asarray(derivatives)) < order:
    raise ValueError(
        f"Derivatives {derivatives} must be at least {order} for order {order} method"
    )

wipe_db(problem_name)


#### Reference `diffrax` solution:

In [3]:
diffrax_sol = get_value("diffrax_sol", None, filename=problem_name)
if diffrax_sol is None or True:
    print("Calculating diffrax solution")

    from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Kvaerno5

    if order == 1:

        def vector_field(t, y, args):
            return vf(y, mesh.laplace_matrix)

    if order == 2:

        def vector_field(t, y, args):
            return jnp.concatenate(
                (y[n : 2 * n], vf(y[:n], y[n : 2 * n], mesh.laplace_matrix))
            )

    u0 = jnp.zeros(n)
    u0 = u0.at[ymost_point].set(1.0)
    u0 = u0.at[yleast_point].set(-1.0)

    if order == 1:
        y0 = u0
    if order == 2:
        y0 = jnp.concatenate([u0, jnp.zeros(n)])

    # Solve the system
    sol = diffeqsolve(
        ODETerm(vector_field),
        Kvaerno5(),
        t0=0,
        t1=10,
        dt0=0.01,
        y0=y0,
        saveat=SaveAt(ts=jnp.linspace(0, 10, 100)),
        stepsize_controller=PIDController(rtol=1e-8, atol=1e-8),
        max_steps=50000,
    )

    diffrax_sol = sol.ys[:, zleast_point]  # Displacement solutions over time
    set_value("diffrax_sol", diffrax_sol, filename=problem_name)
    print("saved")
    set_value("diffrax_sol_steps", sol.stats["num_steps"].item(), filename=problem_name)


Calculating diffrax solution
saved


In [4]:
import time
import numpy as np
import tqdm

product = list(itertools.product(priors, derivatives, timesteps))

iter = tqdm.tqdm(product)

for prior, q, timestep in iter:
    experiment_name = build_experiment_name(prior, q, timestep)
    iter.set_description(f"Running experiment: {experiment_name}")

    means, stds, runtime, rmse = get_value(experiment_name, [None] * 4, problem_name)
    if means is None:
        fastest_time = 1e9
        for _ in range(1):
            start_time = time.time()
            means, stds = probabilistic_solve_icosphere.solve(
                isosphere_nu=nu,
                n_solution_points=timestep,
                derivatives=q,
                prior_type=prior,
                prior_scale=prior_scales[0],  # TODO
                vector_field=vf,
                order=order,
            )
            try:
                means = means[:, zleast_point]
                stds = stds[:, zleast_point]
            except Exception as err:
                print(f"Experiment {experiment_name} failed because of {err}")
            end_time = time.time()
            if end_time - start_time < fastest_time:
                fastest_time = end_time - start_time
        diff = means - diffrax_sol
        rmse = jnp.sqrt(jnp.mean(diff[:-1] ** 2))
        means = means.astype(np.float32)
        stds = stds.astype(np.float32)
        set_value(
            experiment_name,
            (
                means.astype(np.float32),
                stds.astype(np.float32),
                fastest_time,
                rmse,
            ),
            filename=problem_name,
        )
    else:
        continue

Running experiment: iwp_4_3162: 100%|██████████| 60/60 [01:10<00:00,  1.17s/it] 
