In [1]:
import jax.numpy as np
from jax import random, grad, vmap, jit
from jax.experimental.ode import odeint
from jax.config import config


import matplotlib
import matplotlib.pyplot as plt
import numpy as onp


In [2]:
plt.rcParams.update(plt.rcParamsDefault)
plt.rc('font', family='serif')
plt.rcParams.update({
                      "text.usetex": True,
                      "font.family": "serif",
                     'text.latex.preamble': r'\usepackage{amsmath}',
                      'font.size': 16,
                      'lines.linewidth': 3,
                      'axes.labelsize': 16, 
                      'axes.titlesize': 16,
                      'xtick.labelsize': 16,
                      'ytick.labelsize': 16,
                      'legend.fontsize': 16,
                      'axes.linewidth': 2})

In [3]:

# Use double precision to generate data (due to GP sampling)
config.update("jax_enable_x64", True)

# Training data
N = 100
m = 100 # number of input sensors
P = 100   # number of output sensors
K = 10
Output_scales = np.linspace(-2, 2, K)  # K is the number of different output scales

y_test = np.load("y_test.npy")
s_test = np.load("s_test.npy")
s_pred = np.load("s_pred.npy")




In [4]:


s_pred_mu, s_pred_std = np.mean(s_pred, axis = 0)[:,None], np.std(s_pred, axis = 0)[:,None]
print(s_pred_mu.shape, s_pred_std.shape)



# Plot a sample test example
idx = 835
index = np.arange(idx*P,(idx+1)*P)
plt.figure()
for k in range(1, 20):
    plt.plot(y_test[index, :], s_pred[k,index], 'r--', lw=2)
plt.plot(y_test[index, :], s_pred[0,index], 'r--', lw=2, label = "Predicted sample")
plt.plot(y_test[index, :], s_test[index, :], 'b-', lw=2, label = "Exact")
plt.plot(y_test[index, :], s_pred_mu[index, :], 'k--', lw=2, label = "Predicted mean")
plt.legend(loc='upper right', frameon=False, prop={'size': 13})
plt.xlabel('y')
plt.ylabel('G(u)(y)')
plt.tight_layout()
plt.savefig('./Samples.png', dpi = 300)


# Compute the errors and the uncertainty
s_pred_mu, s_pred_std

N_test_total = s_pred_mu.shape[0] // P
N_test = N_test_total // K 

errors = onp.zeros((K, N_test))
uncertainty = onp.zeros((K, N_test))

for idx in range(N_test_total):
    id1 = idx // N_test
    id2 = idx - id1 * N_test
    index = np.arange(idx*P,(idx+1)*P)
    s_pred_sample = s_pred_mu[index,:]
    s_pred_uncertainty = s_pred_std[index,:]
    s_test_sample = s_test[index,:]

    # print(id1, id2, s_pred_sample.shape, s_test_sample.shape, s_pred_uncertainty.shape)

    errors[id1, id2] = np.linalg.norm(s_pred_sample - s_test_sample, 2) / np.linalg.norm(s_test_sample, 2) 
    uncertainty[id1, id2] = np.linalg.norm(s_pred_uncertainty, 2) / np.linalg.norm(s_test_sample, 2)

plt.figure()
plt.errorbar(Output_scales, errors.mean(axis = 1), yerr=errors.std(axis = 1), fmt='.k')
plt.errorbar(Output_scales, uncertainty.mean(axis = 1), fmt='r-')
plt.savefig('./error_vs_uncertainty.png', dpi = 300)


np.save("errors_normed.npy", errors)
np.save("uncertainty_normed.npy", uncertainty)


(100000, 1) (100000, 1)


In [5]:


# Plot different samples from different output scales
idxs = [30, 130, 230, 330, 430, 530, 630, 730, 830, 930]

for m in range(len(idxs)):

    # Plot a sample 130 example
    idx = idxs[m]
    index = np.arange(idx*P,(idx+1)*P)
    plt.figure()
    for k in range(1, 20):
        plt.plot(y_test[index, :], s_pred[k,index], 'r--', lw=2)
    plt.plot(y_test[index, :], s_pred[0,index], 'r--', lw=2, label = "Predicted sample")
    plt.plot(y_test[index, :], s_test[index, :], 'b-', lw=2, label = "Exact")
    plt.plot(y_test[index, :], s_pred_mu[index, :], 'k--', lw=2, label = "Predicted mean")
    plt.legend(loc='upper right', frameon=False, prop={'size': 13})
    plt.xlabel('y')
    plt.ylabel('G(u)(y)')
    plt.tight_layout()
    plt.savefig('./Samples' + str(idx) + '.png', dpi = 300)



In [2]:
errors = np.load("errors_normed.npy")

np.max(errors)



DeviceArray(0.04996318, dtype=float32)