# Benchmark of GKTL on the Ornstein-Uhlenbeck process

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget
matplotlib.rc('font', size=18)
import uncertainties as unc
import scipy.special as ss

from tqdm.notebook import tqdm

import sys
sys.path.append('../')
import general_purpose.utilities as ut
import general_purpose.uplotlib as uplt

from importlib import reload

import rea.reconstruct as rec

In [None]:
def get_run(folder, iteration=None):
    if iteration is None:
        try:
            run = {'folder': folder, 'rec_dict': ut.json2dict(f'{folder}/reconstructed.json')}
        except FileNotFoundError:
            # detect last iteration
            iterations = [int(f[1:]) for f in os.listdir(folder) if f[0] == 'i']
            iteration = np.max(iterations)
            run = {'folder': folder, 'rec_dict': rec.reconstruct(f'{folder}/i{iteration:04d}')}
            ut.dict2json(run['rec_dict'], f'{folder}/reconstructed.json')
    else:
        run = {'folder': folder, 'rec_dict': rec.reconstruct(f'{folder}/i{iteration:04d}')}
        
    return run

def compute_trajectories(run):
    folder = run['folder']
    rec_dict = run['rec_dict']
    rec_trajs = {}

    for rname, r in tqdm(rec_dict['members'].items()):
        traj = []
        for subfolder,e in zip(rec_dict['folders'], r['ancestry']):
            traj.append(np.load(f'{folder}/{subfolder}/{e}-traj.npy'))
        traj = np.concatenate(traj)

        rec_trajs[rname] = traj
    
    run['traj'] = rec_trajs
    
    return rec_trajs

def plot_traj(run, start=0, **kwargs):
    if 'traj' not in run:
        compute_trajectories(run)
    label = kwargs.pop('label', None)
    for traj in run['traj'].values():
        plt.plot(traj[start:,0]-traj[0,0], -traj[start:,1], label=label, **kwargs)
        label = None


def expectation(run, func, **kwargs):
    _f = np.array([func(x, **kwargs) for x in run['observables'].values()])
    _w = np.array([r['weight'] for r in run['rec_dict']['members'].values()])
    # _w = np.exp(-np.array([r['cum_log_escore'] for r in rec_dict['members'].values()]))
    # _n = np.exp(rec_dict['cum_log_norm_factor'])

    mean = np.mean(_f*_w)
    std = np.sqrt(np.mean((_f*_w)**2)/len(_w))

    return unc.ufloat(mean,std)

In [None]:
def f(x, a):
    return float(x>=a)

lam = 1
sig = 1
def ou_var(t):
    return sig**2/(2*lam)*(1 - np.exp(-2*lam*t))

cum_gaus = lambda x: 0.5*(1 + ss.erf(x/np.sqrt(2)))

overcoming_prob = lambda x,v: cum_gaus(-x/np.sqrt(v))

v = ou_var(2)

## Get the runs

In [None]:
c_runs = [get_run(f'./__test__old/c{i}--k__0--nens__999--T__10/') for i in tqdm(range(6))]
# c_runs = [get_run(f'./__test__old/c{i}--k__0--nens__100--T__10/') for i in range(6)]
runs = [get_run(f'./__test__old/f{i}--k__4--nens__999--T__10') for i in tqdm(range(6))]
# runs = [get_run(f'./__test__old/f{i}--k__2--nens__999--T__10') for i in range(8)]
# runs = [get_run(f'./__test__old/f{i}--k__4--nens__100--T__10') for i in range(14)]

## Plot trajectories

In [None]:
run  = runs[0]
_ = compute_trajectories(run)

c_run = c_runs[0]
_ = compute_trajectories(c_run)

In [None]:
plt.close(1)
fig,ax = plt.subplots(num=1, figsize=(9,6))

plot_traj(run, color='gray', alpha=0.2)

t = np.linspace(0,2,100)
plt.plot(t, -2*np.sqrt(ou_var(t)), color='black', linestyle='dashed')

fig.tight_layout()

In [None]:
from matplotlib.gridspec import GridSpec

In [None]:
c_run['traj']['r001'][1,0]

In [None]:
plt.close(4)
fig = plt.figure(num=4, figsize=(10,6))
gs = GridSpec(4,6)
ax_plot = fig.add_subplot(gs[:,0:4])

# plot_traj(c_run, alpha=0.2, color='blue')
plot_traj(run, alpha=0.1, color='black')
plt.plot(t, -3*np.sqrt(ou_var(t)), color='red', linestyle='dashed')
plt.plot(t, 3*np.sqrt(ou_var(t)), color='red', linestyle='dashed')

plt.xlabel('time')
plt.ylabel('$X$')



ax_hist = fig.add_subplot(gs[:,4:])

bin_edges = np.linspace(ax_plot.get_ylim()[0], ax_plot.get_ylim()[1], 30)

#hist of the old attractor
c_pts = -np.array([c[-1,-1] for c in c_run['traj'].values()])
hist, bin_edges = np.histogram(c_pts, bins=bin_edges, density=True)
ax_hist.plot(hist, 0.5*(bin_edges[:-1] + bin_edges[1:]), color='blue')

# #hist of the new attractor
pts = -np.array([c[-1,-1] for c in run['traj'].values()])
hist, bin_edges = np.histogram(pts, bins=bin_edges, density=True)
ax_hist.plot(hist, 0.5*(bin_edges[:-1] + bin_edges[1:]), color='black')

ax_hist.set_ylim(*ax_plot.get_ylim())
ax_hist.set_yticklabels([])

fig.tight_layout()

In [None]:
fig.savefig('../../download/ou-side-hist-N1000.pdf')

## Probabilities

In [None]:
a = np.linspace(0,4,101)

es = [np.array([expectation(run, f, a=_a) for _a in a]) for run in runs]
c_es = [np.array([expectation(run, f, a=_a) for _a in a]) for run in c_runs]
e_th = overcoming_prob(a,v)

In [None]:
n_es = np.vstack([uplt.nominal_value(e) for e in es])
c_n_es = np.vstack([uplt.nominal_value(e) for e in c_es])

s_es = np.vstack([uplt.std_dev(e) for e in es])
c_s_es = np.vstack([uplt.std_dev(e) for e in c_es])

m_es = uplt.avg(n_es, axis=0)
c_m_es = uplt.avg(c_n_es, axis=0)

relative_error = np.sqrt(np.mean((n_es/uplt.nominal_value(m_es) - 1)**2, axis=0))
c_relative_error = np.sqrt(np.mean((c_n_es/uplt.nominal_value(c_m_es) - 1)**2, axis=0))

est_rel_err = uplt.avg(s_es/n_es, axis=0)
c_est_rel_err = uplt.avg(c_s_es/c_n_es, axis=0)

ms_bias = np.sqrt(np.mean((n_es/e_th - 1)**2, axis=0))
c_ms_bias = np.sqrt(np.mean((c_n_es/e_th - 1)**2, axis=0))

In [None]:
plt.close(2)
fig,ax = plt.subplots(num=2, figsize=(9,6))

uplt.errorband(a, m_es, color='black', label='REA: $k=4$')
uplt.errorband(a, c_m_es, alpha=0.5, color='blue', label=r'control: $k=0$')

plt.semilogy(a, e_th, color='red', label='theoretical', linestyle='dashed')

# for e in es:
#     uplt.plot(a,e, alpha=0.5)
# plt.plot(a,uplt.std_dev(e)/uplt.nominal_value(e))

plt.xlabel(r'$a$')
ax.set_xticklabels([-t if t>0 else t for t in ax.get_xticks()])

plt.ylabel(r'$\mathbb{P}\left( X(T) \leq a \right)$')

# plt.title(r'$N = 1000,\, k=4$')

plt.legend()

fig.tight_layout()

In [None]:
fig.savefig('../../download/ou-prob-N1000.pdf')

In [None]:
plt.close(2)
fig,ax = plt.subplots(num=2, figsize=(9,6))


plt.plot(a, uplt.nominal_value(m_es), color='black', label='REA: $k=4$')
plt.plot(a, uplt.nominal_value(c_m_es), color='blue', label=r'control: $k=0$')

plt.plot(a, n_es.T, color='black', alpha=0.2)
plt.plot(a, c_n_es.T, color='blue', alpha=0.2)

plt.semilogy(a, e_th, color='red', label='theoretical', linestyle='dashed')


# for e in es:
#     uplt.plot(a,e, alpha=0.5)
# plt.plot(a,uplt.std_dev(e)/uplt.nominal_value(e))

plt.xlabel(r'$a$')
ax.set_xticklabels([-t if t>0 else t for t in ax.get_xticks()])

plt.ylabel(r'$\mathbb{P}\left( X(T) \leq a \right)$')

# plt.title(r'$N = 1000,\, k=4$')

plt.legend()

fig.tight_layout()

### Relative error

In [None]:
plt.close(3)
fig,ax = plt.subplots(num=3, figsize=(9,6))

plt.plot(a, relative_error, color='red', label='$k=4$') # error between different realization of the algorithm
plt.plot(a, ms_bias, color='red', linestyle='dashed') # mean square bias

plt.plot(a, c_relative_error, color='blue', label='$k=0$') # error between different realization of the algorithm
plt.plot(a, c_ms_bias, color='blue', linestyle='dashed') # mean square bias

# for e in es:
#     plt.plot(a, uplt.std_dev(e)/uplt.nominal_value(e)) # error estimated by each realization

plt.legend()
plt.xlabel('$a$')
plt.ylabel('relative error')
fig.tight_layout()

In [None]:
plt.close(4)
fig,ax = plt.subplots(num=4, figsize=(9,6))

uplt.plot(a, m_es)
plt.plot(a, e_th, color='black')

fig.tight_layout()

In [None]:
plt.close(5)
fig,ax = plt.subplots(num=5, figsize=(9,6))

plt.plot(a, relative_error, color='red', label='$k=2$') # error between different realization of the algorithm
plt.plot(a, ms_bias, color='red', linestyle='dashed') # mean square bias
uplt.errorband(a, est_rel_err, color='red', linestyle='dotted')

plt.plot(a, c_relative_error, color='blue', label='$k=0$') # error between different realization of the algorithm
plt.plot(a, c_ms_bias, color='blue', linestyle='dashed') # mean square bias
uplt.errorband(a, c_est_rel_err, color='blue', linestyle='dotted')

# for e in es:
#     plt.plot(a, uplt.std_dev(e)/uplt.nominal_value(e)) # error estimated by each realization

plt.ylim(0,1)
plt.legend()
plt.xlabel('$a$')
plt.ylabel('relative error')
fig.tight_layout()

In [None]:
np.nanmean(relative_error/est_rel_err)