In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
import jax 
import jax.numpy as jnp
import pickle
import matplotlib.pyplot as plt

from phyOT.utils import *

In [None]:
file = 'data/data1_sharp.pkl'
with open(file, 'rb') as f:
    data = pickle.load(f)
# print(data)
show_n = 10

n_data = data['cfg'].Observation.n_data
print('n_data: ', n_data)
print(data['obs_locs'].shape)
print(data['dataYs'].shape)

plt.rcParams.update({
    'axes.labelsize':   20,
    'axes.titlesize':   20,
    'xtick.labelsize' : 16,
    'ytick.labelsize' : 16,
          })

plt.plot(data['grid'], data['dataUs'][:show_n].T)
plt.scatter(jnp.repeat(data['obs_locs'],  data['dataYs'][:show_n].shape[0], axis=1).T, data['dataYs'][:show_n], s=5)
plt.xlabel(r'$x$')
plt.ylabel(r'$u, y$')
plt.grid()
plt.savefig('data/plots/sharp_data_us.pdf')
plt.show()

h = data['grid'][1] - data['grid'][0]
plt.plot(data['grid'], data['dataAs'][:show_n].T / jnp.sqrt(jnp.sum(data['dataAs'][:show_n].T**2.*h, axis=0)))
plt.grid()
plt.xlabel(r'$x$')
plt.ylabel(r'$\bar{a}$')
plt.savefig('data/plots/sharp_data_a_bar.pdf')
plt.show()

plt.step(data['grid'], data['dataZs'][0,:])
plt.grid()
plt.xlabel(r'$x$')
plt.ylabel(r'$z$')
plt.savefig('data/plots/sharp_data_zs.pdf')
plt.show()

a_L2 = jnp.sqrt( jnp.sum( data['dataAs'][0,:]**2. * (data['grid'][1] - data['grid'][0]) ) )
a_bar_field = data['dataAs'][0,:] / a_L2
ell = 10
kappa_1 = 1.
kappa_2 = 2.
z_field = jnp.tanh((a_bar_field) * ell) * 0.5 * (kappa_2 - kappa_1) + kappa_1 + 0.5 * (kappa_2 - kappa_1)

plt.plot(data['grid'], z_field, label='smoothed $z$')
# plt.step(data['grid'], data['dataZs'][0,:])
plt.grid()
plt.xlabel(r'$x$')
plt.ylabel(r'$z$')
plt.savefig('data/plots/smooth_data_zs.pdf')
plt.show()

print(data['params'])

In [None]:
file = 'data/data2.pkl'
with open(file, 'rb') as f:
    data = pickle.load(f)
# print(data)

n_data = data['cfg'].Observation.n_data
print('n_data: ', n_data)

plt.plot(data['grid'], data['dataUs'].T)
plt.grid()
plt.show()

plt.plot(data['grid'], data['dataAs'].T)
plt.grid()
plt.show()

plt.step(data['grid'], data['dataZs'][0,:])
plt.grid()
plt.show()

print(data['cfg'])
print(data['params'])

In [None]:
plt.rcParams.update({
    'axes.labelsize':   16,
    'axes.titlesize':   16,
    'xtick.labelsize' : 12,
    'ytick.labelsize' : 12,
          })


In [None]:
with open('data/model_st_smooth_adam.pkl', 'rb') as f:
    model_smooth = pickle.load(f)

loss = model_smooth['loss']
plt.semilogy(loss)
plt.grid()
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig('data/plots/plot_loss_smooth_adam.pdf')
plt.show()

aux = model_smooth['aux']
lambda_vals = aux['lambdas']
plt.plot(jnp.exp(lambda_vals), label=r'$\lambda$')
plt.plot(lambda_vals**0. * 8., '--', c='k', label=r'true $\lambda$')

kappa_p = aux['kappas'][:,0]
kappa_m = aux['kappas'][:,1]
plt.plot(jnp.exp(kappa_m), label=r'$\kappa^+$')
plt.plot(kappa_p**0. * 2., ':', c='k', label=r'true $\kappa^+$')
plt.plot(jnp.exp(kappa_p), label=r'$\kappa^-$')
plt.plot(kappa_m**0.* 1., '-.', c='k', label=r'true $\kappa^{-}$')
plt.grid()
plt.legend()
plt.xlabel('Iterations')
plt.ylabel(r'$\lambda, \kappa^+, \kappa^-$')
plt.savefig('data/plots/plot_convergence_smooth_adam.pdf')
plt.show()

print('time:', model_smooth['time'])

print('last kappa^- = ', jnp.exp(kappa_m)[-1])
print('last kappa^+ = ', jnp.exp(kappa_p)[-1])
print('last lambda = ', jnp.exp(lambda_vals)[-1])

print('relative error km', 100*jnp.abs(2-jnp.exp(kappa_m)[-1])/2 )
print('relative error kp',  100*jnp.abs(1-jnp.exp(kappa_p)[-1])/1 )
print('relative error lambda',  100*jnp.abs(8-jnp.exp(lambda_vals)[-1])/8 )

In [None]:
with open('data/model_st_sharp_adam.pkl', 'rb') as f:
    model_sharp = pickle.load(f)

loss = model_sharp['loss']
plt.semilogy(loss)
plt.grid()
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig('data/plots/plot_loss_sharp_adam.pdf')
plt.show()

aux = model_sharp['aux']
lambda_vals = aux['lambdas']
plt.plot(jnp.exp(lambda_vals), label=r'$\lambda$')
plt.plot(lambda_vals**0. * 8., '--', c='k', label=r'true $\lambda$')

kappa_p = aux['kappas'][:,0]
kappa_m = aux['kappas'][:,1]
plt.plot(jnp.exp(kappa_m), label=r'$\kappa^+$')
plt.plot(kappa_p**0. * 2., ':', c='k', label=r'true $\kappa^+$')
plt.plot(jnp.exp(kappa_p), label=r'$\kappa^-$')
plt.plot(kappa_m**0.* 1., '-.', c='k', label=r'true $\kappa^{-}$')
plt.grid()
plt.legend()
plt.xlabel('Iterations')
plt.ylabel(r'$\lambda, \kappa^+, \kappa^-$')
plt.savefig('data/plots/plot_convergence_sharp_adam.pdf')
plt.show()

print('last kappa^- = ', jnp.exp(kappa_m)[-1])
print('last kappa^+ = ', jnp.exp(kappa_p)[-1])
print('last lambda = ', jnp.exp(lambda_vals)[-1])

print('relative error km', 100*jnp.abs(2-jnp.exp(kappa_m)[-1])/2 )
print('relative error kp',  100*jnp.abs(1-jnp.exp(kappa_p)[-1])/1 )
print('relative error lambda',  100*jnp.abs(8-jnp.exp(lambda_vals)[-1])/8 )