In [None]:
import jax.numpy as jnp

import matplotlib.pyplot as plt

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

## Full Kernel

In [None]:
order = 3

In [None]:
def plot_data(full):
    if full == True:
        n_datapoints_1 = jnp.arange(1,50)*10
        n_datapoints_2 = jnp.arange(500, 1001, 100)
        n_datapoints_3 = jnp.arange(1200, 3000, 200)
        avg_times_1 = jnp.load("full_time_10_500_10.npy")
        avg_times_2 = jnp.load("full_time_500_1001_100.npy")
        avg_times_3 = jnp.load("full_time_1200_3000_200.npy")

        n_datapoints = jnp.hstack((n_datapoints_1, n_datapoints_2, n_datapoints_3))
        avg_times = jnp.hstack((avg_times_1, avg_times_2, avg_times_3))
    else:
        n_datapoints_1 = jnp.arange(1,50)*10
        n_datapoints_2 = jnp.arange(500, 1001, 100)
        n_datapoints_3 = jnp.arange(1200, 3000, 200)
        avg_times_1 = jnp.load("sparse_time_10_500_10.npy")
        avg_times_2 = jnp.load("sparse_time_500_1001_100.npy")
        avg_times_3 = jnp.load("sparse_time_1200_3000_200.npy")

        n_datapoints = jnp.hstack((n_datapoints_1, n_datapoints_2, n_datapoints_3))
        avg_times = jnp.hstack((avg_times_1, avg_times_2, avg_times_3))
    
    return n_datapoints, avg_times

In [None]:
n_dp_full, avg_full = plot_data(True)

lm = LinearRegression()
lm.fit(n_dp_full.reshape(-1,1), avg_full**(1/order))
line_full = lm.predict(n_dp_full.reshape(-1,1))

print(lm.score(n_dp_full.reshape(-1,1), avg_full**(1/order)))

In [None]:

n_dp_sparse, avg_sparse = plot_data(False)

lm = LinearRegression()
lm.fit(n_dp_sparse.reshape(-1,1), avg_sparse**(1/order))
line_sparse = lm.predict(n_dp_sparse.reshape(-1,1))

print(lm.score(n_dp_sparse.reshape(-1,1), avg_sparse**(1/order)))

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(9,4))

ax[0].plot(n_dp_full,avg_full, label="avg_time")
ax[0].plot(n_dp_full, line_full**order, label="$\mathcal{O}(n^"+f"{order})$")
ax[0].plot(n_dp_sparse,avg_sparse, label="avg_time")
ax[0].plot(n_dp_sparse, line_sparse**order, label="$\mathcal{O}(n^"+f"{order})$")

ax[1].plot(n_dp_full,avg_full**(1/order), label="avg_time$^{(1/"+f"{order}"+")}$")
ax[1].plot(n_dp_full, line_full, label="$\mathcal{O}(n^"+f"{order})"+"^{(1/"+f"{order}"+")}$")
ax[1].plot(n_dp_sparse,avg_sparse**(1/order), label="avg_time$^{(1/"+f"{order}"+")}$")
ax[1].plot(n_dp_sparse, line_sparse, label="$\mathcal{O}(n^"+f"{order})"+"^{(1/"+f"{order}"+")}$")

for i in range(2):
    ax[i].grid()
    ax[i].set_xlabel("num datapoints")
    ax[i].legend()

ax[0].set_ylabel("average inversion time [s]")