### Checking r2 of synthetic functions

In [9]:
import torch
from mf_kmc.simulations.implementations.park.park import Park
from botorch.test_functions import AugmentedHartmann, AugmentedBranin
from sklearn.metrics import r2_score, root_mean_squared_error

torch.manual_seed(33);

In [48]:
samps_park = torch.rand(size=(100,4))
samps_branin = torch.rand(size=(100, 2)) 
samps_hartmann = torch.rand(size=(100, 6))

### Park


In [17]:
p = Park()

fids = [0.2, 0.3, 0.37, 0.4, 0.46, 0.5, 0.6]

for l in fids:

    s_lf = torch.cat((samps_park, torch.ones(samps_park.size()[0]).unsqueeze(1)*l), dim=1)
    s_hf = torch.cat((samps_park, torch.ones(samps_park.size()[0]).unsqueeze(1)*1.0), dim=1)
    s = p(s_lf).detach().numpy()
    s_true = p(s_hf).detach().numpy() 

    r2 = r2_score(s, s_true)
    mse = root_mean_squared_error(s, s_true)

    print(f"fid {l}-> r2:{r2} mse: {mse}")

fid 0.2-> r2:-0.18424156267148417 mse: 3.458242654800415
fid 0.3-> r2:0.19125956140869615 mse: 3.0259616374969482
fid 0.37-> r2:0.3933111370645136 mse: 2.723365545272827
fid 0.4-> r2:0.46710164703573753 mse: 2.593681812286377
fid 0.46-> r2:0.5946507555172562 mse: 2.33431339263916
fid 0.5-> r2:0.6664187371820292 mse: 2.1614015102386475
fid 0.6-> r2:0.8066593983389008 mse: 1.7291208505630493


### Branin

In [28]:
f = AugmentedBranin()

bounds = torch.tensor(f._bounds)[:2, :]

In [31]:
bounds[:, 1]

tensor([10., 15.])

In [65]:
f = AugmentedBranin()

samps = samps_branin * (bounds[:, 1] - bounds[:, 0]) + bounds[:, 0]

fids = [-0.29, 0.03, 0.3, 0.54]

for l in fids:

    s_lf = torch.cat((samps, torch.ones(samps.size()[0]).unsqueeze(1)*l), dim=1)
    s_hf = torch.cat((samps, torch.ones(samps.size()[0]).unsqueeze(1)*1.0), dim=1)
    s = f(s_lf).detach().numpy()
    s_true = f(s_hf).detach().numpy() 

    r2 = r2_score(s, s_true)
    mse = root_mean_squared_error(s, s_true)

    print(f"fid {l}-> r2:{r2} mse: {mse}")

fid -0.29-> r2:0.20694481205919268 mse: 107.81169128417969
fid 0.03-> r2:0.40611345513564845 mse: 73.32196807861328
fid 0.3-> r2:0.6149813599407541 mse: 48.40953063964844
fid 0.54-> r2:0.8055106761867619 mse: 29.3226318359375


In [67]:
p = AugmentedHartmann()

fids = [0.2, 0.3, 0.37, 0.4, 0.46, 0.5, 0.6]

for l in fids:

    s_lf = torch.cat((samps_hartmann, torch.ones(samps_hartmann.size()[0]).unsqueeze(1)*l), dim=1)
    s_hf = torch.cat((samps_hartmann, torch.ones(samps_hartmann.size()[0]).unsqueeze(1)*1.0), dim=1)
    s = p(s_lf).detach().numpy()
    s_true = p(s_hf).detach().numpy() 

    r2 = r2_score(s, s_true)
    mse = root_mean_squared_error(s, s_true)

    print(f"fid {l}-> r2:{r2} mse: {mse}")

fid 0.2-> r2:0.9997885319682924 mse: 0.0050696711987257
fid 0.3-> r2:0.9998383007244377 mse: 0.004435963463038206
fid 0.37-> r2:0.9998691402803723 mse: 0.003992372192442417
fid 0.4-> r2:0.9998813519467383 mse: 0.0038022585213184357
fid 0.46-> r2:0.999903968974162 mse: 0.00342203164473176
fid 0.5-> r2:0.9999177113804464 mse: 0.0031685426365584135
fid 0.6-> r2:0.9999474029011272 mse: 0.002534835832193494
