-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpsgd_with_finite_precision_arithmetic.py
83 lines (76 loc) · 3.03 KB
/
psgd_with_finite_precision_arithmetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
This example studies preconditioner fitting with finite precision arithemetic.
The PSGD-Kron-QEP version is numerically stable with single precision.
But, the PSGD-Kron-EQ version is less stable as it needs tri solver to update Q.
"""
import sys
import matplotlib.pyplot as plt
import torch
import opt_einsum
sys.path.append("..")
from psgd import (init_kron,
precond_grad_kron_whiten_qep,
update_precond_kron_newton_qep,
precond_grad_kron)
num_iterations = 2000
N = 10
H1 = torch.randn(N, N, dtype=torch.float64) / N**0.5
H2 = torch.randn(N, N, dtype=torch.float64) / N**0.5
H3 = torch.randn(N, N, dtype=torch.float64) / N**0.5
H4 = torch.randn(N, N, dtype=torch.float64) / N**0.5
H5 = torch.randn(N, N, dtype=torch.float64) / N**0.5
print(f"Cond(H): {(torch.linalg.cond(H1)*torch.linalg.cond(H2)*torch.linalg.cond(H3)*torch.linalg.cond(H4)*torch.linalg.cond(H5))**2}")
H1 = H1 @ H1.T
H2 = H2 @ H2.T
H3 = H3 @ H3.T
H4 = H4 @ H4.T
H5 = H5 @ H5.T
"""
Preconditioner fitting with hvp
"""
legends = []
for dtype in [torch.float64, torch.float32]:
H1, H2, H3, H4, H5 = H1.to(dtype), H2.to(dtype), H3.to(dtype), H4.to(dtype), H5.to(dtype)
V = torch.randn(N, N, N, N, N, dtype=dtype)
QL, exprs = init_kron(V, 1.0, float("inf"), float("inf"))
errs = []
for i in range(num_iterations):
V = torch.randn(N, N, N, N, N, dtype=dtype)
G = opt_einsum.contract("aA,bB,cC,dD,eE, ABCDE->abcde", H1,H2,H3,H4,H5, V)
update_precond_kron_newton_qep(QL, exprs, V, G, lr=0.01)
precond_grad = precond_grad_kron(QL, exprs, G)
err = torch.mean((precond_grad - V)**2).item()
errs.append(err)
plt.semilogy(errs)
legends.append(str(dtype)[6:])
plt.ylabel(r"$\|Pg-H^{-1}g\|^2/N^5$")
plt.xlabel("Iterations")
plt.legend(legends)
plt.tick_params(axis='y', which='both', labelleft=False, labelright=True)
plt.title("Preconditioner fitting with hvps")
plt.savefig("psgd_hvp_with_fpa.svg")
plt.show()
"""
Preconditioner fitting with gradients (gradient whitening)
"""
legends = []
for dtype in [torch.float64, torch.float32]:
H1, H2, H3, H4, H5 = H1.to(dtype), H2.to(dtype), H3.to(dtype), H4.to(dtype), H5.to(dtype)
V = torch.randn(N, N, N, N, N, dtype=dtype)
QL, exprs = init_kron(V, 1.0, float("inf"), float("inf"))
errs = []
for i in range(num_iterations):
V = torch.randn(N, N, N, N, N, dtype=dtype)
G = opt_einsum.contract("aA,bB,cC,dD,eE, ABCDE->abcde", H1,H2,H3,H4,H5, V)
precond_grad = precond_grad_kron_whiten_qep(QL, exprs, G, lr=0.01)
err = torch.mean((precond_grad - V)**2).item()
errs.append(err)
plt.semilogy(errs)
legends.append(str(dtype)[6:])
plt.ylabel(r"$\|Pg-H^{-1}g\|^2/N^5$")
plt.xlabel("Iterations")
plt.legend(legends)
plt.tick_params(axis='y', which='both', labelleft=False, labelright=True)
plt.title("Preconditioner fitting with gradients")
plt.savefig("psgd_whitening_with_fpa.svg")
plt.show()