# Multi-Fidelity Deep Gaussian process benchmark

This notebook replicates the benchmark experiments from the paper:

[Deep Gaussian Processes for Multi-fidelity Modeling (Kurt Cutajar, Mark Pullin, Andreas Damianou, Neil Lawrence, Javier González)](https://arxiv.org/abs/1903.07320)

Note that the code for one of the benchmark models in the paper, "Deep Multi-fidelity Gaussian process", is not publically available and so does not appear in this notebook.

In [1]:
from prettytable import PrettyTable
import numpy as np
import scipy.stats
from sklearn.metrics import mean_squared_error, r2_score
import emukit.examples.multi_fidelity_dgp

from emukit.examples.multi_fidelity_dgp.baseline_model_wrappers import LinearAutoRegressiveModel, NonLinearAutoRegressiveModel, HighFidelityGp

from emukit.core import ContinuousParameter, ParameterSpace
from emukit.experimental_design import LatinDesign
from emukit.examples.multi_fidelity_dgp.multi_fidelity_deep_gp import MultiFidelityDeepGP

from emukit.test_functions.multi_fidelity import (multi_fidelity_borehole_function, multi_fidelity_branin_function,
                                                  multi_fidelity_park_function, multi_fidelity_hartmann_3d,
                                                  multi_fidelity_currin_function)

# Parameters for different benchmark functions

In [2]:
from collections import namedtuple

Function = namedtuple('Function', ['name', 'y_scale', 'noise_level', 'do_x_scaling', 'num_data', 'fcn'])

borehole = Function(name='borehole', y_scale=100, noise_level=[0.05, 0.1], do_x_scaling=True, num_data=[60, 5], 
                    fcn=multi_fidelity_borehole_function)
branin = Function(name='branin', y_scale=1, noise_level=[0., 0., 0.], do_x_scaling=False, num_data=[80, 30, 10], 
                    fcn=multi_fidelity_branin_function)
currin = Function(name='currin', y_scale=1, noise_level=[0., 0.], do_x_scaling=False, num_data=[12, 5], 
                    fcn=multi_fidelity_currin_function)
park = Function(name='park', y_scale=1, noise_level=[0., 0.], do_x_scaling=False, num_data=[30, 5], 
                    fcn=multi_fidelity_park_function)
hartmann_3d = Function(name='hartmann', y_scale=1, noise_level=[0., 0., 0.], do_x_scaling=False, num_data=[80, 40, 20], 
                    fcn=multi_fidelity_hartmann_3d)

In [3]:
# Function to repeat test across different random seeds.

def do_benchmark(fcn_tuple):
    metrics = dict()

    # Some random seeds to use
    seeds = [123, 184, 202, 289, 732]

    for i, seed in enumerate(seeds):
        run_name = str(seed) + str(fcn_tuple.num_data)
        metrics[run_name] = test_function(fcn_tuple, seed)
        print('After ' + str(i+1) + ' runs of ' + fcn_tuple.name)
        print_metrics(metrics)

    return metrics

In [4]:
# Print metrics as table 
def print_metrics(metrics):
    model_names = list(list(metrics.values())[0].keys())
    metric_names = ['r2', 'mnll', 'rmse']
    table = PrettyTable(['model'] + metric_names)

    for name in model_names:
        mean = []
        for metric_name in metric_names:
            mean.append(np.mean([metric[name][metric_name] for metric in metrics.values()]))
        table.add_row([name] + mean)

    print(table)

In [5]:
def test_function(fcn, seed):
    np.random.seed(seed)

    x_test, y_test, X, Y = generate_data(fcn, 1000)

    mf_dgp_fix_lf_mean = MultiFidelityDeepGP(X, Y, n_iter=5000)
    mf_dgp_fix_lf_mean.name = 'mf_dgp_fix_lf_mean'

    models = [HighFidelityGp(X, Y), LinearAutoRegressiveModel(X, Y), NonLinearAutoRegressiveModel(X, Y), mf_dgp_fix_lf_mean]
    return benchmark_models(models, x_test, y_test)

In [6]:
def benchmark_models(models, x_test, y_test):
    metrics = dict()
    for model in models:
        model.optimize()
        y_mean, y_var = model.predict(x_test)
        metrics[model.name] = calculate_metrics(y_test, y_mean, y_var)
        print('+ ######################## +')
        print(model.name, 'r2', metrics[model.name]['r2'])
        print('+ ######################## + ')
    return metrics

In [7]:
def generate_data(fcn_tuple, n_test_points):
    """
    Generates train and test data for
    """
    
    do_x_scaling = fcn_tuple.do_x_scaling
    fcn, space = fcn_tuple.fcn()
    
    # Generate training data
    new_space = ParameterSpace(space._parameters[:-1])
    latin = LatinDesign(new_space)
    X = [latin.get_samples(n) for n in fcn_tuple.num_data]
    
    # Scale X if required
    if do_x_scaling:
        scalings = X[0].std(axis=0)
    else:
        scalings = np.ones(X[0].shape[1])
        
    for x in X:
        x /= scalings
    
    Y = []
    for i, x in enumerate(X):
        Y.append(fcn.f[i](x * scalings))
    
    y_scale = fcn_tuple.y_scale
    
    # scale y and add noise if required
    noise_levels = fcn_tuple.noise_level
    if any([n > 0 for n in noise_levels]):
        for y, std_noise in zip(Y, noise_levels):
            y /= y_scale + std_noise * np.random.randn(y.shape[0], 1)
    
    # Generate test data
    x_test = latin.get_samples(n_test_points)
    x_test /= scalings
    y_test = fcn.f[-1](x_test * scalings)
    y_test /= y_scale

    i_highest_fidelity = (len(fcn_tuple.num_data) - 1) * np.ones((x_test.shape[0], 1))
    x_test = np.concatenate([x_test, i_highest_fidelity], axis=1)
    print(X[1].shape)
    return x_test, y_test, X, Y

In [8]:
def calculate_metrics(y_test, y_mean_prediction, y_var_prediction):
    # R2
    r2 = r2_score(y_test, y_mean_prediction)
    # RMSE
    rmse = np.sqrt(mean_squared_error(y_test, y_mean_prediction))
    # Test log likelihood
    mnll = -np.sum(scipy.stats.norm.logpdf(y_test, loc=y_mean_prediction, scale=np.sqrt(y_var_prediction)))/len(y_test)
    return {'r2': r2, 'rmse': rmse, 'mnll': mnll}

In [9]:
metrics = []
metrics.append(do_benchmark(borehole))

(5, 8)
Optimization restart 1/10, f = 2.6780050188822546
Optimization restart 2/10, f = 1.0935264476928328
Optimization restart 3/10, f = 6.857006533137321
Optimization restart 4/10, f = 2.6779984957092537
Optimization restart 5/10, f = 6.857006533050038
Optimization restart 6/10, f = 2.6779982188829115
Optimization restart 7/10, f = 2.524868617934755
Optimization restart 8/10, f = 6.8570065331075085
Optimization restart 9/10, f = 6.857006533126267
Optimization restart 10/10, f = 6.857006511305246
+ ######################## +
hf_gp r2 0.853075000266343
+ ######################## + 




Optimization restart 1/10, f = -156.8285538162658
Optimization restart 2/10, f = -179.38036874969842
Optimization restart 3/10, f = -163.36705147772054
Optimization restart 4/10, f = -182.28237823155325
Optimization restart 5/10, f = -176.24304083993212




Optimization restart 6/10, f = -119.99978816309434
Optimization restart 7/10, f = -175.37409704195088
Optimization restart 8/10, f = -179.61244375281075
Optimization restart 9/10, f = -182.86672215283113
Optimization restart 10/10, f = -161.18706361447067




Optimization restart 1/10, f = -186.38395209569308
Optimization restart 2/10, f = -147.3071062351859
Optimization restart 3/10, f = -169.85954282860385
Optimization restart 4/10, f = -179.35403697892826
Optimization restart 5/10, f = -183.22412818655513
Optimization restart 6/10, f = -137.07203845602123
Optimization restart 7/10, f = -178.09572141639865
Optimization restart 8/10, f = -171.86507595255608
Optimization restart 9/10, f = -176.29583805979854
Optimization restart 10/10, f = -43.10639285241553
+ ######################## +
ar1 r2 0.9998972497557327
+ ######################## + 
Optimization restart 1/10, f = -136.9370331901252
Optimization restart 2/10, f = -136.93703386790318
Optimization restart 3/10, f = -136.93703278467692
Optimization restart 4/10, f = -136.93703641182705
Optimization restart 5/10, f = -136.93703548981853
Optimization restart 6/10, f = -136.93703389359064
Optimization restart 7/10, f = -136.9370235017134
Optimization restart 8/10, f = -136.93703798785725




Optimization restart 2/10, f = 2.5248676236134466
Optimization restart 3/10, f = 2.524868325063706
Optimization restart 4/10, f = 6.857006533146583
Optimization restart 5/10, f = 2.5248697071171575
Optimization restart 6/10, f = 6.857006532504605
Optimization restart 7/10, f = 1.0935260099510984
Optimization restart 8/10, f = -7.649649343166953




Optimization restart 9/10, f = -7.6497521754149265
Optimization restart 10/10, f = 6.857006532978832
+ ######################## +
nargp r2 0.9985789738974397
+ ######################## + 




+ ######################## +
mf_dgp_fix_lf_mean r2 0.9987049708753993
+ ######################## + 
After 1 runs of borehole
+--------------------+--------------------+---------------------+----------------------+
|       model        |         r2         |         mnll        |         rmse         |
+--------------------+--------------------+---------------------+----------------------+
|       hf_gp        | 0.853075000266343  |  96.14594484403838  | 0.17314327654728623  |
|        ar1         | 0.9998972497557327 | -3.9025995002940634 | 0.004578774062605934 |
|       nargp        | 0.9985789738974397 | -2.7349785645424847 | 0.017027810574982254 |
| mf_dgp_fix_lf_mean | 0.9987049708753993 |  -2.031759496425357 | 0.016255395866799045 |
+--------------------+--------------------+---------------------+----------------------+
(5, 8)
Optimization restart 1/10, f = 0.33646433185789526
Optimization restart 2/10, f = 6.095755725618849
Optimization restart 3/10, f = 6.095755513850342
Optimiz



Optimization restart 1/10, f = -175.7017231933433
Optimization restart 2/10, f = -179.72025414287847
Optimization restart 3/10, f = -172.54208731026495
Optimization restart 4/10, f = -171.370638232173




Optimization restart 5/10, f = -137.27491107308924
Optimization restart 6/10, f = -127.25252791765338
Optimization restart 7/10, f = -183.13166466037566
Optimization restart 8/10, f = -169.60249271488678
Optimization restart 9/10, f = 13.886584113044236
Optimization restart 10/10, f = -139.55901389884826




Optimization restart 1/10, f = -188.78584087902138
Optimization restart 2/10, f = -165.9255045776461
Optimization restart 3/10, f = -153.37076733339003
Optimization restart 4/10, f = -188.9609216585933
Optimization restart 5/10, f = 13.886284349594561
Optimization restart 6/10, f = -187.42461585981636
Optimization restart 7/10, f = 13.886287725923232
Optimization restart 8/10, f = -185.04689075322142
Optimization restart 9/10, f = -153.69295765749064
Optimization restart 10/10, f = -154.9567659155038
+ ######################## +
ar1 r2 0.9998843925649263
+ ######################## + 
Optimization restart 1/10, f = -130.92876017188414
Optimization restart 2/10, f = -130.92876657276398
Optimization restart 3/10, f = -130.92876767135255
Optimization restart 4/10, f = -130.92876559595632
Optimization restart 5/10, f = -130.9287654318012
Optimization restart 6/10, f = -130.9287645952371
Optimization restart 7/10, f = -130.9287645595528
Optimization restart 8/10, f = -130.92876255847608
Opti



Optimization restart 5/10, f = -9.375357877143411
Optimization restart 6/10, f = 6.0957556083365665
Optimization restart 7/10, f = 6.095755725581395
Optimization restart 8/10, f = -7.7623016245230865
Optimization restart 9/10, f = 6.09575307745472
Optimization restart 10/10, f = 2.47841213380755
+ ######################## +
nargp r2 0.9995643044326001
+ ######################## + 




+ ######################## +
mf_dgp_fix_lf_mean r2 0.9986658368962823
+ ######################## + 
After 2 runs of borehole
+--------------------+--------------------+---------------------+----------------------+
|       model        |         r2         |         mnll        |         rmse         |
+--------------------+--------------------+---------------------+----------------------+
|       hf_gp        | 0.8071332592664457 |  160.3940242445727  | 0.19708450539643602  |
|        ar1         | 0.9998908211603295 | -3.6043566191251193 | 0.004720923879911397 |
|       nargp        | 0.9990716391650198 | -2.9429921739227374 | 0.013234313348665924 |
| mf_dgp_fix_lf_mean | 0.9986854038858408 | -1.9802407338682442 | 0.01638793355339148  |
+--------------------+--------------------+---------------------+----------------------+
(5, 8)
Optimization restart 1/10, f = 1.8779979043393498
Optimization restart 2/10, f = 6.399411289184439
Optimization restart 3/10, f = 6.399411289183076
Optimiza



Optimization restart 1/10, f = -168.50921710500506




Optimization restart 2/10, f = -161.62602033740683
Optimization restart 3/10, f = -168.75166115236362
Optimization restart 4/10, f = -169.01421767386614
Optimization restart 5/10, f = -187.61243482187768
Optimization restart 6/10, f = -176.795221077795
Optimization restart 7/10, f = -175.71193696883745
Optimization restart 8/10, f = 26.180066864086797
Optimization restart 9/10, f = -182.1766033778814
Optimization restart 10/10, f = -168.87254062558475




Optimization restart 1/10, f = -189.16556929055702
Optimization restart 2/10, f = -166.75937587932424
Optimization restart 3/10, f = -176.2901213945462
Optimization restart 4/10, f = -174.91520529093754
Optimization restart 5/10, f = -147.00795492820075
Optimization restart 6/10, f = -184.356586721933
Optimization restart 7/10, f = -176.88761909732565
Optimization restart 8/10, f = -186.1073368257119
Optimization restart 9/10, f = -187.07412628413127
Optimization restart 10/10, f = -132.06991924414854
+ ######################## +
ar1 r2 0.9998895199793003
+ ######################## + 
Optimization restart 1/10, f = -137.3542531386492
Optimization restart 2/10, f = -137.35425221332142
Optimization restart 3/10, f = -137.3542551400852
Optimization restart 4/10, f = -137.35425127731895
Optimization restart 5/10, f = -137.35425358580272
Optimization restart 6/10, f = -137.35425074586522
Optimization restart 7/10, f = -137.3542530367982
Optimization restart 8/10, f = -137.35425178905263
Opt



Optimization restart 7/10, f = -8.271854167697615
Optimization restart 8/10, f = 6.3994112869529465
Optimization restart 9/10, f = 6.3994112646066394




Optimization restart 10/10, f = -7.398659437770409
+ ######################## +
nargp r2 0.9990111982546624
+ ######################## + 




+ ######################## +
mf_dgp_fix_lf_mean r2 0.9989347077881694
+ ######################## + 
After 3 runs of borehole
+--------------------+---------------------+---------------------+----------------------+
|       model        |          r2         |         mnll        |         rmse         |
+--------------------+---------------------+---------------------+----------------------+
|       hf_gp        | 0.18975752749105101 |  291.4859828798956  |  0.3432502967362628  |
|        ar1         |  0.9998903874333198 | -3.7475329752624433 | 0.004704488747757613 |
|       nargp        |  0.9990514921949006 | -2.9134123986447107 | 0.013481512355637095 |
| mf_dgp_fix_lf_mean |  0.9987685051866171 |  -1.973888364544819 | 0.015760758174585637 |
+--------------------+---------------------+---------------------+----------------------+
(5, 8)
Optimization restart 1/10, f = -0.23815772080697117
Optimization restart 2/10, f = 6.722265352396027
Optimization restart 3/10, f = 6.72226535238935



Optimization restart 1/10, f = -177.50715700135964
Optimization restart 2/10, f = -182.35137908524368
Optimization restart 3/10, f = -171.81106134555398
Optimization restart 4/10, f = -174.44659814920016
Optimization restart 5/10, f = -157.11188901523144
Optimization restart 6/10, f = -175.5289781791669
Optimization restart 7/10, f = -167.18342777251877
Optimization restart 8/10, f = -176.49503972337988
Optimization restart 9/10, f = -168.89886437947865
Optimization restart 10/10, f = -169.4018849717763
Optimization restart 1/10, f = -182.35137944948804




Optimization restart 2/10, f = -152.91855664270662
Optimization restart 3/10, f = -164.92761180484882
Optimization restart 4/10, f = -177.5966665426517
Optimization restart 5/10, f = -166.72400186170802
Optimization restart 6/10, f = -179.09841439538982
Optimization restart 7/10, f = -168.8735610114593




Optimization restart 8/10, f = -152.4452965078218
Optimization restart 9/10, f = -163.93108712632892
Optimization restart 10/10, f = -178.6932816478879
+ ######################## +
ar1 r2 0.9998995831570738
+ ######################## + 
Optimization restart 1/10, f = -129.41513361638297




Optimization restart 2/10, f = -136.08333505094856
Optimization restart 3/10, f = -130.6522278066674
Optimization restart 4/10, f = -136.08334209462464




Optimization restart 5/10, f = -135.81131203970205
Optimization restart 6/10, f = -136.0833480173486
Optimization restart 7/10, f = -136.08333939381376
Optimization restart 8/10, f = -136.0833458248548
Optimization restart 9/10, f = -136.08334704902
Optimization restart 10/10, f = -136.0833460676893




Optimization restart 1/10, f = -6.992756964026395
Optimization restart 2/10, f = 1.3627646820082702
Optimization restart 3/10, f = 6.7222653512107735
Optimization restart 4/10, f = 6.722265252631112
Optimization restart 5/10, f = -0.3415159407557442
Optimization restart 6/10, f = 6.722265294805144
Optimization restart 7/10, f = 6.72226534710218
Optimization restart 8/10, f = 6.722265347988294
Optimization restart 9/10, f = 1.362767090557449
Optimization restart 10/10, f = 6.7222653404306945
+ ######################## +
nargp r2 0.9985701574364635
+ ######################## + 




+ ######################## +
mf_dgp_fix_lf_mean r2 0.9991096860342353
+ ######################## + 
After 4 runs of borehole
+--------------------+--------------------+---------------------+----------------------+
|       model        |         r2         |         mnll        |         rmse         |
+--------------------+--------------------+---------------------+----------------------+
|       hf_gp        | 0.2321361381294123 |  220.03007452734528 |  0.3484158039203253  |
|        ar1         | 0.9998926863642583 | -3.8390772703338856 | 0.004667312717971484 |
|       nargp        | 0.9989311585052914 | -2.8651052224341105 | 0.014408912703090446 |
| mf_dgp_fix_lf_mean | 0.9988538003985217 | -1.9679598894890824 | 0.01521190997933396  |
+--------------------+--------------------+---------------------+----------------------+
(5, 8)
Optimization restart 1/10, f = 0.9907906809758065
Optimization restart 2/10, f = 0.9724527878094236
Optimization restart 3/10, f = 6.259344442108942
Optimiz



Optimization restart 6/10, f = 0.7803455588388024
Optimization restart 7/10, f = 6.259344430279247
Optimization restart 8/10, f = 1.6758547791956753
Optimization restart 9/10, f = 6.259344442214996
Optimization restart 10/10, f = 6.259344442215002
+ ######################## +
hf_gp r2 0.8481594974914033
+ ######################## + 




Optimization restart 1/10, f = -181.31134943751303
Optimization restart 2/10, f = -189.4348946936675
Optimization restart 3/10, f = -188.09083252468946




Optimization restart 4/10, f = -188.61780048223127
Optimization restart 5/10, f = -150.27283685249677
Optimization restart 6/10, f = -188.24072325890927
Optimization restart 7/10, f = -180.75341908927533
Optimization restart 8/10, f = -184.0017011293242
Optimization restart 9/10, f = -187.53274566935198
Optimization restart 10/10, f = -186.31852219986737




Optimization restart 1/10, f = -191.13771457815267
Optimization restart 2/10, f = -145.64378159530622
Optimization restart 3/10, f = -129.13467425198428
Optimization restart 4/10, f = -185.7379559753191
Optimization restart 5/10, f = -167.21016944745378
Optimization restart 6/10, f = -182.9160787575003
Optimization restart 7/10, f = -186.7798129508015
Optimization restart 8/10, f = -184.86279072240546
Optimization restart 9/10, f = -107.9576409924531
Optimization restart 10/10, f = -163.6479869895175
+ ######################## +
ar1 r2 0.9999315069185383
+ ######################## + 
Optimization restart 1/10, f = -136.62855225569234
Optimization restart 2/10, f = -136.6285539897515
Optimization restart 3/10, f = -136.6285520436615
Optimization restart 4/10, f = -136.62854176665266
Optimization restart 5/10, f = -136.6285532370645
Optimization restart 6/10, f = -136.62854672011755
Optimization restart 7/10, f = -136.62855250724243
Optimization restart 8/10, f = -136.62855333575612
Opti



Optimization restart 4/10, f = -8.38410054022058
Optimization restart 5/10, f = 6.259340118373446
Optimization restart 6/10, f = 6.259344442227789
Optimization restart 7/10, f = 6.259344433821681
Optimization restart 8/10, f = 6.259341465731598
Optimization restart 9/10, f = 6.259341904715415
Optimization restart 10/10, f = 0.7803571265329232
+ ######################## +
nargp r2 0.9987292088929406
+ ######################## + 




+ ######################## +
mf_dgp_fix_lf_mean r2 0.9980699864341062
+ ######################## + 
After 5 runs of borehole
+--------------------+--------------------+---------------------+----------------------+
|       model        |         r2         |         mnll        |         rmse         |
+--------------------+--------------------+---------------------+----------------------+
|       hf_gp        | 0.3553408100018105 |  322.9443044652573  |  0.3135555055193182  |
|        ar1         | 0.9999004504751143 | -3.9483659816140753 | 0.004473445236622577 |
|       nargp        | 0.9988907685828211 | -2.8525534323669226 | 0.01471285112951268  |
| mf_dgp_fix_lf_mean | 0.9986970376056385 | -1.9576738453062468 | 0.016095532828221362 |
+--------------------+--------------------+---------------------+----------------------+


In [10]:
for (metric) in zip(metrics):
    print(fcn_name)
    print_metrics(metric[0])

NameError: name 'fcn_name' is not defined