-
Notifications
You must be signed in to change notification settings - Fork 0
/
u.py
99 lines (80 loc) · 2.78 KB
/
u.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# lsqfitgp/examples/u.py
#
# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
#
# This file is part of lsqfitgp.
#
# lsqfitgp is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# lsqfitgp is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
"""
EXAMPLE U.
Where we infer the temporal scale of a process assuming
another process is correlated with its derivative.
"""
import lsqfitgp as lgp
from matplotlib import pyplot as plt
import numpy as np
import gvar
data_deriv = 1
time = np.linspace(-5, 5, 10)
x = np.empty(len(time), dtype=[
('time', float),
('label', int)
])
x['time'] = time
x['label'] = 1
data_error = 0.05
data_mean = np.cos(time)
data_mean += data_error * np.random.randn(*data_mean.shape)
data = gvar.gvar(data_mean, np.full_like(data_mean, data_error))
label_scale = 5
corr = lgp.ExpQuad(scale=label_scale)(0, 1)
print(f'corr = {corr:.3g}')
def makegp(params):
kernel_time = lgp.ExpQuad(scale=params['time_scale'], dim='time')
kernel_label = lgp.ExpQuad(scale=label_scale, dim='label')
return (lgp
.GP(kernel_time * kernel_label)
.addx(x, 'data', deriv=(data_deriv, 'time'))
.addx(np.array([(0, 0)], dtype=x.dtype), 'fixed_point')
)
prior = {
'log(time_scale)': gvar.log(gvar.gvar(3, 2))
}
datadict = {'data': data, 'fixed_point': [gvar.gvar(0, 1e2)]}
params = lgp.empbayes_fit(prior, makegp, datadict, raises=False).p
print('time_scale:', params['time_scale'])
time_pred = np.linspace(-10, 10, 100)
xpred = np.empty((2, len(time_pred)), dtype=x.dtype)
xpred['time'] = time_pred
xpred['label'][0] = 0
xpred['label'][1] = 1
gp = (makegp(gvar.mean(params))
.addx(xpred[0], 0)
.addx(xpred[1], 1, deriv=(1, 'time'))
)
pred = gp.predfromdata(datadict, [0, 1])
fig, ax = plt.subplots(num='u', clear=True)
colors = dict()
for deriv in pred:
m = gvar.mean(pred[deriv])
s = gvar.sdev(pred[deriv])
polys = ax.fill_between(time_pred, m - s, m + s, alpha=0.5, label=f'deriv {deriv}')
colors[deriv] = polys.get_facecolor()[0]
for sample in gvar.raniter(pred, 3):
for deriv in pred:
ax.plot(time_pred, sample[deriv], color=colors[deriv])
ax.errorbar(time, gvar.mean(data), yerr=gvar.sdev(data), fmt='.', color=colors[data_deriv], alpha=1, label='data')
ax.legend(loc='best')
ax.set_xlabel('time')
fig.show()