In [9]:
import numpy as np
import scipy.linalg as sla
import scipy.sparse.linalg as spla
import matplotlib.pyplot as plt

In [17]:
N = 500
a_reg = 2.e0

tmp = np.random.randn(N,N)
# Hd_tilde = sla.sqrtm(tmp.T @ tmp)
Hd_tilde = 0.5 * (tmp.T + tmp)

tmp2 = np.random.randn(N,N)
Hr0 = sla.sqrtm(tmp.T @ tmp)
Hr = a_reg * Hr0

H_tilde = Hd_tilde + Hr
ee, P = sla.eigh(H_tilde, Hr)

print(ee)

[-5.16079650e+00 -2.35662023e+00 -2.22846821e+00 -1.74551760e+00
 -1.07128117e+00 -7.31605834e-01 -7.13271968e-01 -6.21994788e-01
 -5.33669419e-01 -4.59146340e-01 -3.16005958e-01 -2.86602158e-01
 -2.70946301e-01 -2.34156999e-01 -1.84584327e-01 -1.27982875e-01
 -8.58325855e-02 -7.56706073e-02 -3.86213303e-02 -3.96667872e-03
  2.03038518e-02  5.88930642e-02  6.22636167e-02  7.87304841e-02
  8.98375352e-02  1.16103466e-01  1.25342643e-01  1.40811374e-01
  1.48542321e-01  1.67787323e-01  1.80479218e-01  1.93189656e-01
  1.99487181e-01  2.02250863e-01  2.11972527e-01  2.21631586e-01
  2.33537946e-01  2.42763884e-01  2.51458038e-01  2.67004699e-01
  2.71830604e-01  2.83192882e-01  2.84208066e-01  2.91623340e-01
  2.98675000e-01  3.00611362e-01  3.07853891e-01  3.14611389e-01
  3.19038394e-01  3.24630254e-01  3.28338854e-01  3.32622828e-01
  3.37625048e-01  3.43451993e-01  3.45246666e-01  3.49553797e-01
  3.58938656e-01  3.63176390e-01  3.66310296e-01  3.71000518e-01
  3.74825716e-01  3.79585

In [18]:
iH_tilde = np.linalg.inv(H_tilde)
iHr = np.linalg.inv(Hr)

iee, iP = sla.eigh(iH_tilde, iHr)

print(iee)

[-2.52100075e+02 -2.58924276e+01 -1.32151708e+01 -1.16505869e+01
 -7.81354538e+00 -5.41757807e+00 -4.27063895e+00 -3.69076823e+00
 -3.48915726e+00 -3.16449729e+00 -2.17795486e+00 -1.87381919e+00
 -1.60773051e+00 -1.40198977e+00 -1.36685624e+00 -9.33461755e-01
 -5.72895971e-01 -4.48738732e-01 -4.24336508e-01 -1.93768539e-01
  1.40335221e-01  2.29505281e-01  2.34883142e-01  2.65458039e-01
  3.24841042e-01  3.65237221e-01  3.68725562e-01  3.76324898e-01
  3.92265020e-01  4.09638117e-01  4.28537711e-01  4.32647300e-01
  4.41181863e-01  4.42935837e-01  4.62065756e-01  4.70601432e-01
  4.76030488e-01  4.81806975e-01  4.90351639e-01  4.99085143e-01
  5.03995612e-01  5.08261619e-01  5.15305555e-01  5.17052017e-01
  5.26576390e-01  5.28656229e-01  5.37694013e-01  5.39545994e-01
  5.40435320e-01  5.45397032e-01  5.48224298e-01  5.51530782e-01
  5.55113074e-01  5.58258967e-01  5.58984897e-01  5.62504603e-01
  5.64453837e-01  5.65332681e-01  5.70662910e-01  5.72764692e-01
  5.77014292e-01  5.77471

In [20]:
H_tilde_linop = spla.LinearOperator((N,N), matvec=lambda x: H_tilde @ x) 
iH_tilde_linop = spla.LinearOperator((N,N), matvec=lambda x: iH_tilde @ x) 
Hr_linop = spla.LinearOperator((N,N), matvec=lambda x: Hr @ x)
iHr_linop = spla.LinearOperator((N,N), matvec = lambda x: iHr @ x)

ee_min, P_min = spla.eigsh(H_tilde_linop, 5, M=Hr_linop, Minv=iHr_linop, which='SA')
print(ee_min)

min_eig = ee_min[0]

[-5.1607965  -2.35662023 -2.22846821 -1.7455176  -1.07128117]


In [30]:
tau = 0.75
mu = -2.0*min_eig

H_tilde_shifted = H_tilde + mu * Hr
Hr_shifted = Hr + mu * Hr

ee_shifted, P_shifted = sla.eigh(H_tilde_shifted, Hr_shifted)
print(ee_shifted)

bad_inds = (ee_shifted < tau)
bad_ee = ee_shifted[bad_inds]
bad_P = P_shifted[:,bad_inds]

[0.4558366  0.7035205  0.71483976 0.75749724 0.81705038 0.84705281
 0.84867218 0.8567344  0.8645359  0.87111828 0.88376141 0.88635856
 0.88774139 0.89099087 0.89536947 0.90036889 0.90409189 0.90498947
 0.90826191 0.91132284 0.91346658 0.91687504 0.91717275 0.91862722
 0.91960827 0.92192825 0.92274432 0.92411062 0.92479347 0.92649332
 0.92761436 0.92873703 0.92929327 0.92953738 0.93039606 0.93124921
 0.93230086 0.93311576 0.93388369 0.93525688 0.93568313 0.93668673
 0.93677639 0.93743136 0.93805421 0.93822524 0.93886495 0.93946182
 0.93985285 0.94034676 0.94067433 0.94105271 0.94149454 0.94200922
 0.94216774 0.94254817 0.94337711 0.94375141 0.94402822 0.94444249
 0.94478036 0.9452008  0.94540038 0.94558614 0.9459951  0.94628746
 0.94679555 0.94713887 0.94739428 0.94751128 0.9478109  0.94837703
 0.94844865 0.94879567 0.94896509 0.94929561 0.94932846 0.94946082
 0.94991767 0.9501236  0.9503238  0.95067352 0.95073439 0.95103497
 0.95135547 0.95151983 0.95176196 0.95225125 0.95238668 0.9525

In [225]:
import typing as typ


def eig_coeff(eig: float, # bad eigenvalue A*u = eig*B*u
              tau: float, # tolerance in (0,1)
              method='zero',
             ) -> float: # correction coefficient
    if method == 'thresh': # eig + 1 + c = tau
        c = tau - eig - 1.0
    elif method == 'zero': # eig + 1 + c = 1
        c = -eig
    elif method == 'flip': # eig + 1 + c = 1 + |eig|
        c = np.abs(eig) - eig
    return c
    

def bad_generalized_eigs(apply_A: typ.Callable[[np.ndarray], np.ndarray], # shape=(N,N)
                         apply_B: typ.Callable[[np.ndarray], np.ndarray], # shape=(N,N)
                         solve_B: typ.Callable[[np.ndarray], np.ndarray], # shape=(N,N)
                         N: int,
                         tau: float=0.5, # tolerance in (0,1)
                         chunk_size:int=5,
                         max_rank: int=50,
                         display: bool=False,
                        ) -> typ.Tuple[np.ndarray, # U=[u0, u1, ...]: shape=(N,k)
                                       np.ndarray, # V=[B@u0, B@u1, ...]: shape=(N,k)
                                       np.ndarray]: # eigs=[e0, e1, ...]: shape=(N,))
    '''Finds geigs A@ui=ei*B@ui such that ui.T@(A + B)@ui < tau.
    
    A is symmetric indefinite
    B is symmetric positive definite
    A.shape=B.shape=(N,N)
    apply_A(x) = A @ x
    apply_B(x) = B @ x
    solve_B(apply_B(x)) = apply_B(solve_B(x)) = x
    '''
    Bop = spla.LinearOperator((N,N), matvec=apply_B)
    iBop = spla.LinearOperator((N,N), matvec=solve_B)
    uu: typ.List[np.ndarray] = [] 
    vv: typ.List[np.ndarray] = []
    ee: typ.List[float] = []
    while len(ee) < max_rank:
        if ee:
            cc = np.array([eig_coeff(ei, tau, method='flip') for ei in ee])
            VT = np.array(vv)
            extra_term = lambda x: VT.T @ (cc * (VT @ x.reshape(-1)))
        else:
            extra_term = lambda x: 0.0 * x
            
        def apply_modified_A(x: np.ndarray) -> np.ndarray:
            return apply_A(x) + extra_term(x)
        modified_A_linop = spla.LinearOperator((N,N), matvec=apply_modified_A)
        
        if display:
            print('computing chunk, chunk_size=', chunk_size)
        ee_chunk, U_chunk = spla.eigsh(modified_A_linop, k=chunk_size, 
                                       M=Bop, Minv=iBop, which='SA')
        done=False
        for ii in range(len(ee_chunk)):
            if display:
                print('eig=', ee_chunk[ii], ', tau-1.0=', tau-1.0)
            if ee_chunk[ii] < tau-1.0:
                ee.append(ee_chunk[ii])
                uu.append(U_chunk[:,ii])
                vv.append(apply_B(U_chunk[:,ii]))
            if ee_chunk[ii] >= tau-1.0:
#                 if display:
#                     print('Done.')
                done = True
        if done:
            break
    
    U = np.array(vv).T
    V = np.array(vv).T
    eigs = np.array(ee)
    return U, V, eigs

def bad_eig_correction(apply_A: typ.Callable[[np.ndarray], np.ndarray], # shape=(N,N)
                       apply_B: typ.Callable[[np.ndarray], np.ndarray], # shape=(N,N)
                       solve_B: typ.Callable[[np.ndarray], np.ndarray], # shape=(N,N)
                       N: int,
                       tau: float=0.5, # tolerance in (0,1)
                       chunk_size:int=5,
                       max_rank: int=50,
                       display: bool=False,
                       method='flip',
                      ) -> typ.Tuple[np.ndarray, # V: shape=(N,k)
                                     np.ndarray]: # cc: shape=(N,))
    '''Finds low rank correction for A+B to reduce impact of indefiniteness of A.
    
    A is symmetric indefinite
    B is symmetric positive definite
    A.shape=B.shape=(N,N)
    apply_A(x) = A @ x
    apply_B(x) = B @ x
    solve_B(apply_B(x)) = apply_B(solve_B(x)) = x
    
    Bad eigenpairs (ui, ei) of A @ ui = ei B @ ui satisfy
        ui.T @ (A + B) @ ui < tau * ui.T @ B @ ui = tau
        
    We form low rank correction
        A + B + V * diag(cc) * V.T
    such that:
    1) V * diag(cc) * V.T is diagonalized by the generalized eigenvectors of (A,B)
    2) For all good (ui, ei), we have
        ui.T @ (A + B + V * diag(cc) * V.T) @ ui = ei
    3) For all bad (ui, ei), we have:
        ui.T @ (A + B + V * diag(cc) * V.T) @ ui = tau    (method='thresh')
    or:
        ui.T @ (A + B + V * diag(cc) * V.T) @ ui = 1.0  (method='zero')
    or:
        ui.T @ (A + B + V * diag(cc) * V.T) @ ui = |ei|   (method='flip')
        
    In:
        import numpy as np
        import scipy.linalg as sla
        
        N = 500
        alpha = 2.0
        tau = 0.75

        tmp = np.random.randn(N,N)
        A = 0.5 * (tmp.T + tmp)

        tmp2 = np.random.randn(N,N)
        B = sla.sqrtm(tmp.T @ tmp)

        # Flip bad geigs with brute force:
        ee, P = sla.eigh(A + alpha*B, alpha*B)
        ee_flip = np.zeros(len(ee))
        for ii in range(len(ee)):
            if ee[ii] < tau:
                ee_flip[ii] = 1.0 + np.abs(ee[ii])
            else:
                ee_flip[ii] = ee[ii]
        M_flip_brute = np.linalg.inv(P.T) @ np.diag(ee_flip) @ np.linalg.inv(P)
                
        # Flip bad geigs using bad_eig_correction():
        apply_A = lambda x: A @ x
        apply_B = lambda x: B @ x
        Binv = np.linalg.inv(B) # yeah yeah..
        solve_B = lambda x: Binv @ x
        
        V, cc = bad_eig_correction(apply_A, apply_B, solve_B, N, alpha, tau=tau)
        
        # Check with brute force:
        M_flip = A + alpha*B + V @ np.diag(cc) @ V.T
        err = np.linalg.norm(M_flip - M_flip_brute) / np.linalg.norm(M_flip_brute)
        print('err=', err)
        
        
    '''
    U, V, eigs = bad_generalized_eigs(apply_A, apply_B, solve_B,
                                      N, tau=tau, chunk_size=chunk_size, 
                                      max_rank=max_rank, display=display)
    
    cc = np.array([eig_coeff(ei, tau, method=method) for ei in list(eigs)])
    return V, cc

In [237]:
import numpy as np
import scipy.linalg as sla

N = 500
tau = 0.75

tmp = np.random.randn(N,N)
A = 0.5 * (tmp.T + tmp) + 20.0*np.eye(N)

tmp2 = np.random.randn(N,N)
B = sla.sqrtm(tmp2.T @ tmp2)

# Flip bad geigs with brute force:
ee, P = sla.eigh(A, B)
ee_flip = np.zeros(len(ee))
for ii in range(len(ee)):
    if ee[ii] + 1.0 < tau:
        ee_flip[ii] = 1.0 + np.abs(ee[ii])
    else:
        ee_flip[ii] = 1.0 + ee[ii]

# Flip bad geigs using bad_eig_correction():
apply_A = lambda x: A @ x
apply_B = lambda x: B @ x
Binv = np.linalg.inv(B) # yeah yeah..
solve_B = lambda x: Binv @ x

U, V, eigs = bad_generalized_eigs(apply_A, apply_B, solve_B, N, tau=tau, display=True)

V, cc = bad_eig_correction(apply_A, apply_B, solve_B, N, tau=tau, display=False)

# Check correctness:
M_flip = A + B + V @ np.diag(cc) @ V.T
rayleigh = P.T @ M_flip @ P
err_offdiagonal = np.linalg.norm(rayleigh - np.diag(np.diag(rayleigh)))
print('err_offdiagonal=', err_offdiagonal)

ee_flip2 = np.diag(rayleigh)
err_eigs = np.linalg.norm(ee_flip2 - ee_flip)
print('err_eigs=', err_eigs)

# err_diagonal = np.linalg.norm(M_flip - M_flip_brute) / np.linalg.norm(M_flip_brute)
# print('err=', err)

computing chunk, chunk_size= 5
eig= -0.8403648401752911 , tau-1.0= -0.25
eig= -0.8216257754680453 , tau-1.0= -0.25
eig= -0.7888672956716711 , tau-1.0= -0.25
eig= -0.7303216090860766 , tau-1.0= -0.25
eig= -0.7007747213396004 , tau-1.0= -0.25
computing chunk, chunk_size= 5
eig= -0.6978371707375703 , tau-1.0= -0.25
eig= -0.6353377188854781 , tau-1.0= -0.25
eig= -0.6230393718542241 , tau-1.0= -0.25
eig= -0.576770390770704 , tau-1.0= -0.25
eig= -0.5607611744137356 , tau-1.0= -0.25
computing chunk, chunk_size= 5
eig= -0.5594335299350824 , tau-1.0= -0.25
eig= -0.5167520589074648 , tau-1.0= -0.25
eig= -0.5097363273123539 , tau-1.0= -0.25
eig= -0.4851362966154258 , tau-1.0= -0.25
eig= -0.4629000322908032 , tau-1.0= -0.25
computing chunk, chunk_size= 5
eig= -0.4600921160775987 , tau-1.0= -0.25
eig= -0.4439135261290584 , tau-1.0= -0.25
eig= -0.4406221678391628 , tau-1.0= -0.25
eig= -0.41971494211581273 , tau-1.0= -0.25
eig= -0.40219055272378507 , tau-1.0= -0.25
computing chunk, chunk_size= 5
eig=

In [217]:
rayleigh = P.T @ M_flip @ P
np.linalg.norm(rayleigh - np.diag(np.diag(rayleigh)))

ee_flip2 = np.diag(rayleigh)

In [233]:
eigs

array([-0.87907527, -0.81908531, -0.78300675, -0.7708221 , -0.71869382,
       -0.67275632, -0.65982034, -0.62750092, -0.58528401, -0.56200318,
       -0.55693451, -0.51894607, -0.50799192, -0.49655246, -0.47663348,
       -0.47081452, -0.43481611, -0.42744686, -0.4204752 , -0.3996964 ,
       -0.38361491, -0.37813569, -0.35693317, -0.35315913, -0.34064827,
       -0.32984809, -0.30851962, -0.29451704, -0.28801111, -0.26645417,
       -0.25816187, -0.2535858 ])

In [232]:
ee

array([-8.79075269e-01, -8.19085306e-01, -7.83006750e-01, -7.70822098e-01,
       -7.18693819e-01, -6.72756315e-01, -6.59820344e-01, -6.27500919e-01,
       -5.85284011e-01, -5.62003179e-01, -5.56934507e-01, -5.18946071e-01,
       -5.07991919e-01, -4.96552464e-01, -4.76633485e-01, -4.70814523e-01,
       -4.34816107e-01, -4.27446859e-01, -4.20475195e-01, -3.99696403e-01,
       -3.83614905e-01, -3.78135686e-01, -3.56933170e-01, -3.53159129e-01,
       -3.40648271e-01, -3.29848093e-01, -3.08519618e-01, -2.94517039e-01,
       -2.88011107e-01, -2.66454167e-01, -2.58161868e-01, -2.53585802e-01,
       -2.45912530e-01, -2.36599540e-01, -2.20593420e-01, -2.09712673e-01,
       -2.03851788e-01, -1.99307142e-01, -1.92131704e-01, -1.78430410e-01,
       -1.69305378e-01, -1.63512924e-01, -1.50320074e-01, -1.33817243e-01,
       -1.26428711e-01, -1.21042441e-01, -1.13484743e-01, -1.06645438e-01,
       -1.01333269e-01, -8.81408262e-02, -7.84798363e-02, -7.04867410e-02,
       -6.51135043e-02, -

In [234]:
ee_flip

array([ 1.87907527e+00,  1.81908531e+00,  1.78300675e+00,  1.77082210e+00,
        1.71869382e+00,  1.67275632e+00,  1.65982034e+00,  1.62750092e+00,
        1.58528401e+00,  1.56200318e+00,  1.55693451e+00,  1.51894607e+00,
        1.50799192e+00,  1.49655246e+00,  1.47663348e+00,  1.47081452e+00,
        1.43481611e+00,  1.42744686e+00,  1.42047520e+00,  1.39969640e+00,
        1.38361491e+00,  1.37813569e+00,  1.35693317e+00,  1.35315913e+00,
        1.34064827e+00,  1.32984809e+00,  1.30851962e+00,  1.29451704e+00,
        1.28801111e+00,  1.26645417e+00,  1.25816187e+00,  1.25358580e+00,
       -2.45912530e-01, -2.36599540e-01, -2.20593420e-01, -2.09712673e-01,
       -2.03851788e-01, -1.99307142e-01, -1.92131704e-01, -1.78430410e-01,
       -1.69305378e-01, -1.63512924e-01, -1.50320074e-01, -1.33817243e-01,
       -1.26428711e-01, -1.21042441e-01, -1.13484743e-01, -1.06645438e-01,
       -1.01333269e-01, -8.81408262e-02, -7.84798363e-02, -7.04867410e-02,
       -6.51135043e-02, -

In [235]:
ee_flip2

array([  1.87907527,   1.81908531,   1.78300675,   1.7708221 ,
         1.71869382,   1.67275632,   1.65982034,   1.62750092,
         1.58528401,   1.56200318,   1.55693451,   1.51894607,
         1.50799192,   1.49655246,   1.47663348,   1.47081452,
         1.43481611,   1.42744686,   1.4204752 ,   1.3996964 ,
         1.38361491,   1.37813569,   1.35693317,   1.35315913,
         1.34064827,   1.32984809,   1.30851962,   1.29451704,
         1.28801111,   1.26645417,   1.25816187,   1.2535858 ,
         0.75408747,   0.76340046,   0.77940658,   0.79028733,
         0.79614821,   0.80069286,   0.8078683 ,   0.82156959,
         0.83069462,   0.83648708,   0.84967993,   0.86618276,
         0.87357129,   0.87895756,   0.88651526,   0.89335456,
         0.89866673,   0.91185917,   0.92152016,   0.92951326,
         0.9348865 ,   0.94203667,   0.94603263,   0.94892925,
         0.95600745,   0.96160721,   0.97212644,   0.98085329,
         0.98313987,   0.99791475,   1.00331161,   1.00

In [236]:
ee_flip - ee_flip2

array([ 2.62123656e-12,  2.50444110e-12,  3.72590847e-13,  8.72191208e-13,
        3.04023473e-12, -5.84865489e-13,  9.48796597e-13,  3.11617399e-12,
        4.33253433e-12, -4.81414908e-12,  2.33679742e-12,  8.30269187e-12,
        7.97784061e-12,  5.02953235e-12,  3.99014155e-13, -5.16542364e-12,
       -4.91739982e-12,  2.51154653e-12,  8.51096971e-13, -5.53357360e-12,
        1.41642253e-12,  3.52451401e-12,  7.41851025e-13,  3.04201109e-12,
       -1.92046379e-12, -3.05666603e-12, -7.31414929e-13,  1.01607611e-12,
        2.64988032e-12,  1.07447384e-12,  4.91362506e-12,  4.74220663e-12,
       -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,
       -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,
       -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,
       -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,
       -1.00000000e+00, -1.00000000e+00, -1.00000000e+00, -1.00000000e+00,
       -1.00000000e+00, -

In [58]:
ee_flip

array([ 5.76269524,  4.02099107,  3.30868079,  3.05089916,  2.8303452 ,
        2.77030345,  2.68300221,  2.59145597, -0.45629886, -0.36375334,
       -0.31233329, -0.25420954, -0.17939885, -0.15828811, -0.1305535 ,
       -0.08089485, -0.06385077, -0.05247395, -0.03615887,  0.00848004,
        0.0380448 ,  0.04854886,  0.07662666,  0.07822524,  0.1105669 ,
        0.11345388,  0.13030788,  0.13553619,  0.13665826,  0.16207286,
        0.17543267,  0.18690205,  0.19674866,  0.20260235,  0.2188642 ,
        0.22362788,  0.24295704,  0.24948852,  0.25635967,  0.27439377,
        0.27605674,  0.28135781,  0.28936293,  0.29195477,  0.30211809,
        0.30584791,  0.31290638,  0.32139239,  0.32647403,  0.33068277,
        0.3344006 ,  0.33986124,  0.34157642,  0.34714057,  0.3537887 ,
        0.358687  ,  0.36788259,  0.36857989,  0.37136131,  0.377819  ,
        0.38183903,  0.38299257,  0.38435163,  0.39058641,  0.39189279,
        0.39685758,  0.40007501,  0.40156134,  0.40507598,  0.41

In [54]:
iHr0 = np.linalg.inv(Hr0)
    
apply_Hd_tilde = lambda x: Hd_tilde @ x
apply_Hr0 = lambda x: Hr0 @ x
solve_Hr0 = lambda x: iHr0 @ x
    
U, V, eigs = bad_generalized_eigs(apply_Hd_tilde,
                                  apply_Hr0, solve_Hr0,
                                  N, a_reg, tau=tau,
                                  chunk_size=5, max_rank=50, 
                                  display=True)

print(eigs)

computing chunk, chunk_size= 5
eig= -10.321593005557636 , alpha*(tau-1.0)= -0.5
eig= -4.713240466598481 , alpha*(tau-1.0)= -0.5
eig= -4.4569364213361125 , alpha*(tau-1.0)= -0.5
eig= -3.49103520133542 , alpha*(tau-1.0)= -0.5
eig= -2.1425623385262003 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 5
eig= -1.4632116688951808 , alpha*(tau-1.0)= -0.5
eig= -1.4265439365838337 , alpha*(tau-1.0)= -0.5
eig= -1.243989576609168 , alpha*(tau-1.0)= -0.5
eig= -1.067338838923188 , alpha*(tau-1.0)= -0.5
eig= -0.918292679628947 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 5
eig= -0.632011916470587 , alpha*(tau-1.0)= -0.5
eig= -0.5732043164471803 , alpha*(tau-1.0)= -0.5
eig= -0.5418926027966902 , alpha*(tau-1.0)= -0.5
eig= -0.4683139974478157 , alpha*(tau-1.0)= -0.5
eig= -0.369168653148221 , alpha*(tau-1.0)= -0.5
[-10.32159301  -4.71324047  -4.45693642  -3.4910352   -2.14256234
  -1.46321167  -1.42654394  -1.24398958  -1.06733884  -0.91829268
  -0.63201192  -0.57320432  -0.5418926 ]


In [42]:
def low_rank_negative_eigenvalue_correction(apply_A: typ.Callable[[np.ndarray], np.ndarray],
                                            apply_B: typ.Callable[[np.ndarray], np.ndarray], 
                                            solve_B: typ.Callable[[np.ndarray], np.ndarray],
                                            N: int,
                                            alpha: float,
                                            tau: float=0.5,
                                            chunk_size:int=5,
                                            max_rank: int=50,
                                            display: bool=True,
                                            adjustment_type='reg',
                                           ) -> typ.Tuple[np.ndarray, # shape=(N,k)
                                                          np.ndarray]: # shape=(N,)
    '''Finds low rank factors and weights to make A + alpha*B >= tau*B.
    
    A is symmetric indefinite
    B is symmetric positive definite
    A.shape=B.shape=(N,N)
    apply_A(x) = A @ x
    apply_B(x) = B @ x
    solve_B(apply_B(x)) = apply_B(solve_B(x)) = x
    '''
    Bop = spla.LinearOperator((N,N), matvec=apply_B)
    iBop = spla.LinearOperator((N,N), matvec=solve_B)
    vv: typ.List[np.ndarray] = []
    cc: typ.List[float] = []
    while len(cc) < max_rank:
        if cc:
            VT = np.array(vv)
            c = np.array(cc)
            correction = lambda x: VT.T @ (c * (VT @ x.reshape(-1)))
        else:
            correction = lambda x: 0.0 * x
            
        def apply_M(x: np.ndarray) -> np.ndarray:
            return apply_A(x) + alpha*apply_B(x) + correction(x)
        M_linop = spla.LinearOperator((N,N), matvec=apply_M)
        
        if display:
            print('computing chunk, chunk_size=', chunk_size)
        ee, U = spla.eigsh(M_linop, k=chunk_size, M=Bop, Minv=iBop, which='SA')
        good=False
        for ii in range(len(ee)):
            if display:
                print('ee[ii]=', ee[ii], ', alpha*(tau-1.0)=', alpha*(tau-1.0))
            if ee[ii] < alpha*(tau-1.0):
#                 cc.append(alpha*(tau-1.0) - ee[ii])
                cc.append(-ee[ii])
                vv.append(apply_B(U[:,ii]))
            if np.any(ee[ii] >= alpha*(tau-1.0)):
                good=True
        if good:
            break
    
    V = np.array(vv).T
    c = np.array(cc)
    return V, c
    

iHr0 = np.linalg.inv(Hr0)
    
apply_Hd_tilde = lambda x: Hd_tilde @ x
apply_Hr0 = lambda x: Hr0 @ x
solve_Hr0 = lambda x: iHr0 @ x
    
V, c = low_rank_negative_eigenvalue_correction(apply_Hd_tilde,
                                               apply_Hr0, solve_Hr0,
                                               N, a_reg, tau=tau,
                                               chunk_size=5, max_rank=50, 
                                               display=True)

computing chunk, chunk_size= 5
ee[ii]= -10.321593005552296 , alpha*(tau-1.0)= -0.5
ee[ii]= -4.713240466602076 , alpha*(tau-1.0)= -0.5
ee[ii]= -4.4569364213366525 , alpha*(tau-1.0)= -0.5
ee[ii]= -3.4910352013367665 , alpha*(tau-1.0)= -0.5
ee[ii]= -2.1425623385274486 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 5
ee[ii]= -1.4632116688948853 , alpha*(tau-1.0)= -0.5
ee[ii]= -1.426543936585123 , alpha*(tau-1.0)= -0.5
ee[ii]= -1.2439895766106515 , alpha*(tau-1.0)= -0.5
ee[ii]= -1.0673388389243132 , alpha*(tau-1.0)= -0.5
ee[ii]= -0.9182926796288587 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 5
ee[ii]= -0.6320119164724999 , alpha*(tau-1.0)= -0.5
ee[ii]= -0.5732043164475709 , alpha*(tau-1.0)= -0.5
ee[ii]= -0.5418926027954649 , alpha*(tau-1.0)= -0.5
ee[ii]= -0.4683139974487207 , alpha*(tau-1.0)= -0.5
ee[ii]= -0.36916865314763725 , alpha*(tau-1.0)= -0.5


In [43]:
V, c = low_rank_negative_eigenvalue_correction(apply_Hd_tilde,
                                               apply_Hr0, solve_Hr0,
                                               N, a_reg, tau=tau,
                                               chunk_size=1, max_rank=50, 
                                               display=True)

computing chunk, chunk_size= 1
ee[ii]= -10.321593005617997 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -4.713240466591016 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -4.4569364213403775 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -3.4910352013361905 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -2.1425623385294457 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -1.4632116688947825 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -1.4265439365842303 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -1.243989576608658 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -1.0673388389219927 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -0.9182926796288576 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -0.6320119164702502 , alpha*(tau-1.0)= -0.5
computing chunk, chunk_size= 1
ee[ii]= -0.5732043164475763 , alpha*(tau-1.0)= -0.5
comput