In [None]:
import copy
import numpy as np
import jax.numpy as jnp
from jax import jit
from jax import config
from jax.flatten_util import ravel_pytree
from adoptODE import train_adoptODE, simple_simulation

config.update('jax_platform_name', 'cpu')

def lorenz96(**kwargs_sys):
    vars = kwargs_sys['vars']

    @jit
    def eom(y, t, params, iparams, exparams):
        p = params['p']
        x = jnp.array([y[v] for v in vars])
        dx = jnp.array(jnp.roll(x, 1)*(jnp.roll(x, -1) - jnp.roll(x, 2)) - x + p)
        return dict(zip(vars, dx))

    @jit
    def loss(ys, params, iparams, exparams, targets):
        flat_fit = ravel_pytree(ys)[0]
        flat_target = ravel_pytree(targets)[0]
        return jnp.nanmean((flat_fit-flat_target)**2)
   
    def gen_params():
        return {}, {}, {}

    def gen_y0():
        y = kwargs_sys['init']
        return dict(zip(vars, y))

    return eom, loss, gen_params, gen_y0, {}

class Prob_density_vars:
    def __init__(self, rng, measured):
        self.rng = rng
        self.measured = measured
    
    def __call__(self):
        return rng.choice(self.measured, 1)

class Uniform:
    def __init__(self, rng, range):
        self.rng = rng
        self.range = range
    
    def __call__(self):
        return np.array([rng.uniform(self.range[0], self.range[1])])


In [None]:
num_iter = 3
every = 3
len_segs = 100
gen_iguess = 'uniform'
iguess_range = [-1, 4]
threshold = 10**-2
D = 9
N = 10000
trans = 1000 
dt = 0.01
p = 8.17
epochs = 3000
lr = 0.0
lr_y0 = 0.01
seed = 0

rng = np.random.default_rng(seed=seed)
vars = ['x'+str(i+1).zfill(3) for i in range(D)]
vars_measured = ['x'+str(i+1).zfill(3) for i in range(D) if i%every==0]
vars_unmeasured = sorted(list(set(vars)-set(vars_measured)))

# Setting up system and training properties
kwargs_sys = {'N_sys':1, 'vars':vars, 'init':rng.random(D)}
upper_bound_y0 = dict(zip(vars, [jnp.inf]*D))
lower_bound_y0 = dict(zip(vars, [-jnp.inf]*D))
t_all = jnp.arange(0, (N+trans)*dt, dt)
t_evals = jnp.arange(0, len_segs*dt, dt)
kwargs_adoptODE = {'epochs':epochs, 'lr':lr, 'lr_y0':lr_y0}

# Generate entire time series
true = simple_simulation(lorenz96, t_all, kwargs_sys, kwargs_adoptODE, params={'p': p})
true = np.array([true.ys[v][0][trans:] for v in vars])
kwargs_sys['init'] = true[:, 0]

for v in vars_measured:
    upper_bound_y0[v] = true[vars.index(v), 0]
    lower_bound_y0[v] = true[vars.index(v), 0]

# Initial guess generation method
if gen_iguess == "prob_density_vars": # drow from the probability density of Lorenz96 variables
    gen_iguess = Prob_density_vars(rng, true[::every, :].flatten())
elif gen_iguess == "uniform": # drow from a uniform distribution
    gen_iguess = Uniform(rng, iguess_range)

# Generate training data
dataset = simple_simulation(lorenz96, t_evals, kwargs_sys, kwargs_adoptODE, params={'p': p}, params_train={'p': p})
ys_true = copy.deepcopy(dataset.ys)
for v in vars_unmeasured:
    dataset.ys[v] = dataset.ys[v]*jnp.nan

# Iteratively update initial guess
iguess_histroy = []
for i in range(num_iter):
    # Determined initial value are recorded in upper_bound_y0
    iguess_histroy.append(copy.deepcopy(upper_bound_y0))

    for v in vars_unmeasured:
        if upper_bound_y0[v] == jnp.inf:
            dataset.y0_train[v] = gen_iguess()
        else:
            dataset.y0_train[v] = np.array([upper_bound_y0[v]])

    dataset.kwargs_adoptODE['upper_b_y0'] = upper_bound_y0
    dataset.kwargs_adoptODE['lower_b_y0'] = lower_bound_y0
    params_final, losses, errors, params_history = train_adoptODE(dataset, print_interval=100, save_interval=1)
    mse_measured = dict(zip(vars_measured, [np.mean((dataset.ys_sol[v] - ys_true[v])**2) for v in vars_measured]))

    for v1, v2 in zip(vars_measured, vars_measured[1:] + vars_measured[:1]):
        if (mse_measured[v1] < threshold) and (mse_measured[v2] < threshold):
            # If mean squared error of both measured variables v1 and v2 are smaller than threshold,
            # unmeasured variables between v1 and v2 are assumed to be successfully estimated
            idx1 = vars.index(v1)
            idx2 = vars.index(v2)
            
            if idx1 < idx2:
                vars_replace = vars[idx1+1:idx2]
            else:
                vars_replace = vars[vars.index(v1)+1:] + vars[:vars.index(v2)]
            for v_rep in vars_replace:
                # Initial guess of the estimated variables are replaced with the estimated initial values
                upper_bound_y0[v_rep] = dataset.ys_sol[v_rep][0, 0]
                lower_bound_y0[v_rep] = dataset.ys_sol[v_rep][0, 0]
