In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
import jax 
import jax.numpy as jnp
import pickle
import matplotlib.pyplot as plt

from phyOT.utils import *

In [None]:
with open('data/models_mt_2.pkl', 'rb') as f:
    models_mt = pickle.load(f)

# error_on_lambdas
    

last_avg_lambdas = []
last_avg_kappa_p = []
last_avg_kappa_m = []

last_stddev_lambdas = []
last_stddev_kappa_p = []
last_stddev_kappa_m = []

# batches = jnp.array([1, 10, 100, 1000, 10000])
# dss = jnp.array([10, 100, 1000, 10000])

# batches = jnp.array([1, 10])
# dss = jnp.array([10, 100])

cfg = models_mt['run_0_0_0']['cfg']
n_last = 1

dss     = models_mt['n_data_list']
batches = models_mt['batch_list']
T   = len(models_mt['trial'])

print(dss, batches, T)

for i in range(0, len(dss)):

    last_avg_lambdas_nb = []
    last_avg_kappa_p_nb = []
    last_avg_kappa_m_nb = []

    last_stddev_lambdas_nb = []
    last_stddev_kappa_p_nb = []
    last_stddev_kappa_m_nb = []

    for j in range(0, len(batches)):

        lambda_vals = jnp.zeros( ( T, cfg.train_iters  ) )
        kappa_p_vals = jnp.zeros( ( T, cfg.train_iters  ) )
        kappa_m_vals = jnp.zeros( ( T, cfg.train_iters  ) )

        for t in range(0, T):

            print('dataset number: ', i, 'batches: ', j, 'run:', t)
            model = models_mt[f'run_{i}_{j}_{t}']

            loss = model['loss']
            aux = model['aux']
            
            kappa_sort = jnp.sort(jnp.exp(aux['kappas']), axis=-1)

            lambda_vals = lambda_vals.at[t, :].set( jnp.exp(aux['lambdas']) )
            
            # kappa_p_vals = kappa_p_vals.at[t, :].set( jnp.exp(aux['kappas'][:,0]) )
            # kappa_m_vals = kappa_m_vals.at[t, :].set( jnp.exp(aux['kappas'][:,1]) )
            
            kappa_p_vals = kappa_p_vals.at[t, :].set( kappa_sort[:,1]) 
            kappa_m_vals = kappa_m_vals.at[t, :].set( kappa_sort[:,0]) 
            
        print(i, j)
        last_avg_kappa_p_nb.append(jnp.mean(kappa_p_vals[:, -n_last:]))
        last_avg_kappa_m_nb.append(jnp.mean(kappa_m_vals[:, -n_last:]))
        last_avg_lambdas_nb.append(jnp.mean(lambda_vals[:, -n_last:]))

        last_stddev_lambdas_nb.append(jnp.std(lambda_vals[:, -n_last:]))
        last_stddev_kappa_p_nb.append(jnp.std(kappa_p_vals[:, -n_last:]))
        last_stddev_kappa_m_nb.append(jnp.std(kappa_m_vals[:, -n_last:]))

    last_avg_lambdas.append(last_avg_lambdas_nb)
    last_avg_kappa_p.append(last_avg_kappa_p_nb)
    last_avg_kappa_m.append(last_avg_kappa_m_nb)

    last_stddev_lambdas.append(last_stddev_lambdas_nb)
    last_stddev_kappa_p.append(last_stddev_kappa_p_nb)
    last_stddev_kappa_m.append(last_stddev_kappa_m_nb)

last_avg_lambdas = jnp.array(last_avg_lambdas)
last_avg_kappa_p = jnp.array(last_avg_kappa_p)
last_avg_kappa_m = jnp.array(last_avg_kappa_m)

last_stddev_lambdas = jnp.array(last_stddev_lambdas)
last_stddev_kappa_p = jnp.array(last_stddev_kappa_p)
last_stddev_kappa_m = jnp.array(last_stddev_kappa_m)


colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] 
print(batches)

for i in range(0, len(dss)):
    plt.errorbar(batches, last_avg_lambdas[i], yerr=last_stddev_lambdas[i],c=colors[i], capsize=5, label=f'$N={dss[i]}$')

plt.semilogx(batches[:], last_avg_lambdas[i]**0. * 8., '--k')
# plt.title('parameter estimatimaiton')
plt.xlabel(r'$N_s$')
plt.ylabel(r'$\lambda$')
plt.legend()
plt.grid()
plt.savefig('data/plots/estimation_lambda_mt2.pdf')
plt.show()
for i in range(0, len(dss)):
    plt.errorbar(batches, last_avg_kappa_p[i],yerr=last_stddev_kappa_p[i], c=colors[i], capsize=5, label=f'$N={dss[i]}$')
plt.semilogx(batches, last_avg_kappa_p[i]**0. * 2., '--k')
plt.xlabel(r'$N_s$')
plt.ylabel(r'$\kappa^+$')
plt.legend()
plt.grid()
plt.savefig('data/plots/estimation_kappa_p_mt2.pdf')
plt.show()
for i in range(0, len(dss)):
    plt.errorbar(batches, last_avg_kappa_m[i],yerr=last_stddev_kappa_m[i], c=colors[i], capsize=5, label=f'$N={dss[i]}$')
plt.semilogx(batches, last_avg_kappa_m[i]**0. * 1., '--k')
plt.xlabel(r'$N_s$')
plt.ylabel(r'$\kappa^-$')
plt.legend()
plt.grid()
plt.savefig('data/plots/estimation_kappa_m_mt2.pdf')
plt.show()