In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
import time
import jax 
import jax.numpy as jnp
import pickle
import matplotlib.pyplot as plt
from config import *
from phyOT.utils import *
from phyOT.utils import *

In [None]:
def flush(u):
    return jnp.where(jnp.abs(u)<1e-8, 0, u)

In [None]:
file = 'data/data.pkl'

with open(file, 'rb') as f:
    data = pickle.load(f)

ell = 5
idx = [2, 3, 4, 5]
for ix in idx:
    plt.contourf(data['grid'][0], data['grid'][1], flush(data['dataUs'][ix]), 50)
    plt.colorbar()
    plt.scatter(data['obs_locs'][:, 0], data['obs_locs'][:, 1], c='k', s=5)
    plt.xlabel(r'$x_{(1)}$')
    plt.ylabel(r'$x_{(2)}$')
    plt.savefig(f'data/plots/sln_obs_{ix}.png')
    plt.show()

    plt.contourf(data['grid'][0], data['grid'][1], flush(data['dataZs'][ix]), 50)
    plt.colorbar()
    plt.xlabel(r'$x_{(1)}$')
    plt.ylabel(r'$x_{(2)}$')
    plt.savefig(f'data/plots/z_field_{ix}.png')
    plt.show()

    a_field = data['dataAs'][ix]
    dx = 1. / (solver.grid[0].shape[0] - 1)
    a_L2 = jnp.sqrt( jnp.sum( a_field**2. * dx**2. ) )
    print(f'a_L2: {a_L2}')
    a_norm = a_L2
    a_bar_field = a_field / a_norm
    kappa_1, kappa_2 = 1., 2.
    z_field = jnp.tanh((a_bar_field) * ell ) * 0.5 * (kappa_2 - kappa_1) + kappa_1 + 0.5 * (kappa_2 - kappa_1)
    plt.contourf(data['grid'][0], data['grid'][1], flush(z_field), 50)
    plt.colorbar()
    plt.xlabel(r'$x_{(1)}$')
    plt.ylabel(r'$x_{(2)}$')
    plt.savefig(f'data/plots/z_field_smooth_{ix}.png')
    plt.show()


In [None]:
with open('data/model.pkl', 'rb') as f:
    model_smooth = pickle.load(f)

loss = model_smooth['loss']
plt.semilogy(loss)
plt.grid()
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.savefig('data/plots/plot_loss.pdf')
plt.show()
lw = 1
lw_big = 2
w_avg = 100

aux = model_smooth['aux']
lambda_vals = jnp.exp( aux['lambdas'] )
kappa_p = jnp.exp(aux['kappas'][:,0])
kappa_m = jnp.exp(aux['kappas'][:,1])


plt.plot(lambda_vals, linewidth=lw_big, label=r'$\lambda$')
plt.plot(lambda_vals**0. * 5., '--', c='k', label=r'true $\lambda$')


plt.plot(kappa_m, linewidth=lw_big, label=r'$\kappa^+$')
plt.plot(kappa_p**0. * 2., ':', c='k', label=r'true $\kappa^+$')
plt.plot(kappa_p, linewidth=lw_big, label=r'$\kappa^-$')
plt.plot(kappa_m**0.* 1., '-.', c='k', label=r'true $\kappa^{-}$')
plt.grid()
plt.xlabel('Iterations')
plt.ylabel(r'$\lambda, \kappa^+, \kappa^-$')
plt.legend(loc='best')

plt.savefig('data/plots/plot_convergence.pdf')
plt.show()

print('last kappa^- = ', kappa_m[-1])
print('last kappa^+ = ', kappa_p[-1])
print('last lambda = ', lambda_vals[-1])

print('relative error kp', 100*jnp.abs(2-kappa_m[-1])/2 )
print('relative error km',  100*jnp.abs(1-kappa_p[-1])/1 )
print('relative error lambda',  100*jnp.abs(5-lambda_vals[-1])/5 )