In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
import time
import jax 
import pickle
import matplotlib.pyplot as plt
from config import *

from phyOT.utils import *

In [None]:
def flush(u):
    return jnp.where(jnp.abs(u)<1e-8, 0, u)

In [None]:
file = 'data/data.pkl'

with open(file, 'rb') as f:
    data = pickle.load(f)

solutions = data['dataUs']
ell = 5.

idx = [2,3,4]

for ix in idx:
    plt.contourf(data['grid'][0], data['grid'][1], data['dataUs'][ix], 50)
    plt.scatter(data['obs_locs'][:, 0], data['obs_locs'][:, 1], c='k', s=5)
    plt.colorbar()
    plt.xlabel(r'$x_{(1)}$')
    plt.ylabel(r'$x_{(2)}$')
    plt.savefig('data/plots/50Obs_sample_sln.png')
    plt.show()

    plt.contourf(data['grid'][0], data['grid'][1], flush(data['dataZs'][0]), 50)
    plt.show()

    a_field = data['dataAs'][ix]
    dx = 1. / (solver.grid[0].shape[0] - 1)

    a_L2 = jnp.sqrt( jnp.sum( a_field**2. * dx**2. ) )

    print(f'a_L2: {a_L2}')

    a_norm = a_L2

    a_bar_field = a_field / a_norm

    kappa_1, kappa_2 = 1., 2.

    z_field = jnp.tanh((a_bar_field) * ell ) * 0.5 * (kappa_2 - kappa_1) + kappa_1 + 0.5 * (kappa_2 - kappa_1)
    
    plt.contourf(data['grid'][0], data['grid'][1], flush(data['dataZs'][0]), 50)
    plt.show()


In [None]:
file = 'data/model.pkl'

from config import *
from jax.scipy.signal import convolve
def convolve_avg(array, window):
    kernel = jnp.ones(window)
    new_array = convolve(array, kernel, mode='same') / convolve(jnp.ones_like(array), kernel, mode='same')
    return new_array

plt.rcParams.update({
    'axes.labelsize':   16,
    'axes.titlesize':   16,
    'xtick.labelsize' : 12,
    'ytick.labelsize' : 12,
        })


path_open  = file
path_save  = "convergence"
w_avg = None

file = path_open
with open(file, 'rb') as f:
    model1 = pickle.load(f)

aux = model1['aux']
lambda_vals_exp = jnp.exp(aux['lambdas'])
exp_kappa_p = jnp.exp(aux['kappas'][:,1])
exp_kappa_m = jnp.exp(aux['kappas'][:,0])

lw = 1

plt.plot(lambda_vals_exp, label=r'$\lambda$')
plt.plot(lambda_vals_exp**0. * 5., '--', c='k', label=r'true $\lambda$')

plt.plot(exp_kappa_p, label=r'$\kappa^+$')
plt.plot(exp_kappa_p**0. * 2., ':', c='k', label=r'true $\kappa^+$')

plt.plot(exp_kappa_m, label=r'$\kappa^-$')
plt.plot(exp_kappa_m**0.* 1., '-.', c='k', label=r'true $\kappa^{-}$')

if w_avg is not None:
    plt.plot(convolve_avg(lambda_vals_exp, w_avg), linewidth=lw, c='k')
    plt.plot(convolve_avg(exp_kappa_p, w_avg), linewidth=lw,  c='k')
    plt.plot(convolve_avg(exp_kappa_m, w_avg), linewidth=lw, c='k')

plt.grid()
plt.ylim(-0.5, 13.)
plt.legend(loc='best')
plt.xlabel('Iterations')
plt.ylabel(r'$\lambda, \kappa^+, \kappa^-$')
plt.savefig('data/plots/' + path_save + '.pdf')
plt.show()

loss = model1['loss']
aux  = model1['aux']['aux']
print(aux.shape)
plt.semilogy(aux[:,0] + aux[:,1], label='$J_1$')
# plt.semilogy(aux[:,1], label='reg')
plt.semilogy(aux[:,2], label='$J_2$')
plt.grid()
plt.legend()
plt.savefig('data/plots/' + path_save + '_losses' + '.pdf')
plt.show() 

print('last kappa^- = ', exp_kappa_m[-1])
print('last kappa^+ = ', exp_kappa_p[-1])
print('last lambda = ', lambda_vals_exp[-1])

print('relative error km', 100*jnp.abs(1-exp_kappa_m[-1])/1 )
print('relative error kp',  100*jnp.abs(2-exp_kappa_p[-1])/2 )
print('relative error lambda',  100*jnp.abs(5-lambda_vals_exp[-1])/5 )

