In [None]:
import flamp
import numpy as np
import numpy.linalg as lin
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from scipy import sparse
from tqdm import tqdm

import matrix_functions as mf

flamp.set_dps(50)  # compute with this many decimal digits precision

In [None]:
dim = 20
kappa = flamp.gmpy2.mpfr(1_000.)
lambda_min = flamp.gmpy2.mpfr(1.)
lambda_max = kappa * lambda_min

a_diag = mf.utils.linspace(lambda_min, lambda_max, num=dim)

A = mf.DiagonalMatrix(a_diag)

In [None]:
def ritz_sequence(A, x):
    _, (a, b) = mf.lanczos(A, x, reorthogonalize=True)
    ritz_sequence = []
    for i in range(1, dim+1):
        true_ritz, _ = mf.utils.eigh_tridiagonal(a[:i], b[:i-1])
        ritz_sequence.append(true_ritz)
    return ritz_sequence

In [None]:
lhs_mu = a_diag[:-1] + np.diff(a_diag).min()/50
lhs = ritz_sequence(A, mf.start_vec(a_diag, lhs_mu))

mid_mu = a_diag[:-1] + np.diff(a_diag).min()/2
mid = ritz_sequence(A, mf.start_vec(a_diag, mid_mu))

ones = ritz_sequence(A, flamp.ones(dim))

In [None]:
plt.vlines(a_diag.astype(float), 0, dim, colors='k')

for i in range(dim):
    plt.scatter(mid[i], np.full_like(mid[i], i), c='blue', alpha=0.7, s=5)
    plt.scatter(ones[i], np.full_like(ones[i], i), c='orange', alpha=0.7, s=5)
    plt.scatter(lhs[i], np.full_like(lhs[i], i), c='green', alpha=0.7, s=5)


In [None]:
i = dim // 2
plt.plot(mid[i])
plt.plot(ones[i])
plt.plot(lhs[i])

plt.legend(["mu = mid", "b = 1", "mu = LHS"])

In [None]:
for i in range(dim):
    plt.plot(lhs[i])

In [None]:
for i in range(dim):
    plt.plot(ones[i])