In [None]:
import jax.numpy as jnp

import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression

In [None]:
def forward(x):
    return jnp.sign(x)*jnp.abs(x)**(1/3)

def inverse(x):
    return x**3

from matplotlib.scale import FuncScale

## Compare full with different levels of sparsification

In [None]:
start_ind = 5

full = jnp.load("./data/full_time_[128, 256, 512, 1024, 2048, 4096, 8192].npy")

dp = full[0]
t_mean = full[1]
t_std = full[2]

lm = LinearRegression()
log_dp = jnp.log(dp[start_ind:])
log_t_full = jnp.log(t_mean[start_ind:])

lm.fit(log_dp.reshape(-1,1), log_t_full.reshape(-1,1))
print(lm.coef_, lm.intercept_)

def line(x):
    log_x = jnp.log(x).reshape(-1,1)
    return jnp.exp(lm.predict(log_x) - 0.3)

line_full = line(dp[2:]).reshape(-1)

In [None]:
plt.plot(dp, t_mean, "b", lw=0.5, marker="x", label="full model")
plt.fill_between(dp, t_mean-t_std, t_mean+t_std, color="b", alpha=0.2)
plt.plot(dp[2:], line_full, "k--", label="$\mathcal{O}(N^3)$")


plt.xlabel("N number of training points")
plt.ylabel("runtime prior creation and inversion")
plt.xscale("log")
plt.yscale("log")
plt.grid()
plt.legend()

In [None]:
start_ind = 5

sparse = jnp.load("./data/sparse_time_[128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]_ref512.npy")

dp = sparse[0,2:]
t_mean = sparse[1,2:]
t_std = sparse[2,2:]

lm = LinearRegression()
log_dp = jnp.log(dp[start_ind:])
log_t_sparse = jnp.log(t_mean[start_ind:])

lm.fit(log_dp.reshape(-1,1), log_t_sparse.reshape(-1,1))
print(lm.coef_, lm.intercept_)

def line(x):
    log_x = jnp.log(x).reshape(-1,1)
    return jnp.exp(lm.predict(log_x) - 0.3)

line_sparse = line(dp[2:]).reshape(-1)

In [None]:
plt.plot(dp, t_mean, "r--", marker="x", label="M = 512")
plt.fill_between(dp, t_mean-t_std, t_mean+t_std, color="r", alpha=0.2)
plt.plot(dp[2:], line_sparse, "k--", label="$\mathcal{O}(N)$")


plt.xlabel("N number of training points")
plt.ylabel("runtime prior creation and inversion")
plt.xscale("log")
plt.yscale("log")
plt.grid()
plt.legend()

In [None]:
start_ind = 3

sparse_fixed = jnp.load("./data/sparse_time_[512, 1024, 2048, 4096, 8192, 16384]_max16384.npy")

dp = sparse_fixed[0]
t_mean = sparse_fixed[1]
t_std = sparse_fixed[2]

lm = LinearRegression()
log_dp = jnp.log(dp[start_ind:])
log_t_sparse_fixed = jnp.log(t_mean[start_ind:])

lm.fit(log_dp.reshape(-1,1), log_t_sparse_fixed.reshape(-1,1))
print(lm.coef_, lm.intercept_)

def line(x):
    log_x = jnp.log(x).reshape(-1,1)
    return jnp.exp(lm.predict(log_x) - 0.3)

line_sparse_fixed = line(dp[0:]).reshape(-1)

In [None]:
plt.plot(dp, t_mean, "g--", marker="x", label="$N=16384$")
plt.fill_between(dp, t_mean-t_std, t_mean+t_std, color="g", alpha=0.2)
plt.plot(dp[0:], line_sparse_fixed, "k--", label="$\mathcal{O}(N)$")

plt.xlabel("M number of reference points")
plt.ylabel("runtime prior creation and inversion")
plt.xscale("log")
plt.yscale("log")
plt.grid()
plt.legend()

In [None]:
# full_m = full[1] - jnp.min(full[1]) + 1e-4
# sparse_m = sparse[1,2:] - jnp.min(sparse[1,2:]) + 1e-4
# sparse_fixed_m = sparse_fixed[1] - jnp.min(sparse_fixed[1]) + 1e-4

full_m = full[1] / jnp.linalg.norm(full[1])
sparse_m = sparse[1,2:] / jnp.linalg.norm(sparse[1,2:])
sparse_fixed_m = sparse_fixed[1] / jnp.linalg.norm(sparse_fixed[1])

In [None]:
plt.plot(full[0], full[1], "b--", marker="x", label="full model")
plt.fill_between(full[0], full[1]-full[2], full[1]+full[2], color="b", alpha=0.2)
plt.plot(full[0,2:], line_full, "c--", lw=0.8, label="$\mathcal{O}(N^3)$")

plt.plot(sparse[0,2:], sparse[1,2:], "r--", marker="x", label="sparse/fixed M")
plt.fill_between(sparse[0,2:], sparse[1,2:]-sparse[2,2:], sparse[1,2:]+sparse[2,2:], color="r", alpha=0.2)
plt.plot(sparse[0,4:], line_sparse, "m--", lw=0.8, label="$\mathcal{O}(N)$")

plt.plot(sparse_fixed[0], sparse_fixed[1], "g--", marker="x", label="sparse/fixed N")
plt.fill_between(sparse_fixed[0], sparse_fixed[1]-sparse_fixed[2], sparse_fixed[1]+sparse_fixed[2], color="g", alpha=0.2)
plt.plot(sparse_fixed[0], line_sparse_fixed, "y--", lw=0.8, label="$\mathcal{O}(M^2)$")

# plt.plot(full[0], full_m, "b--", marker="x")
# plt.fill_between(full[0], full_m-full[2], full_m+full[2], color="b", alpha=0.2)

# plt.plot(sparse[0,2:], sparse_m, "r--", marker="x")
# plt.fill_between(sparse[0,2:], sparse_m-sparse[2,2:], sparse_m+sparse[2,2:], color="r", alpha=0.2)

# plt.plot(sparse_fixed[0], sparse_fixed_m, "g--", marker="x")
# plt.fill_between(sparse_fixed[0], sparse_fixed_m-sparse_fixed[2], sparse_fixed_m+sparse_fixed[2], color="g", alpha=0.2)


plt.xscale("log")
plt.yscale("log")
plt.grid()
plt.legend()