-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
79 lines (60 loc) · 1.76 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import autograd.numpy as np
import sys
import toml
from get_dataset import *
import GPy
argv = sys.argv[1:]
conf = toml.load(argv[0])
### Load data
name = conf['funct']
funct = get_funct(name)
num = conf['num']
bounds = np.array(conf['bounds'])
dim = bounds.shape[0]
bfgs_iter = conf['bfgs_iter']
### GP training
dataset1, samples = init_dataset(funct, num, bounds)
#print('dataset1',dataset1)
#print('samples',samples)
train_x1 = dataset1['train_x'].T
train_y1 = dataset1['train_y'].T
k = GPy.kern.RBF(dim)
model1 = GPy.models.GPRegression(X=train_x1, Y=train_y1,kernel=k)
model1.kern.variance = np.var(train_y1)
model1.kern.lengthscale = np.std(train_x1)
model1.likelihood.variance = 0.01 * np.var(train_y1)
model1.optimize()
### Test
nn = 200
testdata = get_test(funct, nn, bounds)
X_star = testdata['test_x']
y_star = testdata['test_y']
y_pred1, y_var1 = model1.predict(X_star.T)
#print('y_star',y_star)
#print('y_pred1',y_pred1)
### BCM
samples_x = samples['samples_x']
samples_y = samples['samples_y']
n_clus = samples_x.shape[0]
sum_inv = np.zeros((nn,1))
sum_rat = np.zeros((nn,1))
for i in range(n_clus):
train_x2 = samples_x[i].T
train_y2 = samples_y[i].reshape(-1,1)
k = GPy.kern.RBF(dim)
model = GPy.models.GPRegression(X=train_x2, Y=train_y2,kernel=k)
model.kern.variance = np.var(train_y2)
model.kern.lengthscale = np.std(train_x2)
model.likelihood.variance = 0.01 * np.var(train_y2)
model.optimize()
y_pred, y_var = model.predict(X_star.T)
sum_inv = sum_inv + 1.0/y_var
sum_rat = sum_rat + y_pred/y_var
y_var2 = 1.0/sum_inv
y_pred2 = y_var2 * sum_rat
#print('y_pred2',y_pred2)
### Accuracy
rmse1 = np.linalg.norm(y_pred1.T - y_star)
rmse2 = np.linalg.norm(y_pred2.T - y_star)
print('rmse1',rmse1)
print('rmse2',rmse2)