In [14]:
import torch
from torchessian import hessian_matmul
from torchvision.models import resnet18

In [15]:
model = torch.nn.Linear(6, 1)
torch.manual_seed(2019)
v = torch.rand(sum(p.data.numel() for p in model.parameters()))
x = torch.rand(10, 6)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y

In [16]:
def lanczos(model, loss_function, v, batch, m):
    V_ = [v]
    for _ in range(m - 1):
        v_ = hessian_matmul(model, loss_function, v, batch)
        V_.append(v)

    V_ = torch.cat(list(v.unsqueeze(0) for v in V_), 0) # [m, n]
    V_ = V_.transpose(0, 1) # [n, m]
    # Check if n > m
    V, _ = torch.qr(V_) # [n, m]
    
    Hv = []
    for i in range(m):
        v = V[:, i] # [n]
        Hv.append(hessian_matmul(model, loss_function, v, batch))
    
    Hv = torch.cat(list(v.unsqueeze(0) for v in Hv), 0) # [m, n]
    T = V.transpose(0, 1).matmul(Hv.transpose(0, 1)) # VHV
    
    return T

In [17]:
T = lanczos(model, loss_function, v, batch, 5)

In [18]:
T

tensor([[-40179.0859, -50228.6133, -60274.3789, -70320.8203, -80363.6406],
        [  6872.0557,   8591.0469,  10309.0186,  12027.3701,  13744.9531],
        [  8317.8086,  10398.2607,  12478.0254,  14557.7568,  16636.7188],
        [ 18516.6230,  23148.0117,  27777.5957,  32407.5391,  37035.6484],
        [-16251.3496, -20315.9902, -24379.2031, -28442.7109, -32504.4727]])

In [19]:
w, v = torch.eig(T, eigenvectors=True)

In [20]:
w

tensor([[-1.9208e+04,  0.0000e+00],
        [ 6.8007e-01,  0.0000e+00],
        [ 2.2778e-02,  0.0000e+00],
        [ 2.4571e-01,  0.0000e+00],
        [ 1.6875e-01,  0.0000e+00]])

In [21]:
v.matmul(torch.diag(w[:, 0])).matmul(v.transpose(0, 1))

tensor([[-13263.8984,   2268.6606,   2745.7505,   6112.6572,  -5364.9922],
        [  2268.6604,   -387.8971,   -469.6662,  -1045.4684,    917.5088],
        [  2745.7505,   -469.6662,   -568.3350,  -1265.3640,   1110.5846],
        [  6112.6572,  -1045.4684,  -1265.3640,  -2816.9634,   2472.3718],
        [ -5364.9922,    917.5088,   1110.5846,   2472.3718,  -2169.8423]])

In [22]:
T.matmul(v).matmul(torch.diag(1/w[:, 0]))

tensor([[ 0.8310, -0.7659, -0.8550,  0.5809,  0.4943],
        [-0.1421,  0.2081,  0.2421,  0.4221,  0.5113],
        [-0.1720,  0.2601,  0.5437, -0.0197, -0.6229],
        [-0.3830,  0.4447, -0.5299,  0.1087, -0.2864],
        [ 0.3361, -0.3282, -0.1576, -0.6428,  0.1524]])