In [None]:
import copy
import numpy as np
import jax.numpy as jnp
from jax import jit
from jax import config
from jax.flatten_util import ravel_pytree
from adoptODE import dataset_adoptODE, train_adoptODE, simple_simulation

config.update('jax_platform_name', 'cpu')

def lorenz96(**kwargs_sys):
    vars = kwargs_sys['vars']
    vars_local = kwargs_sys['vars_local']

    @jit
    def eom(y, t, params, iparams, exparams):
        p = params['p']
        x = jnp.array([y[v] for v in vars])
        dx = jnp.array(jnp.roll(x, 1)*(jnp.roll(x, -1) - jnp.roll(x, 2)) - x + p)
        return dict(zip(vars, dx))

    @jit
    def loss(ys, params, iparams, exparams, targets):
        ys_local = {key: ys[key] for key in vars_local}
        targets_local = {key: targets[key] for key in vars_local}
        flat_fit = ravel_pytree(ys_local)[0]
        flat_target = ravel_pytree(targets_local)[0]
        return jnp.nanmean((flat_fit-flat_target)**2)
   
    def gen_params():
        return {}, {}, {}

    def gen_y0():
        y = kwargs_sys['init']
        return dict(zip(vars, y))

    return eom, loss, gen_params, gen_y0, {}

class Prob_density_vars:
    def __init__(self, rng, measured):
        self.rng = rng
        self.measured = measured
    
    def __call__(self):
        return rng.choice(self.measured, 1)

class Uniform:
    def __init__(self, rng, range):
        self.rng = rng
        self.range = range
    
    def __call__(self):
        return np.array([rng.uniform(self.range[0], self.range[1])])


In [None]:
num_iguesses = 10
D = 120
every = 6
gen_iguess_method = 'uniform'
iguess_range = [-1, 4]

N = 10000
len_segs = 100
trans = 1000 
dt = 0.01
p = 8.17
epochs = 3000
lr = 0.0
lr_y0 = 0.01
seed = 0

rng = np.random.default_rng(seed=seed)
vars = ['x'+str(i+1).zfill(3) for i in range(D)]
vars_measured = ['x'+str(i+1).zfill(3) for i in range(D) if i%every==0]
vars_unmeasured = sorted(list(set(vars)-set(vars_measured)))
vars_local_list = zip(vars_measured, vars_measured[1:] + vars_measured[:1])

# Setting up system and training properties
kwargs_sys = {'N_sys':1, 'vars':vars, 'init':rng.random(D)}
kwargs_sys['vars_local'] = None
t_all = jnp.arange(0, (N+trans)*dt, dt)
t_evals = jnp.arange(0, len_segs*dt, dt)
kwargs_adoptODE = {'epochs':epochs, 'lr':lr, 'lr_y0':lr_y0}

# Generate entire time series
true = simple_simulation(lorenz96, t_all, kwargs_sys, kwargs_adoptODE, params={'p': p})
true = np.array([true.ys[v][0][trans:] for v in vars])
kwargs_sys['init'] = true[:, 0]

# Initial guess generation method
if gen_iguess_method == "prob_density_vars": # drow from the probability density of Lorenz96 variables
    gen_iguess = Prob_density_vars(rng, true[::every, :].flatten())
elif gen_iguess_method == "uniform": # drow from a uniform distribution
    gen_iguess = Uniform(rng, iguess_range)

# Generate training data
dataset = simple_simulation(lorenz96, t_evals, kwargs_sys, kwargs_adoptODE, params={'p': p}, params_train={'p': p})
ys_true = copy.deepcopy(dataset.ys)
for v in vars_unmeasured:
    dataset.ys[v] = dataset.ys[v]*jnp.nan


mse_measured = np.zeros((int(D/every), num_iguesses))
estimated_init = np.zeros((D, int(D/every)*num_iguesses))
for i, vars_local in enumerate(vars_local_list):
    # Estimate unmeasured variables between vars_local[0] and vars_local[1]
    kwargs_sys['vars_local'] = vars_local
    dataset = dataset_adoptODE(lorenz96, dataset.ys, t_evals, kwargs_sys, kwargs_adoptODE, true_y0=dataset.y0, params_train={'p': p}, true_params={'p': p})
    for j in range(num_iguesses):
        for v in vars:
            if v in vars_measured:
                dataset.y0_train[v] = dataset.y0[v]
            else:
                dataset.y0_train[v] = gen_iguess()

        params_final, losses, errors, params_history = train_adoptODE(dataset, print_interval=100, save_interval=1)
        mse_measured[i, j] = np.mean((np.array([dataset.ys_sol[v].flatten() for v in vars_local]) - np.array([ys_true[v].flatten() for v in vars_local]))**2)
        estimated_init[:, i*num_iguesses+j] = ravel_pytree(dataset.y0_train)[0]
