In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
import jax 
import jax.mupy as jnp
import pickle
import matplotlib.pyplot as plt

from phyEBM.utils import *

In [None]:
file = 'data/data1_sharp.pkl'
file = 'data/data1_smooth.pkl'

with open(file, 'rb') as f:
    data = pickle.load(f)
# print(data)
show_n = 10

n_data = data['cfg'].Observation.n_data
print('n_data: ', n_data)
print(data['obs_locs'].shape)
print(data['dataYs'].shape)

plt.rcParams.update({
    'axes.labelsize':   20,
    'axes.titlesize':   20,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
          })

plt.plot(data['grid'], data['dataUs'][:show_n].T)
plt.scatter(jnp.repeat(data['obs_locs'],  data['dataYs'][:show_n].shape[0], axis=1).T, data['dataYs'][:show_n], s=5)
plt.xlabel(r'$x$')
plt.ylabel(r'$u, y$')
plt.grid()
plt.savefig('data/plots/sharp_data_us.pdf')
plt.show()

h = data['grid'][1] - data['grid'][0]
plt.plot(data['grid'], data['dataAs'][:show_n].T / jnp.sqrt(jnp.sum(data['dataAs'][:show_n].T**2.*h, axis=0)))
plt.grid()
plt.xlabel(r'$x$')
plt.ylabel(r'$\bar{a}$')
plt.savefig('data/plots/sharp_data_a_bar.pdf')
plt.show()

plt.plot(data['grid'], data['dataZs'][0,:])

plt.grid()
plt.xlabel(r'$x$')
plt.ylabel(r'$z$')
plt.savefig('data/plots/sharp_data_zs.pdf')
plt.show()

print(data['params'])

In [None]:
from config import *

file = 'data/model.pkl'
with open(file, 'rb') as f:
    model1 = pickle.load(f)


from jax.scipy.signal import convolve
def convolve_avg(array, window):
    kernel = jnp.ones(window)
    new_array = convolve(array, kernel, mode='same') / convolve(jnp.ones_like(array), kernel, mode='same')
    return new_array

plt.rcParams.update({
    'axes.labelsize':   16,
    'axes.titlesize':   16,
    'xtick.labelsize' : 12,
    'ytick.labelsize' : 12,
          })

aux = model1['aux']
lambda_vals_exp = jnp.exp(aux['lambdas'])
exp_kappa_p = jnp.exp(aux['kappas'][:,1])
exp_kappa_m = jnp.exp(aux['kappas'][:,0])

lw = 1

plt.plot(lambda_vals_exp, label=r'$\lambda$')
plt.plot(lambda_vals_exp**0. * 8., '--', c='k', label=r'true $\lambda$')

plt.plot(exp_kappa_p, label=r'$\kappa^+$')
plt.plot(exp_kappa_p**0. * 2., ':', c='k', label=r'true $\kappa^+$')

plt.plot(exp_kappa_m, label=r'$\kappa^-$')
plt.plot(exp_kappa_m**0.* 1., '-.', c='k', label=r'true $\kappa^{-}$')

plt.grid()
plt.ylim(-0.8, 18)
plt.legend(loc='upper right')
plt.xlabel('Iterations')
plt.ylabel(r'$\lambda, \kappa^+, \kappa^-$')
plt.savefig('data/plots/plot_convergence.pdf')
plt.show()

print('last kappa^- = ', exp_kappa_m[-1])
print('last kappa^+ = ', exp_kappa_p[-1])
print('last lambda = ', lambda_vals_exp[-1])

print('relative error km', 100*jnp.abs(1-exp_kappa_m[-1])/1 )
print('relative error kp',  100*jnp.abs(2-exp_kappa_p[-1])/2 )
print('relative error lambda',  100*jnp.abs(8-lambda_vals_exp[-1])/8 )


In [None]:
with open(file, 'rb') as f:
    model1 = pickle.load(f)
aux  = model1['aux']['aux']
print(aux.shape)
plt.semilogy(aux[:,0], label=r'$J_1$')
plt.semilogy(aux[:,2], c='tab:orange', label=r'$J_2$')
plt.legend()
plt.xlabel(r'Iterations')
plt.grid()
plt.savefig('data/plots/plot_losses.pdf')
plt.show() 