In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
import jax 
import jax.numpy as jnp
import pickle
import matplotlib.pyplot as plt

from phyEBM.utils import *

In [None]:
file = 'data/data.pkl'

with open(file, 'rb') as f:
    data = pickle.load(f)
# print(data)
show_n = 5

n_data = data['cfg'].Observation.n_data
print('n_data: ', n_data)
print(data['obs_locs'].shape)
print(data['dataYs'].shape)

plt.rcParams.update({
    'axes.labelsize':   20,
    'axes.titlesize':   20,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
          })

plt.plot(data['grid'], data['dataUs'][:show_n].T)
plt.scatter(jnp.repeat(data['obs_locs'],  data['dataYs'][:show_n].shape[0], axis=1).T, data['dataYs'][:show_n], s=5)
plt.xlabel(r'$x$')
plt.ylabel(r'$u, y$')
plt.grid()
plt.savefig('data/plots/sharp_data_us.pdf')
plt.show()

h = data['grid'][1] - data['grid'][0]
plt.plot(data['grid'], data['dataAs'][:show_n].T / jnp.sqrt(jnp.sum(data['dataAs'][:show_n].T**2.*h, axis=0)))
plt.grid()
plt.xlabel(r'$x$')
plt.ylabel(r'$\bar{a}$')
plt.savefig('data/plots/sharp_data_a_bar.pdf')
plt.show()

# plt.step(data['grid'], data['dataZs'][0,:])
plt.plot(data['grid'], data['dataZs'][0,:])

plt.grid()
plt.xlabel(r'$x$')
plt.ylabel(r'$z$')
plt.savefig('data/plots/sharp_data_zs.pdf')
plt.show()

print(data['params'])

In [None]:
with open('data/model.pkl', 'rb') as f:

    model = pickle.load(f)

loss = model['loss']
plt.semilogy(loss)
plt.grid()
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig('data/plots/plot_model_loss.pdf')
plt.show()
lw = 1
w_avg = 5_000

aux = model['aux']
sigma_val = jnp.exp(aux['sigmas'])
ell_val   = jnp.exp(aux['ells'])
nu_val    = jnp.exp(aux['nus'])


plt.plot(nu_val, c = 'tab:blue', label=r'$\nu$')
plt.plot(nu_val**0. * 1.5, '--', c='k', label=r'true $\nu$')
plt.plot(ell_val, c='tab:orange', label=r'$\ell$')
plt.plot(ell_val**0. * 0.5, ':', c='k', label=r'true $\ell$')
plt.grid()
plt.xlabel('Iterations')
plt.ylabel(r'$\nu, \ell$')

plt.legend(loc='best')
plt.savefig('data/plots/plot_convergence_model.pdf')
plt.show()
