In [1]:
import torch
import numpy as np
from torchessian import hessian_matmul
from torchvision.models import resnet18

In [156]:
n = 20
model = torch.nn.Linear(n, 1)
torch.manual_seed(2019)
v = torch.rand(sum(p.data.numel() for p in model.parameters()))
x = torch.rand(10, n)
y = x.sum(1)
def loss_function(y_hat, y):
    return ((y_hat.view(-1) - y.view(-1)) * (y_hat.view(-1) - y.view(-1))).sum() / y.numel()
batch = x, y
m = 3

In [245]:
def solve(A, m):
    n = A.shape[0]
    v = np.ones(n) / np.sqrt(n)
    w = A.dot(v)
    alpha = []
    alpha.append(w.dot(v))
    w -= alpha[0] * v
    
    V = [v]
    beta = []
    
    for i in range(1, m):
        b = np.linalg.norm(w)
        beta.append(b)
        if b > 0:
            v = w / b
        else:
            done = False
            k = 0
            while not done:
                k += 1
                v = np.random.rand(n)
                
                for v_ in V:
                    v -= v.dot(v_) * v_
                
                done = np.linalg.norm(n) > 0
                if k > n * 10:
                    raise Exception("Can't find orthogonal vector")
            
            v /= np.linalg.norm(n)
                
        for v_ in V:
            v -= v.dot(v_) * v_
            v /= np.linalg.norm(v)
               
        V.append(v)
        w = A.dot(v)
        alpha.append(w.dot(v))
        w = w - alpha[-1] * V[-1] - beta[-1] * V[-2]

    T = np.diag(alpha)
    for i in range(m - 1):
        T[i, i + 1] = beta[i] 
        T[i + 1, i] = beta[i]

    V = np.array(V)
    return T, V

In [246]:
H = torch.tensor([[0.6830, 0.4263, 1.0876],
        [0.4263, 0.5262, 0.8746],
        [1.0876, 0.8746, 2.0000]])

n = 10

H = np.zeros((n, n))
for i in range(n):
    H[i, i:] = np.random.rand(n - i)

for i in range(n):
    H[i:, i] = H[i, i:]

In [251]:
m = 9
T, V = solve(H, m)

In [252]:
v = np.linalg.eig(T)[0]
v.sort()
v

array([-1.36059246, -1.16382719, -0.5634516 , -0.27759082,  0.11037997,
        0.43755372,  0.98927396,  1.25766897,  4.59069595])

In [253]:
v = np.linalg.eig(H)[0]
v.sort()
v

array([-1.36065533, -1.16404016, -0.57419454, -0.30143077,  0.04803424,
        0.33196491,  0.49455316,  0.99003085,  1.25804118,  4.59069595])

In [254]:
for i in range(m - 1):
    print("{:.3f}".format(V[i, :].dot(V[m - 1, :])))

-0.000
-0.000
-0.000
-0.000
-0.000
-0.000
-0.000
0.000


In [111]:
a = np.random.rand(3)
a /= np.linalg.norm(a)
b = np.random.rand(3)
c = np.random.rand(3)

In [112]:
b -= b.dot(a) * a
b /= np.linalg.norm(b)

In [113]:
c -= c.dot(a) * a
# c /= np.linalg.norm(c)
c -= c.dot(b) * b
# c /= np.linalg.norm(c)

In [114]:
print(a.dot(b))
print(a.dot(c))
print(c.dot(b))

-5.551115123125783e-17
1.0408340855860843e-16
-1.3877787807814457e-17
