In [1]:
import json
import jax
import jax.numpy as jnp
import surrojax_gp
import oed_toolbox

In [2]:
def load_beam_data(file_dir):
    with open(file_dir, 'r') as f:
        data = json.load(f)
    try:
        x = jnp.stack([jnp.array(data[key]) for key in ('E', 'y_rot')], axis=-1)
        y = jnp.array(jnp.array(data['end_disp'])).reshape(-1,1)
    except KeyError:
        x = jnp.stack([jnp.array(data[key]) for key in ('E', 'Beam Angle')], axis=-1)
        y = jnp.array(jnp.array(data['End Displacement'])).reshape(-1,1)
    return {'x': x, 'y': y}

def scale_data(data):
    for key, val in data.items():
        data[key] = (val- jnp.min(val ,axis=0))/(jnp.max(val ,axis=0) - jnp.min(val ,axis=0))
    return data

linear_data = scale_data(load_beam_data('linear_beam_train.json'))
nonlinear_data = scale_data(load_beam_data('nonlinear_beam_train.json'))



In [3]:
def kernel(x_1, x_2, params):
    lengths = jnp.array([params[f"length_{i}"] for i in range(2)])
    inv_lengths = jnp.diag(lengths**(-1))
    ln_k_d = -0.5*(x_1 - x_2).T @ inv_lengths @ (x_1 - x_2) 
    return params["const"]*jnp.exp(ln_k_d)
constraints = {"length_0": {">": 10**-1, "<": 10**1}, 
               "length_1": {">": 10**-1, "<": 10**1}, 
               "const": {">": 10**-1, "<": 10**2}}
linear_gp = surrojax_gp.create_gp(kernel, linear_data['x'], linear_data['y'], constraints)
nonlinear_gp = surrojax_gp.create_gp(kernel, nonlinear_data['x'], nonlinear_data['y'], constraints)

-93.08633
-133.60971
-95.6327
-156.34497
-240.05603
-195.32083
-245.65036
-172.73447
-256.8095
-261.1317
-263.31085
-264.59674
-265.4391
-263.36307
-266.41797
-266.0668
-267.96014
-267.65726
-264.6313
-264.3997
-267.65726
-267.65726
-268.13397
-268.13397
-268.13397
-265.8656
-268.13397
-268.13397
-267.3597
-268.13397
-268.13397
-265.55286
-268.13397
-268.13397
-266.26984
-268.13397
-267.30255
-267.0605
-266.83017
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
-268.13397
      fun: array(-268.13397217)
 hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 0.07310563,  5.18304443, -4.95946884])
  message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 55
      nit: 11
     njev: 55
   status: 0
  success: True
        x: array([44.64443157,  0.63340229,  0.70003374])
-141.74521
-134.26138
897.66437
-141.81198
-141.58661
-141.81198
-141.8

In [4]:
linear_beam = oed_toolbox.models.Model.from_surrojax_gp(linear_gp)
nonlinear_beam = oed_toolbox.models.Model.from_surrojax_gp(nonlinear_gp)

In [22]:
noise_cov = 0.1*jnp.identity(1)
prior_mean = 0.35*jnp.ones((1,))
prior_cov = 0.2*jnp.identity(1)

In [23]:
minimizer = oed_toolbox.optim.gradient_descent_for_map()
ape = oed_toolbox.losses.APE.using_laplace_approximation(nonlinear_beam, minimizer, prior_mean, \
                                                   prior_cov, noise_cov) #, use_reparameterisation=True
adam = oed_toolbox.optim.adam_for_oed_loss(lr=1e-2)
d_0 = jnp.array([0.5])
num_samples = 100
rng = 19
adam(ape, d_0, num_samples, rng, verbose=True)

Iteration 1: Loss = [-0.2935892], x = [0.51]
Iteration 2: Loss = [-0.28550205], x = [0.5194105]
Iteration 3: Loss = [-0.27736187], x = [0.5269654]
Iteration 4: Loss = [-0.27142977], x = [0.5313043]
Iteration 5: Loss = [-0.26718264], x = [0.53218794]
Iteration 6: Loss = [-0.26655979], x = [0.5307567]
Iteration 7: Loss = [-0.26745393], x = [0.52776676]
Iteration 8: Loss = [-0.27070768], x = [0.52403164]
Iteration 9: Loss = [-0.27378957], x = [0.52025443]
Iteration 10: Loss = [-0.27695231], x = [0.51696104]
Iteration 11: Loss = [-0.27966165], x = [0.5146551]
Iteration 12: Loss = [-0.28156678], x = [0.51355875]
Iteration 13: Loss = [-0.28197101], x = [0.51372266]
Iteration 14: Loss = [-0.28241868], x = [0.51486605]
Iteration 15: Loss = [-0.28174064], x = [0.5167316]
Iteration 16: Loss = [-0.280432], x = [0.5189811]
Iteration 17: Loss = [-0.27819752], x = [0.5212678]
Iteration 18: Loss = [-0.27523239], x = [0.5232789]
Iteration 19: Loss = [-0.2735789], x = [0.52472943]
Iteration 20: Loss = 

KeyboardInterrupt: 

In [26]:
adam = oed_toolbox.optim.adam_for_oed_loss(lr=1e-2)
d_0 = jnp.array([0.5])
theta_0 = 0.01
num_samples = 100
rng = 19
likelihood = oed_toolbox.distributions.Likelihood.from_model_plus_constant_gaussian_noise(nonlinear_beam, noise_cov)
fisher_info = oed_toolbox.covariances.FisherInformation(likelihood, use_reparameterisation=True)
d_opt = oed_toolbox.losses.D_Optimal(fisher_info)
adam(d_opt, d_0, num_samples, rng, verbose=True, args=(theta_0,))

Iteration 1: Loss = -34.82984924316406, x = [0.51]
Iteration 2: Loss = -36.699886322021484, x = [0.5200062]
Iteration 3: Loss = -38.607887268066406, x = [0.5300216]
Iteration 4: Loss = -40.5485725402832, x = [0.5400486]
Iteration 5: Loss = -42.49846267700195, x = [0.5500881]
Iteration 6: Loss = -44.57035827636719, x = [0.56014067]
Iteration 7: Loss = -46.59259796142578, x = [0.570205]
Iteration 8: Loss = -48.6211051940918, x = [0.580278]
Iteration 9: Loss = -50.66379928588867, x = [0.59035516]
Iteration 10: Loss = -52.697540283203125, x = [0.60042995]
Iteration 11: Loss = -54.70115661621094, x = [0.61049324]
Iteration 12: Loss = -56.69696807861328, x = [0.62053376]
Iteration 13: Loss = -58.57734298706055, x = [0.6305366]
Iteration 14: Loss = -60.39044952392578, x = [0.6404834]
Iteration 15: Loss = -62.12944793701172, x = [0.6503518]
Iteration 16: Loss = -63.73953628540039, x = [0.6601151]
Iteration 17: Loss = -65.22236633300781, x = [0.6697416]
Iteration 18: Loss = -66.57312774658203, 

DeviceArray([0.7308018], dtype=float32)

In [30]:
adam = oed_toolbox.optim.adam_for_oed_loss(lr=1e-2)
d_0 = jnp.array([0.8])
theta_0 = 0.01
num_samples = 10
rng = 19
model = linear_beam
likelihood = oed_toolbox.distributions.Likelihood.from_model_plus_constant_gaussian_noise(model, noise_cov)
fisher_info = oed_toolbox.covariances.FisherInformation(likelihood, use_reparameterisation=True)
pred_cov = oed_toolbox.covariances.PredictiveCovariance(model, fisher_info)
d_opt = oed_toolbox.losses.D_Optimal(pred_cov)
adam(d_opt, d_0, num_samples, rng, verbose=True, args=(theta_0,))

Iteration 1: Loss = -0.21289372444152832, x = [0.8097946]
Iteration 2: Loss = -0.21289367973804474, x = [0.8188755]
Iteration 3: Loss = -0.21289370954036713, x = [0.8281681]
Iteration 4: Loss = -0.21289370954036713, x = [0.8366558]
Iteration 5: Loss = -0.21289367973804474, x = [0.8438047]
Iteration 6: Loss = -0.21289369463920593, x = [0.84580326]
Iteration 7: Loss = -0.21289369463920593, x = [0.84886634]
Iteration 8: Loss = -0.21289366483688354, x = [0.85156095]
Iteration 9: Loss = -0.21289366483688354, x = [0.85394835]
Iteration 10: Loss = -0.21289369463920593, x = [0.85472083]
Iteration 11: Loss = -0.21289370954036713, x = [0.85541195]
Iteration 12: Loss = -0.21289367973804474, x = [0.8560323]
Iteration 13: Loss = -0.21289366483688354, x = [0.85779417]
Iteration 14: Loss = -0.21289366483688354, x = [0.85814667]
Iteration 15: Loss = -0.21289369463920593, x = [0.8596114]
Iteration 16: Loss = -0.21289367973804474, x = [0.8620293]
Iteration 17: Loss = -0.21289367973804474, x = [0.8642205

DeviceArray([0.8097946], dtype=float32)