In [1]:
from typing import Tuple
import numpy as np
import deepgp
import matplotlib.pyplot as plt

import GPy
from GPy.models import GPRegression
from emukit.test_functions import forrester_function
from emukit.core.initial_designs import RandomDesign
from emukit.model_wrappers import GPyModelWrapper
from emukit.bayesian_optimization.acquisitions import ExpectedImprovement, NegativeLowerConfidenceBound, ProbabilityOfImprovement
from emukit.core.optimization import GradientAcquisitionOptimizer
from emukit.core.initial_designs import RandomDesign
from emukit.core import ParameterSpace, ContinuousParameter
from emukit.sensitivity.monte_carlo import MonteCarloSensitivity
from gpflow.kernels import RBF, White, Linear
from tqdm import tqdm

from simulator import MainSimulator, TinySimulator
from world import DebugInfo
from pprint import pprint

main_simulator = MainSimulator()

In [2]:
def target_function_list(X):
    Y = np.array([[0]])
    for x in X:
        mutation_rates = {
            "size": x[0],
            "speed": x[1],
            "vision": x[2],
            "aggression": x[3]
        }
        days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
            period=10, should_display_day=True, should_display_grid=False, should_display_traits=False), max_days=10000)
        Y = np.append(Y, [[days_survived]], axis = 0)
    return Y[1:]

In [3]:
# Sanity Checks with experimental parameters

mutation_rates = {
    "size": 0,
    "speed": 0,
    "vision": 0,
    "aggression": 0
}

days_log = []
for i in tqdm(range(10)):
    main_simulator = MainSimulator()
    days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
        period=1, should_display_day=True, should_display_grid=False, should_display_traits=False, should_display_population=True), max_days=10000)
    days_log.append(days_survived)
    print(days_survived)
days_log

  0%|                                                                                                                                                                                                                                                                                                | 0/10 [00:00<?, ?it/s]

0.9920066674924358
Day number: 1
Population: 595
0.9685294927600133
Day number: 2
Population: 460
0.9310165005141857
Day number: 3
Population: 549
0.881690868822638
Day number: 4
Population: 666
0.8233021939280022
Day number: 5
Population: 848
0.7588418252611829
Day number: 6
Population: 1057
0.69126675303134
Day number: 7
Population: 1342
0.6232689069636975
Day number: 8
Population: 1668
0.5571129681135137
Day number: 9
Population: 2048
0.4945501940025
Day number: 10
Population: 2471
0.4368021850861846
Day number: 11
Population: 2869
0.3845994180915832
Day number: 12
Population: 3133
0.33825543833697597
Day number: 13
Population: 3257
0.2977581124277974
Day number: 14
Population: 3222
0.262862784101325
Day number: 15
Population: 3041
0.23317690496603324
Day number: 16
Population: 2782
0.20823040227949466
Day number: 17
Population: 2457
0.18752988270433033
Day number: 18
Population: 2106
0.17059743207865904
Day number: 19
Population: 1700
0.15699629283049923
Day number: 20
Population: 

 10%|████████████████████████████                                                                                                                                                                                                                                                            | 1/10 [00:29<04:24, 29.34s/it]

Day number: 1215
Population: 615
0.009851893665381424
Day number: 1216
Population: 600
0.00806677103305752
Day number: 1217
Population: 591
0.0067228388394057845
Day number: 1218
Population: 592
0.0057136930489427
Day number: 1219
Population: 570
0.0049607302805756434
Day number: 1220
Population: 532
0.004406529417728263
Day number: 1221
Population: 508
0.004009806337693135
Day number: 1222
Population: 464
0.0037416823011836523
Day number: 1223
Population: 416
0.003583026650524563
Day number: 1224
Population: 369
0.0035226774101574277
Day number: 1225
Population: 328
0.00355639542423669
Day number: 1226
Population: 296
0.0036864618262427945
Day number: 1227
Population: 249
0.00392188226380314
Day number: 1228
Population: 216
0.004279214128467565
Day number: 1229
Population: 184
0.0047840856400907525
Day number: 1230
Population: 151
0.005473528078636363
Day number: 1231
Population: 134
0.006399292882794254
Day number: 1232
Population: 117
0.00763236822645264
Day number: 1233
Population:

 20%|████████████████████████████████████████████████████████                                                                                                                                                                                                                                | 2/10 [00:53<03:30, 26.25s/it]

0.024841540557264487
Day number: 818
Population: 246
0.0215967631057247
Day number: 819
Population: 179
0.019118189995404045
Day number: 820
Population: 142
0.01725751568157285
Day number: 821
Population: 122
0.015904125265057866
Day number: 822
Population: 102
0.014978266390611922
Day number: 823
Population: 80
0.014425850321230956
Day number: 824
Population: 69
0.014214791206736997
Day number: 825
Population: 55
0.014332777497321113
Day number: 826
Population: 44
0.014786393084743278
Day number: 827
Population: 40
0.015601549741935046
Day number: 828
Population: 32
0.01682524329798734
Day number: 829
Population: 23
0.018528692252419238
Day number: 830
Population: 17
0.02081194676235059
Day number: 831
Population: 15
0.02381005189709621
Day number: 832
Population: 9
0.027700789334653794
Day number: 833
Population: 6
0.03271387642074647
Day number: 834
Population: 6
0.039141234562344626
Day number: 835
Population: 4
0.0473475132859795
Day number: 836
Population: 1
0.05777944632065562
D

 30%|████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                                                    | 3/10 [01:20<03:04, 26.42s/it]

Population: 148
0.015778409209154586
Day number: 919
Population: 127
0.01389997355632811
Day number: 920
Population: 104
0.01249684451457387
Day number: 921
Population: 81
0.011480364594640313
Day number: 922
Population: 57
0.010787116784275245
Day number: 923
Population: 44
0.010374354075609113
Day number: 924
Population: 38
0.010216827779310102
Day number: 925
Population: 34
0.010304876001688584
Day number: 926
Population: 31
0.01064367515039901
Day number: 927
Population: 34
0.011253611893326979
Day number: 928
Population: 31
0.012171791652495668
Day number: 929
Population: 23
0.013454755271869241
Day number: 930
Population: 17
0.015182519853345242
Day number: 931
Population: 11
0.017464081271856283
Day number: 932
Population: 7
0.020444496882137418
Day number: 933
Population: 4
0.024313581724320656
Day number: 934
Population: 1
0.029316065904153078
Day number: 935
Population: 2
0.03576273464019662
Day number: 936
Population: 3
0.04404156782720593
Day number: 937
Population: 2
0.054

 40%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                        | 4/10 [01:48<02:42, 27.05s/it]

1038
0.9920066674924358
Day number: 1
Population: 584
0.9685294927600133
Day number: 2
Population: 426
0.9310165005141857
Day number: 3
Population: 537
0.881690868822638
Day number: 4
Population: 660
0.8233021939280022
Day number: 5
Population: 845
0.7588418252611829
Day number: 6
Population: 1016
0.69126675303134
Day number: 7
Population: 1253
0.6232689069636975
Day number: 8
Population: 1536
0.5571129681135137
Day number: 9
Population: 1863
0.4945501940025
Day number: 10
Population: 2238
0.4368021850861846
Day number: 11
Population: 2615
0.3845994180915832
Day number: 12
Population: 2956
0.33825543833697597
Day number: 13
Population: 3195
0.2977581124277974
Day number: 14
Population: 3276
0.262862784101325
Day number: 15
Population: 3169
0.23317690496603324
Day number: 16
Population: 2902
0.20823040227949466
Day number: 17
Population: 2591
0.18752988270433033
Day number: 18
Population: 2246
0.17059743207865904
Day number: 19
Population: 1883
0.15699629283049923
Day number: 20
Populat

 50%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                            | 5/10 [02:19<02:23, 28.79s/it]

Day number: 1223
Population: 486
0.003583026650524563
Day number: 1224
Population: 459
0.0035226774101574277
Day number: 1225
Population: 418
0.00355639542423669
Day number: 1226
Population: 372
0.0036864618262427945
Day number: 1227
Population: 325
0.00392188226380314
Day number: 1228
Population: 253
0.004279214128467565
Day number: 1229
Population: 193
0.0047840856400907525
Day number: 1230
Population: 143
0.005473528078636363
Day number: 1231
Population: 94
0.006399292882794254
Day number: 1232
Population: 56
0.00763236822645264
Day number: 1233
Population: 29
0.009268933716345878
Day number: 1234
Population: 15
0.011437977465564714
Day number: 1235
Population: 5
0.014310717069861013
Day number: 1236
Population: 2
0.018111774176149672
Day number: 1237
Population: 0
1237
0.9920066674924358
Day number: 1
Population: 589
0.9685294927600133
Day number: 2
Population: 434
0.9310165005141857
Day number: 3
Population: 545
0.881690868822638
Day number: 4
Population: 647
0.8233021939280022
Da

 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                | 6/10 [02:48<01:55, 28.80s/it]

1033
0.9920066674924358
Day number: 1
Population: 578
0.9685294927600133
Day number: 2
Population: 419
0.9310165005141857
Day number: 3
Population: 555
0.881690868822638
Day number: 4
Population: 704
0.8233021939280022
Day number: 5
Population: 913
0.7588418252611829
Day number: 6
Population: 1116
0.69126675303134
Day number: 7
Population: 1369
0.6232689069636975
Day number: 8
Population: 1713
0.5571129681135137
Day number: 9
Population: 2068
0.4945501940025
Day number: 10
Population: 2462
0.4368021850861846
Day number: 11
Population: 2785
0.3845994180915832
Day number: 12
Population: 3030
0.33825543833697597
Day number: 13
Population: 3178
0.2977581124277974
Day number: 14
Population: 3142
0.262862784101325
Day number: 15
Population: 2928
0.23317690496603324
Day number: 16
Population: 2673
0.20823040227949466
Day number: 17
Population: 2373
0.18752988270433033
Day number: 18
Population: 2034
0.17059743207865904
Day number: 19
Population: 1697
0.15699629283049923
Day number: 20
Populat

 70%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                    | 7/10 [03:20<01:29, 29.86s/it]

Day number: 1117
Population: 87
0.009495146001897556
Day number: 1118
Population: 66
0.008115725243959562
Day number: 1119
Population: 46
0.007080491320441759
Day number: 1120
Population: 26
0.006314799485675519
Day number: 1121
Population: 15
0.005764511179474178
Day number: 1122
Population: 13
0.0053914807660834846
Day number: 1123
Population: 8
0.005170295612795855
Day number: 1124
Population: 6
0.005086069231012701
Day number: 1125
Population: 4
0.005133134125357113
Day number: 1126
Population: 4
0.00531453605895018
Day number: 1127
Population: 4
0.005642289158773332
Day number: 1128
Population: 3
0.006138409293775964
Day number: 1129
Population: 3
0.0068368002515461915
Day number: 1130
Population: 4
0.00778612155819533
Day number: 1131
Population: 4
0.00905381396426725
Day number: 1132
Population: 4
0.010731489690756582
Day number: 1133
Population: 3
0.012941893598003238
Day number: 1134
Population: 2
0.01584758346550148
Day number: 1135
Population: 1
0.019661327201852676
Day numb

 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                        | 8/10 [03:51<01:00, 30.01s/it]

1034
0.9920066674924358
Day number: 1
Population: 590
0.9685294927600133
Day number: 2
Population: 433
0.9310165005141857
Day number: 3
Population: 520
0.881690868822638
Day number: 4
Population: 615
0.8233021939280022
Day number: 5
Population: 802
0.7588418252611829
Day number: 6
Population: 975
0.69126675303134
Day number: 7
Population: 1201
0.6232689069636975
Day number: 8
Population: 1476
0.5571129681135137
Day number: 9
Population: 1788
0.4945501940025
Day number: 10
Population: 2138
0.4368021850861846
Day number: 11
Population: 2506
0.3845994180915832
Day number: 12
Population: 2829
0.33825543833697597
Day number: 13
Population: 3047
0.2977581124277974
Day number: 14
Population: 3169
0.262862784101325
Day number: 15
Population: 3081
0.23317690496603324
Day number: 16
Population: 2859
0.20823040227949466
Day number: 17
Population: 2576
0.18752988270433033
Day number: 18
Population: 2191
0.17059743207865904
Day number: 19
Population: 1801
0.15699629283049923
Day number: 20
Populati

 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                            | 9/10 [04:19<00:29, 29.58s/it]

Day number: 920
Population: 390
0.01249684451457387
Day number: 921
Population: 331
0.011480364594640313
Day number: 922
Population: 289
0.010787116784275245
Day number: 923
Population: 252
0.010374354075609113
Day number: 924
Population: 223
0.010216827779310102
Day number: 925
Population: 201
0.010304876001688584
Day number: 926
Population: 190
0.01064367515039901
Day number: 927
Population: 154
0.011253611893326979
Day number: 928
Population: 128
0.012171791652495668
Day number: 929
Population: 102
0.013454755271869241
Day number: 930
Population: 81
0.015182519853345242
Day number: 931
Population: 64
0.017464081271856283
Day number: 932
Population: 51
0.020444496882137418
Day number: 933
Population: 38
0.024313581724320656
Day number: 934
Population: 25
0.029316065904153078
Day number: 935
Population: 14
0.03576273464019662
Day number: 936
Population: 14
0.04404156782720593
Day number: 937
Population: 10
0.054627193398119195
Day number: 938
Population: 7
0.06808609310166014
Day numb

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [04:52<00:00, 29.21s/it]

Day number: 1221
Population: 378
0.004009806337693135
Day number: 1222
Population: 331
0.0037416823011836523
Day number: 1223
Population: 277
0.003583026650524563
Day number: 1224
Population: 230
0.0035226774101574277
Day number: 1225
Population: 195
0.00355639542423669
Day number: 1226
Population: 153
0.0036864618262427945
Day number: 1227
Population: 123
0.00392188226380314
Day number: 1228
Population: 88
0.004279214128467565
Day number: 1229
Population: 63
0.0047840856400907525
Day number: 1230
Population: 47
0.005473528078636363
Day number: 1231
Population: 26
0.006399292882794254
Day number: 1232
Population: 17
0.00763236822645264
Day number: 1233
Population: 8
0.009268933716345878
Day number: 1234
Population: 6
0.011437977465564714
Day number: 1235
Population: 3
0.014310717069861013
Day number: 1236
Population: 2
0.018111774176149672
Day number: 1237
Population: 0
1237





[1241, 838, 947, 1038, 1237, 1033, 1138, 1034, 949, 1237]

In [28]:
space = ParameterSpace([ContinuousParameter('size', 0, 1),
                        ContinuousParameter('speed', 0, 1),
                        ContinuousParameter('vision', 0, 1),
                        ContinuousParameter('aggression', 0, 1)])

design = RandomDesign(space) # Collect random points
num_data_points = 100
X = design.get_samples(num_data_points)
Y = target_function_list(X)

11.14022935152764
12.275998204157476
13.402863662543043
14.516417968967385
15.612305898749053
16.686241948324202
17.73402724817131
18.751566133830874
19.73488230962194
20.68013454126452
Day number: 10
21.583631815476412
22.441847906716394
23.251435293585406
24.00923836996421
24.71230589874905
25.357902659036274
25.943520240789546
26.466886944388353
26.925976745988525
27.31901729331276
Day number: 20
27.64449690031536
27.901170513116394
28.0880646236606
28.204481111708887
28.25
28.22448111170889
28.128064623660602
27.961170513116397
27.72449690031536
27.419017293312766
Day number: 30
27.045976745988526
26.60688694438835
26.103520240789543
25.53790265903627
24.912305898749054
24.229238369964207
23.491435293585408
22.7018479067164
21.863631815476417
20.98013454126452
Day number: 40
20.054882309621945
19.091566133830874
18.09402724817131
17.06624194832421
16.012305898749055
14.936417968967394
13.842863662543044
12.735998204157482
11.620229351527644
10.500000000000002
Day number: 50
9.37977

In [33]:
results = np.append(X,Y,axis=1)

In [34]:
np.savetxt('testing_data.csv', [result for result in results], delimiter=',', fmt='%s')

In [None]:
def plot_prediction(X,Y,x_plot,mu_plot,var_plot,axis):
    axis.plot(X, Y, "ro", markersize=10, label="Observations")
    axis.plot(x_plot[:, 0], mu_plot[:, 0], "C0", label="Model")
    axis.fill_between(x_plot[:, 0],
                     mu_plot[:, 0] + np.sqrt(var_plot)[:, 0],
                     mu_plot[:, 0] - np.sqrt(var_plot)[:, 0], color="C0", alpha=0.6)
    axis.fill_between(x_plot[:, 0],
                     mu_plot[:, 0] + 2 * np.sqrt(var_plot)[:, 0],
                     mu_plot[:, 0] - 2 * np.sqrt(var_plot)[:, 0], color="C0", alpha=0.4)
    axis.fill_between(x_plot[:, 0],
                     mu_plot[:, 0] + 3 * np.sqrt(var_plot)[:, 0],
                     mu_plot[:, 0] - 3 * np.sqrt(var_plot)[:, 0], color="C0", alpha=0.2)
    axis.legend(loc=2, prop={'size': 10})
    axis.set(xlabel=r"$x$", ylabel=r"$f(x)$")
    axis.grid(True)

In [None]:
def plot_acquisition_functions(x_plot, ei_plot, nlcb_plot, pi_plot, x_new, axis):
    axis.plot(x_plot, (ei_plot - np.min(ei_plot)) / (np.max(ei_plot) - np.min(ei_plot)), "green", label="EI")
    axis.plot(x_plot, (nlcb_plot - np.min(nlcb_plot)) / (np.max(nlcb_plot) - np.min(nlcb_plot)), "purple", label="NLCB")
    axis.plot(x_plot, (pi_plot - np.min(pi_plot)) / (np.max(pi_plot) - np.min(pi_plot)), "darkorange", label="PI")
    
    axis.axvline(x_new, color="red", label="x_next", linestyle="--")
    axis.legend(loc=1, prop={'size': 10})
    axis.set(xlabel=r"$x$", ylabel=r"$f(x)$")
    axis.grid(True)

In [None]:
x_plot = np.linspace(0, 10, 1000)[:, None]
X = np.array([[0],[5], [10]])
Y = np.array([[0]])
for x in X:
    Y = np.append(Y_init,target_speed_function(x),axis=0)
Y = Y[1:]

speed_model = GPRegression(X, Y, GPy.kern.RBF(1, lengthscale=1, variance=100), noise_var=1)
emukit_speed_model = GPyModelWrapper(speed_model)

ei_acquisition = ExpectedImprovement(emukit_speed_model)
nlcb_acquisition = NegativeLowerConfidenceBound(emukit_speed_model)
pi_acquisition = ProbabilityOfImprovement(emukit_speed_model)

In [None]:

mu_plot, var_plot = emukit_speed_model.predict(x_plot)
plot_prediction(X,Y,x_plot,mu_plot,var_plot,plt)
plt.show()

In [None]:
iterations = 20
figure, axis = plt.subplots(iterations, 2, figsize=(10, iterations*3))

for i in tqdm(range(iterations)):
    mu_plot, var_plot = emukit_speed_model.predict(x_plot)
    plot_prediction(X,Y,x_plot,mu_plot,var_plot,axis[i,0])
    
    ei_plot = ei_acquisition.evaluate(x_plot)
    nlcb_plot = nlcb_acquisition.evaluate(x_plot)
    pi_plot = pi_acquisition.evaluate(x_plot)
    
    optimizer = GradientAcquisitionOptimizer(ParameterSpace([ContinuousParameter('x1', 0, 10)]))
    x_new, _ = optimizer.optimize(nlcb_acquisition)
    print("Next position to query:", x_new)
    plot_acquisition_functions(x_plot, ei_plot, nlcb_plot, pi_plot, x_new, axis[i,1])
    
    y_new = target_speed_function(x_new)
    X = np.append(X, x_new, axis=0)
    Y = np.append(Y, y_new, axis=0)
    emukit_speed_model.set_data(X, Y)

plt.show()

In [None]:
X_train = np.array([np.array([110,200,200,400]),np.array([300,252,300,400])])
Y_train = np.array([[100],[200]])

In [None]:
# DGP using deepgp library
Q = 5
num_layers = 1
kern1 = GPy.kern.RBF(Q,ARD=True) + GPy.kern.Bias(Q)
kern2 = GPy.kern.RBF(X_train.shape[1],ARD=True) + GPy.kern.Bias(X_train.shape[1])
num_inducing = 4 # Number of inducing points to use for sparsification
back_constraint = False # Whether to use back-constraint for variational posterior
encoder_dims=[[300],[150]] # Dimensions of the MLP back-constraint if set to true

dgp_model = deepgp.DeepGP([X_train.shape[1], num_layers, Y_train.shape[1]], X_train, Y_train, kernels=[kern2,None], num_inducing=num_inducing, back_constraint=back_constraint, encoder_dims=encoder_dims)

for i in range(len(dgp_model.layers)):
    output_var = dgp_model.layers[i].Y.var() if i==0 else dgp_model.layers[i].Y.mean.var()
    dgp_model.layers[i].Gaussian_noise.variance = output_var*0.01
    dgp_model.layers[i].Gaussian_noise.variance.fix()

dgp_model.optimize(max_iters=800, messages=True)
for i in range(len(dgp_model.layers)):
    dgp_model.layers[i].Gaussian_noise.variance.unfix()
dgp_model.optimize(max_iters=1200, messages=True)

In [None]:
display(dgp_model)

In [None]:
x_plot = np.linspace(0, 1000, 1000)[:, None]
x_new = np.stack((x_plot,x_plot,x_plot,x_plot),axis = -1)
Y_pred = dgp_model.predict(np.array([[10,10,10,10]]))

In [None]:
def target_size_function(x):
    mutation_rates = {
        "size": x,
        "speed": 0,
        "vision": 0,
        "aggression": 0
    }
    days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
        period=10, should_display_day=True, should_display_grid=False, should_display_traits=False), max_days=10000)
    return days_survived
    
def target_speed_function(x):
    mutation_rates = {
        "size": 0,
        "speed": x,
        "vision": 0,
        "aggression": 0
    }
    days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
        period=10, should_display_day=True, should_display_grid=False, should_display_traits=False), max_days=10000)
    return days_survived

def target_vision_function(x):
    mutation_rates = {
        "size": 0,
        "speed": 0,
        "vision": x,
        "aggression": 0
    }
    days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
        period=10, should_display_day=True, should_display_grid=False, should_display_traits=False), max_days=10000)
    return days_survived

def target_aggression_function(x):
    mutation_rates = {
        "size": 0,
        "speed": 0,
        "vision": 0,
        "aggression": x
    }
    days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
        period=10, should_display_day=True, should_display_grid=False, should_display_traits=False), max_days=10000)
    return days_survived

def target_function(X):
    mutation_rates = {
        "size": X[0],
        "speed": X[1],
        "vision": X[2],
        "aggression": X[3]
    }
    days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
        period=10, should_display_day=True, should_display_grid=False, should_display_traits=False), max_days=10000)
    return days_survived

In [None]:
x_plot = np.linspace(0, 20, 1000)[:, None]

X_size = np.array([0,1,20])
X_speed = np.array([0,1,20])
X_vision = np.array([0,1,20])
X_aggression = np.array([0,1,20])

In [None]:
Y_size = np.array([])
for x in X_size:
    Y_size = np.append(Y_size,[target_size_function(x)],axis=0)

In [None]:
Y_aggression

In [None]:
Y_speed = np.array([])
for x in X_speed:
    Y_speed = np.append(Y_speed,[target_speed_function(x)],axis=0)

In [None]:
Y_vision = np.array([])
for x in X_vision:
    Y_vision = np.append(Y_vision,[target_vision_function(x)],axis=0)

In [None]:
Y_aggression = np.array([])
for x in X_aggression:
    Y_aggression = np.append(Y_aggression,[target_aggression_function(x)],axis=0)

In [None]:
size_model = GPRegression(X_size, Y_size, GPy.kern.RBF(1, lengthscale=1, variance=100), noise_var=1)
speed_model = GPRegression(X_size, Y_size, GPy.kern.RBF(1, lengthscale=1, variance=100), noise_var=1)
vision_model = GPRegression(X_size, Y_size, GPy.kern.RBF(1, lengthscale=1, variance=100), noise_var=1)
aggression_model = GPRegression(X_size, Y_size, GPy.kern.RBF(1, lengthscale=1, variance=100), noise_var=1)

emukit_size_model = GPyModelWrapper(size_model)
emukit_speed_model = GPyModelWrapper(speed_model)
emukit_vision_model = GPyModelWrapper(vision_model)
emukit_aggression_model = GPyModelWrapper(agression_model)

size_ei_acquisition = ExpectedImprovement(emukit_size_model)
size_nlcb_acquisition = NegativeLowerConfidenceBound(emukit_size_model)
size_pi_acquisition = ProbabilityOfImprovement(emukit_size_model)

speed_ei_acquisition = ExpectedImprovement(emukit_speed_model)
speed_nlcb_acquisition = NegativeLowerConfidenceBound(emukit_speed_model)
speed_pi_acquisition = ProbabilityOfImprovement(emukit_speed_model)

vision_ei_acquisition = ExpectedImprovement(emukit_vision_model)
vision_nlcb_acquisition = NegativeLowerConfidenceBound(emukit_vision_model)
vision_pi_acquisition = ProbabilityOfImprovement(emukit_vision_model)

aggression_ei_acquisition = ExpectedImprovement(emukit_aggression_model)
aggression_nlcb_acquisition = NegativeLowerConfidenceBound(emukit_aggression_model)
aggression_pi_acquisition = ProbabilityOfImprovement(emukit_aggression_model)

In [None]:
X_train = np.array([[1,1,1,1, Y_size[1], Y_speed[1], Y_vision[1], Y_aggression[1]],
                   [1,0,0,0, Y_size[1], Y_speed[0], Y_vision[0], Y_aggression[0]],
                   [0,1,0,0, Y_size[0], Y_speed[1], Y_vision[0], Y_aggression[0]],
                   [0,0,1,0, Y_size[0], Y_speed[0], Y_vision[1], Y_aggression[0]],
                   [0,0,0,1, Y_size[0], Y_speed[0], Y_vision[0], Y_aggression[1]]])
Y_train = np.array([[target_function([1,1,1,1])],[Y_size[1]],[Y_speed[1]],[Y_vision[1]],[Y_aggression[1]]])
Q = 5
num_layers = 1
kern1 = GPy.kern.RBF(Q,ARD=True) + GPy.kern.Bias(Q)
kern2 = GPy.kern.RBF(X_train.shape[1],ARD=True) + GPy.kern.Bias(X_train.shape[1])
num_inducing = 4 # Number of inducing points to use for sparsification
back_constraint = False # Whether to use back-constraint for variational posterior
encoder_dims=[[300],[150]] # Dimensions of the MLP back-constraint if set to true

In [None]:
def upper_confidence_bound(y_pred, y_std, beta):
    ucb = y_pred + beta * y_std
    return ucb

beta = 2.0

In [None]:
iterations = 20
figure, axis = plt.subplots(iterations, 2, figsize=(10, iterations*3))

for i in tqdm(range(iterations)):
    mu_speed_plot, var_speed_plot = emukit_speed_model.predict(x_plot)
    ei_speed_plot = speed_ei_acquisition.evaluate(x_plot)
    nlcb_speed_plot = speed_nlcb_acquisition.evaluate(x_plot)
    pi_speed_plot = speed_pi_acquisition.evaluate(x_plot)
    
    size_optimizer = GradientAcquisitionOptimizer(ParameterSpace([ContinuousParameter('x1', 0, 20)]))
    x_size_new, _ = size_optimizer.optimize(size_nlcb_acquisition)
    speed_optimizer = GradientAcquisitionOptimizer(ParameterSpace([ContinuousParameter('x1', 0, 20)]))
    x_speed_new, _ = speed_optimizer.optimize(speed_nlcb_acquisition)
    vision_optimizer = GradientAcquisitionOptimizer(ParameterSpace([ContinuousParameter('x1', 0, 20)]))
    x_vision_new, _ = vision_optimizer.optimize(vision_nlcb_acquisition)
    aggression_optimizer = GradientAcquisitionOptimizer(ParameterSpace([ContinuousParameter('x1', 0, 20)]))
    x_aggression_new, _ = aggression_optimizer.optimize(agression_nlcb_acquisition)

    print("Next position to query:", x_size_new, x_speed_new, x_vision_new, x_agression_new)
    
    y_size_new = target_size_function(x_size_new)
    X_size = np.append(X_size, x_size_new, axis=0)
    Y_size = np.append(Y_size, y_size_new, axis=0)
    emukit_size_model.set_data(X_size, Y_size)
    X_train = np.append(X_train,[[x_size_new,0,0,0,y_size_new, Y_speed[0], Y_vision[0], Y_aggression[0]]], axis=0)
    Y_train = np.append(Y_train,[[y_size_new]])

    y_speed_new = target_speed_function(x_speed_new)
    X_speed = np.append(X_speed, x_speed_new, axis=0)
    Y_speed = np.append(Y_speed, y_speed_new, axis=0)
    emukit_speed_model.set_data(X_speed, Y_speed)
    X_train = np.append(X_train,[[0,x_speed_new,0,0,Y_size[0], y_speed_new, Y_vision[0], Y_aggression[0]]], axis=0)
    Y_train = np.append(Y_train,[[y_speed_new]])

    y_vision_new = target_vision_function(x_vision_new)
    X_vision = np.append(X_vision, x_vision_new, axis=0)
    Y_vision = np.append(Y_vision, y_vision_new, axis=0)
    emukit_vision_model.set_data(X_vision, Y_vision)
    X_train = np.append(X_train,[[0,0,x_vision_new,0,Y_size[0], Y_speed[0], y_vision_new, Y_aggression[0]]], axis=0)
    Y_train = np.append(Y_train,[[y_vision_new]])

    y_aggression_new = target_aggression_function(x_aggression_new)
    X_aggression = np.append(X_aggression, x_speed_new, axis=0)
    Y_aggression = np.append(Y_aggression, y_aggression_new, axis=0)
    emukit_aggression_model.set_data(X_aggression, Y_aggression)
    X_train = np.append(X_train,[[0,0,0,x_aggression_new,Y_size[0], Y_speed[0], Y_vision[0], y_aggression_new]], axis=0)
    Y_train = np.append(Y_train,[[y_aggression_new]])

    X_train = np.append(X_train,[[x_size_new,x_speed_new,x_vision_new,x_aggression_new,y_size_new,y_speed_new,y_vision_new,y_aggression_new]], axis=0)
    Y_train = np.append(Y_train,[[target_function([x_size_new,x_speed_new,x_vision_new,x_aggression_new])]])

    dgp_model = deepgp.DeepGP([X_train.shape[1], num_layers, Y_train.shape[1]], X_train, Y_train, kernels=[kern2,None], num_inducing=num_inducing, back_constraint=back_constraint, encoder_dims=encoder_dims)
    
    for i in range(len(dgp_model.layers)):
        output_var = dgp_model.layers[i].Y.var() if i==0 else dgp_model.layers[i].Y.mean.var()
        dgp_model.layers[i].Gaussian_noise.variance = output_var*0.01
        dgp_model.layers[i].Gaussian_noise.variance.fix()
    
    dgp_model.optimize(max_iters=800, messages=True)
    for i in range(len(dgp_model.layers)):
        dgp_model.layers[i].Gaussian_noise.variance.unfix()
    dgp_model.optimize(max_iters=1200, messages=True)

In [None]:
x_plot = np.linspace(0, 10, 1000)[:, None]
X = np.array([[0],[5], [10]])
Y = np.array([[0]])
for x in X:
    Y = np.append(Y,target_speed_function(x),axis=0)
Y = Y[1:]

model = GPRegression(X, Y, GPy.kern.RBF(1, lengthscale=1, variance=100), noise_var=1)
emukit_model = GPyModelWrapper(model)

ei_acquisition = ExpectedImprovement(emukit_model)
nlcb_acquisition = NegativeLowerConfidenceBound(emukit_model)
pi_acquisition = ProbabilityOfImprovement(emukit_model)

In [None]:
def target_function_list(X):
    Y = np.array([[0]])
    for x in X:
        mutation_rates = {
            "size": x[0],
            "speed": x[1],
            "vision": x[2],
            "aggression": x[3]
        }
        days_survived, log = main_simulator.run(mutation_rates, debug_info=DebugInfo(
            period=10, should_display_day=True, should_display_grid=False, should_display_traits=False), max_days=10000)
        Y = np.append(Y, [[days_survived]], axis = 0)
    return Y[1:]

In [None]:
from emukit.core.initial_designs import RandomDesign
from emukit.core import ParameterSpace, ContinuousParameter

space = ParameterSpace([ContinuousParameter('size', 0, 20),
                        ContinuousParameter('speed', 0, 20),
                        ContinuousParameter('vision', 0, 20),
                        ContinuousParameter('aggression', 0, 20)])

design = RandomDesign(space) # Collect random points
num_data_points = 5
X = design.get_samples(num_data_points)
Y = target_function_list(X)
model_gpy = GPRegression(X,Y) # Train and wrap the model in Emukit
model_emukit = GPyModelWrapper(model_gpy)

In [None]:
ei_acquisition = ExpectedImprovement(model = model_emukit)
nlcb_acquisition = NegativeLowerConfidenceBound(model = model_emukit)
pi_acquisition = ProbabilityOfImprovement(model = model_emukit)

In [None]:
iterations = 20
figure, axis = plt.subplots(iterations, 2, figsize=(10, iterations*3))
# Control along which trait is the function plotted
plot = 0
x_plot = np.linspace(0, 20, 1000)[:, None]
x_zeros = np.linspace(0, 0, 1000)[:, None]
x_linear = np.linspace(0, 20, 1000)[:, None]

for i in tqdm(range(iterations)):
    for j in range(plot):
        x_plot = np.append(x_zeros, x_plot, axis = 1)
    for j in range(3-plot):
        x_plot = np.append(x_plot, x_zeros, axis = 1)
        
    mu_plot, var_plot = model_emukit.predict(x_plot)
    plot_prediction(X,Y,x_linear,mu_plot,var_plot,axis[i,0])
    
    ei_plot = ei_acquisition.evaluate(x_plot)
    nlcb_plot = nlcb_acquisition.evaluate(x_plot)
    pi_plot = pi_acquisition.evaluate(x_plot)
    
    optimizer = GradientAcquisitionOptimizer(ParameterSpace([ContinuousParameter('size', 0, 20),
                                                             ContinuousParameter('speed', 0, 20),
                                                             ContinuousParameter('vision', 0, 20),
                                                             ContinuousParameter('aggression', 0, 20)]))
    x_new, _ = optimizer.optimize(nlcb_acquisition)
    #print(x_new[0][plot])
    plot_acquisition_functions(x_linear, ei_plot, nlcb_plot, pi_plot, x_new[0][plot], axis[i,1])
    #print(x_new)
    print("Next position to query:", x_new)
    # plot_acquisition_functions(x_plot, ei_plot, nlcb_plot, pi_plot, x_new, axis[i,1])
    
    y_new = target_function_list(x_new)
    X = np.append(X, x_new, axis=0)
    Y = np.append(Y, y_new, axis=0)
    model_emukit.set_data(X, Y)

plt.show()