In [1]:
import torch
import numpy as np
from torchessian import hessian_matmul
from torchvision.models import resnet18

In [2]:
model = torch.nn.Linear(6, 1)
torch.manual_seed(2019)
v = torch.rand(sum(p.data.numel() for p in model.parameters()))
x = torch.rand(10, 6)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y

In [41]:
def solve(A, desired_rank):
    # Calculate two more eigenvalues, but we only keep the largest desired_rank
    # one. Doing this to keep the result consistent with scipy.sparse.linalg.svds.

    n = A.shape[1]
    v_next = np.ones(n) / np.sqrt(n)
    v_prev = np.zeros(n)
    beta = np.zeros(desired_rank+1)
    beta[0] = 0
    alpha = np.zeros(desired_rank)

    # Since the disiredRank << size of matrix, so we keep
    # V in local memory for efficiency reason(It needs to be updated
    # for every iteration). 
    # If the case which V can't be fit in local memory occurs, 
    # you could turn it into spartan distributed array. 
    V = np.zeros((n, desired_rank))

    for i in range(desired_rank):
        v_next_expr = v_next.reshape(n, 1)
        w = np.dot(A, v_next_expr).reshape(n)

        w = w - alpha[i] * v_next - beta[i] * v_prev

        # Orthogonalize:
        for t in range(i):
            tmpa = np.dot(w, V[:, t])
            if tmpa == 0.0:
                continue
            w -= tmpa * V[:, t] 

        beta[i+1] = np.linalg.norm(w, 2) 
        v_prev = v_next
        v_next = w / beta[i+1]
        V[:, i] = v_prev

    # Create tridiag matrix with size (desired_rank X desired_rank)  
    tridiag = np.diag(alpha)
    for i in range(desired_rank-1):
        tridiag[i, i+1] = beta[i+1] 
        tridiag[i+1, i] = beta[i+1]

    return tridiag, V

In [42]:
H = torch.tensor([[0.6830, 0.4263, 1.0876],
        [0.4263, 0.5262, 0.8746],
        [1.0876, 0.8746, 2.0000]])

In [54]:
T, V = solve(H, 2)

In [55]:
V[:, 0].dot(V[:, 1])

0.9438668413799078

In [56]:
V

array([[0.57735027, 0.44971924],
       [0.57735027, 0.37401885],
       [0.57735027, 0.81108724]])

In [57]:
np.linalg.eig(H)

(array([2.9951546 , 0.04153168, 0.1725138 ], dtype=float32),
 array([[-0.45050886, -0.6901737 ,  0.56630564],
        [-0.36621034, -0.4356409 , -0.8222572 ],
        [-0.8142062 ,  0.57782114,  0.05648864]], dtype=float32))

In [47]:
np.linalg.eig(T)

(array([-2.99504798e+00,  2.99504798e+00, -4.29914051e-18]),
 array([[ 6.65869948e-01,  6.65869948e-01, -3.36503232e-01],
        [-7.07106781e-01,  7.07106781e-01,  4.52200731e-17],
        [ 2.37943717e-01,  2.37943717e-01,  9.41682311e-01]]))

In [58]:
T

array([[0.       , 2.8203837],
       [2.8203837, 0.       ]])

In [61]:
w, v = np.linalg.eig(T)

In [62]:
w

array([ 2.8203837, -2.8203837])

In [63]:
v

array([[ 0.70710678, -0.70710678],
       [ 0.70710678,  0.70710678]])

In [21]:
v.matmul(torch.diag(w[:, 0])).matmul(v.transpose(0, 1))

tensor([[-13263.8984,   2268.6606,   2745.7505,   6112.6572,  -5364.9922],
        [  2268.6604,   -387.8971,   -469.6662,  -1045.4684,    917.5088],
        [  2745.7505,   -469.6662,   -568.3350,  -1265.3640,   1110.5846],
        [  6112.6572,  -1045.4684,  -1265.3640,  -2816.9634,   2472.3718],
        [ -5364.9922,    917.5088,   1110.5846,   2472.3718,  -2169.8423]])

In [22]:
T.matmul(v).matmul(torch.diag(1/w[:, 0]))

tensor([[ 0.8310, -0.7659, -0.8550,  0.5809,  0.4943],
        [-0.1421,  0.2081,  0.2421,  0.4221,  0.5113],
        [-0.1720,  0.2601,  0.5437, -0.0197, -0.6229],
        [-0.3830,  0.4447, -0.5299,  0.1087, -0.2864],
        [ 0.3361, -0.3282, -0.1576, -0.6428,  0.1524]])