In [None]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from olympus.datasets import Dataset
from olympus.emulators import Emulator

from olympus.models import BayesNeuralNet

In [2]:
res = pickle.load(open('best_scores.pkl', 'rb'))

In [3]:
dataset_names = [
    'oer_plate_4098', 'oer_plate_3851', 'oer_plate_3860', 'oer_plate_3496',
    'p3ht', 'agnp', 
    'thin_film', 'crossed_barrel', 'autoam', 
    'suzuki_i', 'suzuki_ii', 'suzuki_iii', 'suzuki_iv',
]






dataset_params = { 
        'oer_plate_4098': {'out_act': 'sigmoid', 'feature_transform': 'identity', 'target_transform': 'normalize'},
        'oer_plate_3851': {'out_act': 'sigmoid', 'feature_transform': 'identity', 'target_transform': 'normalize'},
        'oer_plate_3860': {'out_act': 'sigmoid', 'feature_transform': 'identity', 'target_transform': 'normalize'},
        'oer_plate_3496': {'out_act': 'sigmoid', 'feature_transform': 'identity', 'target_transform': 'normalize'},
        #
        'p3ht': {'out_act': 'relu', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
        'thin_film': {'out_act': 'relu', 'feature_transform': 'identity', 'target_transform': 'normalize'},
        'crossed_barrel': {'out_act': 'relu', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
        'autoam': {'out_act': 'sigmoid', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
        'agnp': {'out_act': 'sigmoid', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
        #
        'suzuki_i': {'out_act': 'sigmoid', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
        'suzuki_ii': {'out_act': 'sigmoid', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
        'suzuki_iii': {'out_act': 'sigmoid', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
        'suzuki_iv': {'out_act': 'sigmoid', 'feature_transform': 'standardize', 'target_transform': 'normalize'},
}

dataset_best_ixs = {
    'oer_plate_4098': 39,  # 'train_r2': 0.8319234403227831, 'test_r2': 0.8465640553597371
    'oer_plate_3851': 41,  # 'train_r2': 0.7398952370543461, 'test_r2': 0.900529053900117
    'oer_plate_3860': 48,  # 'train_r2': 0.9643666817683573, 'test_r2': 0.9565719423562296 
    'oer_plate_3496': 20,  # 'train_r2': 0.8847241332118894, 'test_r2': 0.9245734657163589
    'p3ht': 26,            # 'train_r2': 0.8778774142940203, 'test_r2': 0.9673166550473478
    'agnp': 22,            # 'train_r2': 0.9894920088114398, 'test_r2': 0.9829782150809259
    'thin_film': 28,       # 'train_r2': 0.9472173771719822, 'test_r2': 0.9031088775578484
    'crossed_barrel': 37,  # 'train_r2': 0.9646820277025927, 'test_r2': 0.9409044469628027
    'autoam': 49,          # 'train_r2': 0.9970140068523364, 'test_r2': 0.9908067279613181 
    'suzuki_i': 29,        # 'train_r2': 0.9815854812719683, 'test_r2': 0.9717699198558013 
    'suzuki_ii': 34,       # 'train_r2': 0.9571388181603957, 'test_r2': 0.9481749209020203
    'suzuki_iii': 3,       # 'train_r2': 0.9942279917944439, 'test_r2': 0.9977259271817904
    'suzuki_iv': 15,       # 'train_r2': 0.9904239794801154, 'test_r2': 0.9869812427882629
}

In [5]:
def get_best_scores(res, dataset):
    scores = res[dataset]['scores']
    test_r2 = []
    sum_r2 = []
    for ix, score in enumerate(scores):
        test_r2.append(score['test_r2'])
        sum_r2.append(score['test_r2']+score['train_r2'])
    best_ix_test = np.argmax(test_r2)
    best_scores_test = scores[best_ix_test]
    best_ix_sum = np.argmax(sum_r2)
    best_scores_sum = scores[best_ix_sum]
    return best_ix_test, best_scores_test, best_ix_sum, best_scores_sum


def get_hyperparams(res, dataset_name, best_ix):
    return res[dataset_name]['params'][best_ix]

In [6]:
# for dataset in dataset_names:
#     print('DATASET : ', dataset)
#     best_ix_test, best_scores_test, best_ix_sum, best_scores_sum = get_best_scores(res, dataset)
#     print('best_ix_test : ', best_ix_test)
#     print('best_scores_test : ', best_scores_test)
#     print('best_ix_sum : ', best_ix_sum)
#     print('best_scores_sum : ', best_scores_sum)
#     print('\n\n')


In [8]:
# train the networks and make predictions on train/test set

for dataset_name in dataset_names:
    d = Dataset(kind=dataset_name)
    hyperparams = get_hyperparams(res, dataset_name, dataset_best_ixs[dataset_name])
    
    model  = BayesNeuralNet(**hyperparams, out_act=dataset_params[dataset_name]['out_act'])
    emulator = Emulator(
        dataset=dataset_name, 
        model=model,
        feature_transform=dataset_params[dataset_name]['feature_transform'],
        target_transform=dataset_params[dataset_name]['target_transform']
    )
    
    print('DATASET : ', dataset_name)
    print(emulator)
    
    scores = emulator.train()
    
    train_params = d.train_set_features.to_numpy()
    train_values = d.train_set_targets.to_numpy()
    test_params = d.test_set_features.to_numpy()
    test_values = d.test_set_targets.to_numpy()
    
    train_preds = emulator.run(train_params, num_samples=50)
    test_preds  = emulator.run(test_params, num_samples=50)
    
    
    dataset_params[dataset_name]['train_preds'] = train_preds
    dataset_params[dataset_name]['test_preds'] = test_preds
    dataset_params[dataset_name]['scores'] = scores
    dataset_params[dataset_name]['emulator'] = emulator
    
    emulator.save(f'emulator_{dataset_name}_BayesNeuralNet')
    
    
    

DATASET :  oer_plate_4098
<Emulator (Dataset(kind=oer_plate_4098), model=
--> batch_size:    10
--> es_patience:   100
--> hidden_act:    leaky_relu
--> hidden_depth:  4
--> hidden_nodes:  44
--> kind:          BayesNeuralNet
--> learning_rate: 0.00247272270870916
--> max_epochs:    100000
--> out_act:       sigmoid
--> pred_int:      100
--> reg:           0.005747336344783915
--> scope:         model)>
[0;37m[INFO] >>> Training model on 80% of the dataset, testing on 20%...
[0m

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -1.347          0.233         -3.190          0.272 *
[0m[0;37m[INFO]             100          0.427          0.115          0.483          0.096 *
[0m[0;37m[INFO]             200          0.593          0.097          0.557          0.089 *
[0m[0;37m[INFO]             300          0.475          0.110          0.522          0.092
[0m[0;37m[INFO]             400          0.484          0.109          0.657          0.078 *
[0m[0;37m[INFO]             500          0.340          0.124          0.441          0.099
[0m[0;37m[INFO]             600          0.466          0.111          0.545          0.090
[0m[0;37m[INFO]             700          0.549          0.102          0.483          0.096
[0m[0;37m[INFO]             800          0.447          0.113          0.543          0.090
[0m[0;37m[INFO]             900          0.487    

[0m[0;37m[INFO]            8400          0.563          0.101          0.631          0.081
[0m[0;37m[INFO]            8500          0.600          0.096          0.698          0.073
[0m[0;37m[INFO]            8600          0.586          0.098          0.712          0.071
[0m[0;37m[INFO]            8700          0.553          0.102          0.648          0.079
[0m[0;37m[INFO]            8800          0.696          0.084          0.756          0.066
[0m[0;37m[INFO]            8900          0.602          0.096          0.684          0.075
[0m[0;37m[INFO]            9000          0.582          0.099          0.740          0.068
[0m[0;37m[INFO]            9100          0.502          0.108          0.704          0.072
[0m[0;37m[INFO]            9200          0.525          0.105          0.675          0.076
[0m[0;37m[INFO]            9300          0.510          0.107          0.692          0.074
[0m[0;37m[INFO]            9400          0.500          0.

[0m[0;37m[INFO]           17300          0.571          0.100          0.689          0.074
[0m[0;37m[INFO]           17400          0.562          0.101          0.660          0.077
[0m[0;37m[INFO]           17500          0.650          0.090          0.651          0.079
[0m[0;37m[INFO]           17600          0.671          0.087          0.787          0.061
[0m[0;37m[INFO]           17700          0.629          0.093          0.787          0.061
[0m[0;37m[INFO]           17800          0.671          0.087          0.679          0.075
[0m[0;37m[INFO]           17900          0.599          0.096          0.731          0.069
[0m[0;37m[INFO]           18000          0.652          0.090          0.747          0.067
[0m[0;37m[INFO]           18100          0.650          0.090          0.642          0.080
[0m[0;37m[INFO]           18200          0.608          0.095          0.748          0.067
[0m[0;37m[INFO]           18300          0.732          0.

[0m[0;37m[INFO]           26200          0.681          0.086          0.731          0.069
[0m[0;37m[INFO]           26300          0.583          0.098          0.611          0.083
[0m[0;37m[INFO]           26400          0.694          0.084          0.699          0.073
[0m[0;37m[INFO]           26500          0.651          0.090          0.716          0.071
[0m[0;37m[INFO]           26600          0.639          0.092          0.720          0.070
[0m[0;37m[INFO]           26700          0.674          0.087          0.665          0.077
[0m[0;37m[INFO]           26800          0.592          0.097          0.653          0.078
[0m[0;37m[INFO]           26900          0.617          0.094          0.729          0.069
[0m[0;37m[INFO]           27000          0.630          0.093          0.672          0.076
[0m[0;37m[INFO]           27100          0.574          0.099          0.674          0.076
[0m[0;37m[INFO]           27200          0.626          0.

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -0.577          0.165         -0.689          0.171 *
[0m[0;37m[INFO]             100         -0.010          0.132         -0.021          0.133 *
[0m[0;37m[INFO]             200          0.069          0.127          0.053          0.128 *
[0m[0;37m[INFO]             300          0.476          0.095          0.301          0.110 *
[0m[0;37m[INFO]             400          0.465          0.096          0.349          0.106 *
[0m[0;37m[INFO]             500          0.469          0.096          0.415          0.101 *
[0m[0;37m[INFO]             600          0.461          0.097          0.402          0.102
[0m[0;37m[INFO]             700          0.463          0.096          0.408          0.101
[0m[0;37m[INFO]             800          0.495          0.093          0.430          0.099 *
[0m[0;37m[INFO]             900          0.4

[0m[0;37m[INFO]            8400          0.550          0.088          0.539          0.089
[0m[0;37m[INFO]            8500          0.541          0.089          0.519          0.091
[0m[0;37m[INFO]            8600          0.542          0.089          0.561          0.087
[0m[0;37m[INFO]            8700          0.551          0.088          0.532          0.090
[0m[0;37m[INFO]            8800          0.563          0.087          0.574          0.086 *
[0m[0;37m[INFO]            8900          0.547          0.088          0.551          0.088
[0m[0;37m[INFO]            9000          0.585          0.085          0.559          0.087
[0m[0;37m[INFO]            9100          0.545          0.089          0.534          0.090
[0m[0;37m[INFO]            9200          0.540          0.089          0.545          0.089
[0m[0;37m[INFO]            9300          0.557          0.088          0.545          0.089
[0m[0;37m[INFO]            9400          0.563          

[0m[0;37m[INFO]           17100          0.622          0.081          0.587          0.084
[0m[0;37m[INFO]           17200          0.651          0.078          0.593          0.084
[0m[0;37m[INFO]           17300          0.659          0.077          0.589          0.084
[0m[0;37m[INFO]           17400          0.626          0.080          0.608          0.082
[0m[0;37m[INFO]           17500          0.635          0.079          0.601          0.083
[0m[0;37m[INFO]           17600          0.625          0.081          0.602          0.083
[0m[0;37m[INFO]           17700          0.613          0.082          0.587          0.084
[0m[0;37m[INFO]           17800          0.637          0.079          0.587          0.084
[0m[0;37m[INFO]           17900          0.620          0.081          0.595          0.084
[0m[0;37m[INFO]           18000          0.650          0.078          0.590          0.084
[0m[0;37m[INFO]           18100          0.630          0.

[0m[0;37m[INFO] Performance statistics based on original data:
[0m[0;37m[INFO] Train R2   Score: 0.7353
[0m[0;37m[INFO] Test  R2   Score: 0.6781
[0m[0;37m[INFO] Train RMSD Score: 0.0149
[0m[0;37m[INFO] Test  RMSD Score: 0.0164

[0mDATASET :  oer_plate_3860
<Emulator (Dataset(kind=oer_plate_3860), model=
--> batch_size:    30
--> es_patience:   100
--> hidden_act:    leaky_relu
--> hidden_depth:  3
--> hidden_nodes:  28
--> kind:          BayesNeuralNet
--> learning_rate: 0.0013492835719124958
--> max_epochs:    100000
--> out_act:       sigmoid
--> pred_int:      100
--> reg:           0.002340189417421784
--> scope:         model)>
[0;37m[INFO] >>> Training model on 80% of the dataset, testing on 20%...
[0m

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -1.683          0.263         -1.167          0.233 *
[0m[0;37m[INFO]             100          0.502          0.113          0.364          0.126 *
[0m[0;37m[INFO]             200          0.710          0.086          0.768          0.076 *
[0m[0;37m[INFO]             300          0.792          0.073          0.798          0.071 *
[0m[0;37m[INFO]             400          0.811          0.070          0.814          0.068 *
[0m[0;37m[INFO]             500          0.834          0.065          0.809          0.069
[0m[0;37m[INFO]             600          0.821          0.068          0.849          0.062 *
[0m[0;37m[INFO]             700          0.848          0.063          0.848          0.062
[0m[0;37m[INFO]             800          0.856          0.061          0.858          0.060 *
[0m[0;37m[INFO]             900          0.8

[0m[0;37m[INFO]            8400          0.950          0.036          0.934          0.041
[0m[0;37m[INFO]            8500          0.949          0.036          0.929          0.042
[0m[0;37m[INFO]            8600          0.927          0.043          0.926          0.043
[0m[0;37m[INFO]            8700          0.951          0.035          0.924          0.044
[0m[0;37m[INFO]            8800          0.939          0.040          0.930          0.042
[0m[0;37m[INFO]            8900          0.938          0.040          0.922          0.044
[0m[0;37m[INFO]            9000          0.935          0.041          0.931          0.042
[0m[0;37m[INFO]            9100          0.920          0.045          0.933          0.041
[0m[0;37m[INFO]            9200          0.932          0.042          0.921          0.045
[0m[0;37m[INFO]            9300          0.935          0.041          0.930          0.042
[0m[0;37m[INFO]            9400          0.939          0.

[0m[0;37m[INFO]           17200          0.937          0.040          0.940          0.039
[0m[0;37m[INFO]           17300          0.945          0.038          0.937          0.040
[0m[0;37m[INFO]           17400          0.950          0.036          0.936          0.040
[0m[0;37m[INFO]           17500          0.948          0.037          0.936          0.040
[0m[0;37m[INFO]           17600          0.937          0.040          0.943          0.038
[0m[0;37m[INFO]           17700          0.928          0.043          0.915          0.046
[0m[0;37m[INFO]           17800          0.940          0.039          0.939          0.039
[0m[0;37m[INFO]           17900          0.941          0.039          0.940          0.039
[0m[0;37m[INFO]           18000          0.946          0.037          0.935          0.040
[0m[0;37m[INFO]           18100          0.952          0.035          0.945          0.037
[0m[0;37m[INFO]           18200          0.949          0.

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -6.290          0.279         -4.592          0.246 *
[0m[0;37m[INFO]             100          0.313          0.086          0.477          0.075 *
[0m[0;37m[INFO]             200          0.445          0.077          0.628          0.063 *
[0m[0;37m[INFO]             300          0.437          0.078          0.554          0.070
[0m[0;37m[INFO]             400          0.478          0.075          0.632          0.063 *
[0m[0;37m[INFO]             500          0.508          0.072          0.545          0.070
[0m[0;37m[INFO]             600          0.530          0.071          0.579          0.068
[0m[0;37m[INFO]             700          0.488          0.074          0.547          0.070
[0m[0;37m[INFO]             800          0.515          0.072          0.589          0.067
[0m[0;37m[INFO]             900          0.542    

[0m[0;37m[INFO]            8500          0.489          0.074          0.663          0.060
[0m[0;37m[INFO]            8600          0.501          0.073          0.702          0.057
[0m[0;37m[INFO]            8700          0.572          0.068          0.640          0.062
[0m[0;37m[INFO]            8800          0.438          0.077          0.641          0.062
[0m[0;37m[INFO]            8900          0.611          0.064          0.731          0.054
[0m[0;37m[INFO]            9000          0.542          0.070          0.680          0.059
[0m[0;37m[INFO]            9100          0.536          0.070          0.768          0.050
[0m[0;37m[INFO]            9200          0.477          0.075          0.759          0.051
[0m[0;37m[INFO]            9300          0.451          0.077          0.717          0.055
[0m[0;37m[INFO]            9400          0.476          0.075          0.761          0.051
[0m[0;37m[INFO]            9500          0.563          0.

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -1.082          0.445         -1.006          0.142 *
[0m[0;37m[INFO]             100         -1.065          0.444         -0.968          0.140 *
[0m[0;37m[INFO]             200          0.614          0.192         -1.444          0.156
[0m[0;37m[INFO]             300          0.643          0.184         -2.648          0.191
[0m[0;37m[INFO]             400          0.636          0.186         -3.239          0.206
[0m[0;37m[INFO]             500          0.667          0.178         -2.692          0.192
[0m[0;37m[INFO]             600          0.649          0.183         -1.844          0.169
[0m[0;37m[INFO]             700          0.658          0.181         -3.575          0.214
[0m[0;37m[INFO]             800          0.682          0.174         -4.270          0.230
[0m[0;37m[INFO]             900          0.653        

[0m[0;37m[INFO]            8600          0.677          0.175         -1.573          0.160
[0m[0;37m[INFO]            8700          0.681          0.174         -1.642          0.163
[0m[0;37m[INFO]            8800          0.697          0.170         -1.804          0.167
[0m[0;37m[INFO]            8900          0.663          0.179         -1.332          0.153
[0m[0;37m[INFO]            9000          0.699          0.169         -2.049          0.175
[0m[0;37m[INFO]            9100          0.707          0.167         -2.397          0.184
[0m[0;37m[INFO]            9200          0.645          0.184         -1.384          0.154
[0m[0;37m[INFO]            9300          0.707          0.167         -2.088          0.176
[0m[0;37m[INFO]            9400          0.694          0.171         -2.102          0.176
[0m[0;37m[INFO]            9500          0.703          0.168         -1.875          0.170
[0m[0;37m[INFO]            9600          0.676          0.

[1;31m[ERROR] Lower bound of 9.999518096000001 provided for parameter `q_pva` is higher than minimum found in the data!
[0m[1;31m[ERROR] Lower bound of 0.498851653 provided for parameter `q_seed` is higher than minimum found in the data!
  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0         -0.224          0.281          0.010          0.220 *
[0m[0;37m[INFO]             100          0.674          0.145          0.449          0.164 *
[0m[0;37m[INFO]             200          0.793          0.116          0.621          0.136 *
[0m[0;37m[INFO]             300          0.813          0.110          0.726          0.116 *
[0m[0;37m[INFO]             400          0.800          0.114          0.779          0.104 *
[0m[0;37m[INFO]             500          0.863          0.094          0.690          0.123
[0m[0;37m[INFO]             600          0.847          0.099          0.728          0.115
[0m[0;37m[INFO]             700          0.723          0.134          0.612          0.138
[0m[0;37m[INFO]             800          0.897          0.082          0.720          0.117
[0m[0;37m[INFO]             900          0.835  

[0m[0;37m[INFO]            8500          0.973          0.042          0.921          0.062
[0m[0;37m[INFO]            8600          0.971          0.043          0.944          0.052
[0m[0;37m[INFO]            8700          0.967          0.046          0.925          0.060
[0m[0;37m[INFO]            8800          0.965          0.048          0.949          0.050 *
[0m[0;37m[INFO]            8900          0.968          0.045          0.946          0.052
[0m[0;37m[INFO]            9000          0.973          0.042          0.919          0.063
[0m[0;37m[INFO]            9100          0.969          0.045          0.952          0.049 *
[0m[0;37m[INFO]            9200          0.971          0.043          0.942          0.053
[0m[0;37m[INFO]            9300          0.969          0.045          0.946          0.051
[0m[0;37m[INFO]            9400          0.969          0.044          0.946          0.051
[0m[0;37m[INFO]            9500          0.969        

[0m[0;37m[INFO]           17200          0.978          0.038          0.961          0.043
[0m[0;37m[INFO]           17300          0.977          0.039          0.966          0.041
[0m[0;37m[INFO]           17400          0.975          0.040          0.959          0.045
[0m[0;37m[INFO]           17500          0.972          0.043          0.957          0.046
[0m[0;37m[INFO]           17600          0.977          0.039          0.957          0.046
[0m[0;37m[INFO]           17700          0.973          0.042          0.960          0.044
[0m[0;37m[INFO]           17800          0.976          0.040          0.958          0.045
[0m[0;37m[INFO]           17900          0.977          0.039          0.963          0.043
[0m[0;37m[INFO]           18000          0.972          0.042          0.965          0.041
[0m[0;37m[INFO]           18100          0.974          0.041          0.962          0.043
[0m[0;37m[INFO]           18200          0.977          0.

[0m

DATASET :  thin_film
<Emulator (Dataset(kind=thin_film), model=
--> batch_size:    50
--> es_patience:   100
--> hidden_act:    leaky_relu
--> hidden_depth:  5
--> hidden_nodes:  60
--> kind:          BayesNeuralNet
--> learning_rate: 0.0041221056953923575
--> max_epochs:    100000
--> out_act:       relu
--> pred_int:      100
--> reg:           0.05223118992361589
--> scope:         model)>
[0;37m[INFO] >>> Training model on 80% of the dataset, testing on 20%...
[0m

  trainable=trainable)
  trainable=trainable)


[0m[0;37m[INFO]           Epoch       Train R2     Train RMSD        Test R2      Test RMSD
[0m[0;37m[INFO]               0          0.008          0.200         -0.004          0.193 *
[0m[0;37m[INFO]             100          0.633          0.121          0.805          0.085 *
[0m[0;37m[INFO]             200          0.689          0.112          0.831          0.079 *
[0m[0;37m[INFO]             300          0.679          0.114          0.815          0.083
[0m[0;37m[INFO]             400          0.670          0.115          0.757          0.095
[0m[0;37m[INFO]             500          0.706          0.109          0.836          0.078 *
[0m[0;37m[INFO]             600          0.760          0.098          0.764          0.094
[0m[0;37m[INFO]             700          0.768          0.096          0.829          0.080
[0m[0;37m[INFO]             800          0.803          0.089          0.834          0.078
[0m[0;37m[INFO]             900          0.823    

KeyboardInterrupt: 

In [None]:
_olympus_reference_colors = [
    "#08294C",
    "#75BBE1",
    "#D4E9F4",
    "#F2F2F2",
    "#F7A4D4",
    "#F75BB6",
    "#EB0789",
]

dataset_targets = {
    'oer_plate_4098': 'overpotential [V]',  
    'oer_plate_3851': 'overpotential [V]',  
    'oer_plate_3860': 'overpotential [V]',  
    'oer_plate_3496': 'overpotential [V]',  
    'p3ht': 'conductivity',           
    'agnp': 'spectrum score',           
    'thin_film': 'instability index',       
    'crossed_barrel': 'mechanical toughness',  
    'autoam': 'shape score',           
    'suzuki_i': ['yield [%]', '[prod/cat]'],        
    'suzuki_ii': ['yield [%]', '[prod/cat]'],       
    'suzuki_iii': ['yield [%]', '[prod/cat]'],       
    'suzuki_iv': ['yield [%]', '[prod/cat]'],       
}

In [None]:
# make plots for single objective datasets

fig, axes = plt.subplots(2, 5, figsize=(18, 6))
axes = axes.flatten()

for ix, dataset_name in enumerate(dataset_names):
    if 'suzuki_' not in dataset_name:
        d = Dataset(kind=dataset_name)
        print(d.data)
        emualtor = Emulator(dataset=dataset_name, model='BayesNeuralNet')
        print(emulator)
        train_params = d.train_set_features.to_numpy()
        train_values = d.train_set_targets.to_numpy()
        test_params = d.test_set_features.to_numpy()
        test_values = d.test_set_targets.to_numpy()

        train_preds = emulator.run(train_params, num_samples=50)
        test_preds  = emulator.run(test_params, num_samples=50)
        
        meas_name = d.target_names
        
        # training data 
        axes[ix].plot(
            train_values,
            train_preds,
            c="#75BBE1",
            ls='',
            marker='o',
            markersize=4,
        )
        
        # training data 
        axes[ix].plot(
            test_values,
            test_preds,
            c="#EB0789",
            ls='',
            marker='o',
            markersize=4,
        )  
        
        axes[ix].set_xlabel(f'true {dataset_targets[dataset_name]}')
        axes[ix].set_ylabel(f'pred {dataset_targets[dataset_name]}')
    else:
        pass
        
        
plt.tight_layout()
        

In [None]:
dataset_params['p3ht']['scores']

In [None]:
emualtor = Emulator(dataset='p3ht', model='BayesNeuralNet')
print(emulator)

In [9]:
d = Dataset(kind='p3ht')
ix0 = d.test_indices

In [10]:
d = Dataset(kind='p3ht')
ix1 = d.test_indices

In [11]:
ix0

array([158, 134,  32, 122,  74,  75, 119,  28,  17,  68, 163, 137, 160,
        62, 118,  45,  79,   6,  82, 133,   2, 177, 107, 173, 138,  31,
       168,  55, 171, 125, 110,  47, 142, 115, 147,  66])

In [12]:
ix1

array([ 88,  12, 162, 148,  77,  51,   3,  36,  86, 135, 130,  81,  11,
         6, 176, 103,  61, 154,  14,  62,  91,  30,  70,  29,  87, 150,
       106,  26,  72,  35,  44, 119,  19, 132,  37, 174])