-
Notifications
You must be signed in to change notification settings - Fork 0
/
w.py
93 lines (72 loc) · 2.46 KB
/
w.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# lsqfitgp/examples/w.py
#
# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
#
# This file is part of lsqfitgp.
#
# lsqfitgp is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# lsqfitgp is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
"""
EXAMPLE W.
Manually split a process as a sum of two processes, using an additional
index in the input space.
"""
import numpy as np
from matplotlib import pyplot as plt
import gvar
import lsqfitgp as lgp
time = np.arange(30)
time_pred = np.linspace(-30, 60, 200)
def makex(time, comp):
x = np.empty(len(time), dtype=[
('time', float),
('comp', 'U8')
])
x['time'] = time
x['comp'] = comp
return x
kshort = lgp.ExpQuad(scale=1, dim='time')
klong = lgp.ExpQuad(scale=10, dim='time')
kernel = kshort.linop('cond', klong, lambda x: x['comp'] == 'short')
gp = lgp.GP(kernel)
def addcomps(gp, key, time):
return (gp
.addx(makex(time, 'short'), key + 'short')
.addx(makex(time, 'long'), key + 'long')
.addtransf({key + 'short': 0.3, key + 'long': 1}, key)
)
gp = addcomps(gp, 'data', time)
gp = addcomps(gp, 'pred', time_pred)
print('generate data...')
prior = gp.prior(['data', 'datashort', 'datalong'])
data = gvar.sample(prior)
print('prediction...')
pred = gp.predfromdata({'data': data['data']}, ['pred', 'predshort', 'predlong'])
print('sample posterior...')
mean = gvar.mean(pred)
sdev = gvar.sdev(pred)
samples = list(gvar.raniter(pred, 1))
print('figure...')
fig, axs = plt.subplots(3, 1, num='w', clear=True, figsize=[6, 7], layout='constrained')
for ax, comp in zip(axs, ['', 'short', 'long']):
key = 'pred' + comp
m = mean[key]
s = sdev[key]
ax.fill_between(time_pred, m - s, m + s, alpha=0.3, color='b')
for sample in samples:
ax.plot(time_pred, sample[key], alpha=0.2, color='b')
ax.plot(time, data['data' + comp], '.k')
axs[0].set_ylabel('A + B')
axs[1].set_ylabel('A')
axs[2].set_ylabel('B')
fig.show()