In [1]:
import torch
from torchessian import hessian_matmul
from torchvision.models import resnet18
from itertools import product

In [2]:
n = 2
model = torch.nn.Linear(n, 1)
# torch.manual_seed(2019)
v = torch.rand(sum(p.data.numel() for p in model.parameters()))
x = torch.rand(10, n)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y
m = 3
# Exact Hessian

h11 = 2 * (x[:, 0] * x[:, 0]).sum() / x.size(0)
h21 = h12 = 2 * (x[:, 0] * x[:, 1]).sum() / x.size(0)
h31 = h13 = 2 * x[:, 0].sum() / x.size(0)

h22 = 2 * (x[:, 1] * x[:, 1]).sum() / x.size(0)
h32 = h23 = 2 * x[:, 1].sum() / x.size(0)

h33 = 2

H = torch.zeros(3, 3)

for i, j in product(range(3), range(3)):
    exec("H[%d, %d] = h%d%d" % (i, j, i+1, j+1))
    
print(H)

tensor([[0.9112, 0.5456, 1.2491],
        [0.5456, 0.5437, 0.9178],
        [1.2491, 0.9178, 2.0000]])


In [7]:
def lanczos(model, loss_function, batch, m):
    global H
    n = sum(p.data.numel() for p in model.parameters())
    v = torch.ones(n)
    v /= torch.norm(v)
    w = hessian_matmul(model, loss_function, v, batch)
    alpha = []
    alpha.append(w.dot(v))
    w -= alpha[0] * v
    
    V = [v]
    beta = []
    
    for i in range(1, m):
        b = torch.norm(w)
        beta.append(b)
        if b > 0:
            v = w / b
        else:
            done = False
            k = 0
            while not done:
                k += 1
                v = torch.rand(n)
                
                for v_ in V:
                    v -= v.dot(v_) * v_
                
                done = torch.norm(n) > 0
                if k > n * 10:
                    raise Exception("Can't find orthogonal vector")
            
            v /= np.linalg.norm(n)
                
        for v_ in V:
            v -= v.dot(v_) * v_
            v /= torch.norm(v)
               
        V.append(v)
        w = hessian_matmul(model, loss_function, torch.Tensor(v), batch)
        alpha.append(w.dot(v))
        w = w - alpha[-1] * V[-1] - beta[-1] * V[-2]

    T = torch.diag(torch.Tensor(alpha))
    for i in range(m - 1):
        T[i, i + 1] = beta[i] 
        T[i + 1, i] = beta[i]

    V = torch.cat(list(v.unsqueeze(0) for v in V), 0)
    return T, V

In [8]:
T, V = lanczos(model, loss_function, batch, m)

In [10]:
v = torch.eig(T)[0]
v.sort()
v

tensor([[3.2403, 0.0000],
        [0.0619, 0.0000],
        [0.1530, 0.0000]])

In [12]:
v = torch.eig(H)[0]
v.sort()
v

tensor([[3.2403, 0.0000],
        [0.1529, 0.0000],
        [0.0618, 0.0000]])

In [13]:
for i in range(m - 1):
    print("{:.3f}".format(V[i, :].dot(V[m - 1, :])))

-0.000
0.000
