In [104]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from scipy.linalg import solve, eigvals
from scipy.stats import wishart
from numpy.linalg import cond

In [105]:
from hessianfree.utils import generate_pd_matrix, draw_surface, get_z, draw_surface_3d
from hessianfree.cg import pcg

In [106]:
plt.rcParams.update({"figure.figsize": (8, 5), "font.size": 16})

## Problem Formulation

Generate Hessian (PSD matrix)

In [124]:
A = generate_pd_matrix(2)

Generate Gradient (1d vector)

In [116]:
b = torch.rand(2, 1)

## Visualization

### Setup

In [117]:
points, directions, alphas = [], [], []
def callback(x, d, a):
    points.append(x.squeeze().tolist())
    directions.append(d.squeeze().tolist())
    alphas.append(a.squeeze().item())

In [118]:
pcg_solution, info = pcg(A, b, callback=callback)

In [119]:
solution = torch.as_tensor(solve(A, b))

In [120]:
torch.allclose(solution, pcg_solution)

True

In [121]:
solution - pcg_solution

tensor([[-3.7253e-09],
        [ 0.0000e+00]])

### Visualization

In [122]:
%matplotlib widget
draw_surface(A, b, pcg_solution, points, directions, alphas)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [22]:
draw_surface_3d(A, b)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Speed Comparison

In [7]:
import timeit

In [8]:
pcg_times = []
solve_times = []

In [9]:
max_size=1000

In [10]:
for size in range(2, max_size):
    setup = f"import torch\n\
from hessianfree.utils import generate_pd_matrix\n\
from hessianfree.cg import pcg\n\
A = generate_pd_matrix({size})\n\
b = torch.rand({size}, 1)"
    pcg_times.append(timeit.Timer(stmt="pcg(A, b)", setup=setup).repeat(7, 100))
    solve_times.append(timeit.Timer(stmt="torch.solve(b, A)", setup=setup).repeat(7, 100))

In [85]:
times = {"pcg": pcg_times, "solve": solve_times}

In [93]:
data = [
    {
        "time": time,
        "type": t,
        "run": r,
        "size": s,
    } for t in ["pcg", "solve"] for s in range(18) for r, time in enumerate(times[t][s])
]

In [95]:
data = pd.DataFrame(data)

In [96]:
data.head(2)

Unnamed: 0,time,type,run,size
0,0.035391,pcg,0,0
1,0.038253,pcg,1,0


In [101]:
%matplotlib widget
sns.set_style("darkgrid")
sns.catplot(x="size", y="time", data=data, hue="type", kind="swarm")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<seaborn.axisgrid.FacetGrid at 0x13c2f1110>

In [57]:
setup = """
import torch
from hessianfree.utils import generate_pd_matrix
from hessianfree.cg import pcg
A = generate_pd_matrix(9)
b = torch.rand(9, 1)
"""
print(min(timeit.Timer(stmt="pcg(A, b)", setup=setup).repeat(10, 100)))
print(min(timeit.Timer(stmt="torch.solve(b, A)", setup=setup).repeat(10, 100)))

0.05458921800004646
0.0013787340001272241


In [103]:
import torch
from hessianfree.utils import generate_pd_matrix
from hessianfree.cg import pcg
A = generate_pd_matrix(1000)
b = torch.rand(1000, 1)

pcg(A, b)

(tensor([[-7.6950e-01],
         [-1.2026e-01],
         [ 7.1309e-02],
         [-4.7446e-01],
         [-1.1632e+00],
         [-4.3829e-01],
         [ 6.2798e-01],
         [-4.3327e-01],
         [ 1.2831e+00],
         [-8.8527e-02],
         [-3.8648e-01],
         [ 1.7995e-01],
         [ 6.2506e-01],
         [-6.4623e-01],
         [-2.3241e-02],
         [ 3.9919e-01],
         [-5.2401e-01],
         [ 7.8955e-01],
         [-6.6183e-01],
         [ 3.6323e-01],
         [-3.3732e-01],
         [-4.8923e-01],
         [-1.5013e-01],
         [-1.1909e+00],
         [ 1.0442e+00],
         [ 1.0128e+00],
         [-1.0990e+00],
         [ 1.2023e+00],
         [-3.1026e-01],
         [-2.3711e-01],
         [ 1.0081e-01],
         [-5.6993e-01],
         [-6.2265e-01],
         [-5.8564e-01],
         [-8.9926e-01],
         [ 6.3891e-01],
         [-8.2698e-01],
         [ 1.7494e-01],
         [ 3.4436e-01],
         [-1.1335e-01],
         [-1.8798e-01],
         [-7.345