In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import torch as tc

def eigs_power(mat, v0=None, which='la', tau=0.01, it_time=2000, tol=1e-14):
    """
    :param mat: Input matrix (real symmetric matrix)
    :param v0: Initial vector, default is a random vector
    :param which: Compute which eigenvalue and eigenvector,
                  'la' for algebraically largest, 'sa' for smallest, 'lm' for largest in magnitude, 'sm' for smallest
    :param tau: Small positive real number, used to construct projection matrix
    :param it_time: Max number of iterations
    :param tol: Convergence threshold
    :return -tc.log(lm)/tau: Algebraically largest (tau>0) or smallest (tau<0) eigenvalue
    :return v1: Corresponding eigenvector
    """

    # Initialize vector
    if v0 is None:
        v0 = tc.randn(mat.shape[1], dtype=mat.dtype)
        v0 /= v0.norm()
    v1 = v0.clone()

    # Define projection matrix based on 'which'
    tau = abs(tau)
    if which.lower() == 'la':
        rho = tc.matrix_exp(tau * mat)
    elif which.lower() == 'sa':
        rho = tc.matrix_exp(-tau * mat)
    elif which.lower() == 'lm':
        rho = tc.matrix_exp(tau * (tc.matrix_power(mat, 2)))
    else:  # which.lower() == 'sm'
        rho = tc.matrix_exp(-tau * (tc.matrix_power(mat, 2)))

    lm = 1
    for n in range(it_time):  # Start iterating
        v1 = rho.matmul(v0)  # Compute v1 = rho * v0
        lm = v1.norm()  # Compute eigenvalue
        v1 /= lm  # Normalize v1
        # Check convergence
        conv = (v1 - v0).norm()
        if conv < tol:
            break
        else:
            v0 = v1.clone()

    # Correct the loss of sign from squaring
    v1 = mat.matmul(v0)
    sign = tc.sign(v0.dot(v1))

    if which.lower() == 'la':
        return tc.log(lm)/tau, v1/v1.norm()
    elif which.lower() == 'sa':
        return -tc.log(lm)/tau, v1/v1.norm()
    elif which.lower() == 'lm':
        return sign * tc.sqrt(tc.log(lm)/tau), v1/v1.norm()
    else:  # which.lower() == 'sm'
        return sign * tc.sqrt(-tc.log(lm)/tau), v1/v1.norm()
