In [None]:
import numpy as np
import numpy.linalg as lin
from matplotlib import pyplot as plt
import pandas as pd
from scipy import sparse
import seaborn as sns

from matrix_functions import cheb_interpolation, diagonal_fa, lanczos, lanczos_fa_multi_k

In [None]:
dim = 100
rng = np.random.default_rng(42)
a_diag = np.array(list(range(1, dim+1)))
A = sparse.diags((a_diag), (0))
lambda_min = min(a_diag)
lambda_max = max(a_diag)
kappa = np.abs(lambda_max) / np.abs(lambda_min)
f = lambda x: 1/(x**2)  # or np.power(x, -2.) or something? f(x) = 1/x^2

In [None]:
x = rng.standard_normal(dim)
ground_truth = diagonal_fa(f, a_diag, x)
krylov_basis, _ = lanczos(A, x, reorthogonalize=True)

In [None]:
ks = list(range(1, dim+1))
lanczos_errors = []
krylov_errors = []
our_bound = []
cheb_interpolant_errors = []

for k, lanczos_estimate in zip(ks, lanczos_fa_multi_k(f, A, x, ks=ks)):
    lanczos_errors.append(lin.norm(lanczos_estimate - ground_truth))

    coeffs, _, _, _ = lin.lstsq(krylov_basis[:, :k], ground_truth, rcond=None)
    krylov_error = lin.norm(krylov_basis[:, :k] @ coeffs - ground_truth)
    krylov_errors.append(krylov_error)
    our_bound.append((kappa ** 2) * krylov_error)

    ## WAIT! should this be k-1 \/ ?
    cheb_interpolant = cheb_interpolation(k, f, lambda_min, lambda_max)
    xx = np.linspace(lambda_min, lambda_max, num=500)
    cheb_interpolant_errors.append(2 * lin.norm(x) * max(np.abs(f(z) - cheb_interpolant(z)) for z in xx))

In [None]:
results = pd.DataFrame({
    "Number of matrix products": ks,
    "Error of Lanczos-FA": lanczos_errors,
    "Error of Krylov subspace": krylov_errors,
    "Our bound": our_bound,
    "Error of Chebyshev interpolant *2||x||": cheb_interpolant_errors
})
results_long = pd.melt(results, ["Number of matrix products"], value_name="value")
sns.lineplot(x="Number of matrix products", y="value", hue="variable", style="variable", data=results_long)
plt.yscale('log')

In [None]:
def my_f(z):
    z = (z-10)/2
    return np.abs(z) + z/2 - z**2

a = 8
b = 12

interp = cheb_interpolation(1, my_f, a, b)

xx = np.linspace(a, b, num=500)
plt.plot(xx, my_f(xx))
plt.plot(xx, [interp(z) for z in xx])