In [None]:
import jax.numpy as jnp
from jax import jit
from jax.flatten_util import ravel_pytree
from jax import config
import numpy as np
import pandas as pd
from adoptODE import train_adoptODE, simple_simulation
import matplotlib.pyplot as plt
import copy
import os
import json
config.update('jax_platform_name', 'cpu')
# config.update('jax_platform_name', 'gpu')

In [None]:
def lorenz96(**kwargs_sys):
    vars = kwargs_sys['vars']

    @jit
    def eom(y, t, params, iparams, exparams):
        p = params['p']
        x = jnp.array([y[v] for v in vars])
        dx = jnp.array(jnp.roll(x, 1)*(jnp.roll(x, -1) - jnp.roll(x, 2)) - x + p)
        return dict(zip(vars, dx))

    @jit
    def loss(ys, params, iparams, exparams, targets):
        flat_fit = ravel_pytree(ys)[0]
        flat_target = ravel_pytree(targets)[0]
        return jnp.nanmean((flat_fit-flat_target)**2)
   
    def gen_params():
        return {}, {}, {}

    def gen_y0():
        y = kwargs_sys['init']
        return dict(zip(vars, y))

    return eom, loss, gen_params, gen_y0, {}

In [None]:
every = 3
threshold = 10**-2
len_seg = 100
interval = [-1, 4]
D = 9
N = 10000
trans = 1000
dt = 0.01
vars = ['x'+str(i+1).zfill(2) for i in range(D)]
vars_measured = ['x'+str(i+1).zfill(2) for i in range(D) if i%every==0]
p = 8.17
epochs = 3000
lr = 0
lr_y0 = 0.01
seed = 0

rng = np.random.default_rng(seed=seed)
kwargs_sys = {'N_sys':1, 'vars':vars, 'init':rng.random(D)}


# Setting up system and training properties
num_seg = int(N/len_seg)
t_all = jnp.arange(0, (N+trans)*dt, dt)
t_evals = jnp.arange(0, len_seg*dt, dt)
kwargs_adoptODE = {
    'epochs':epochs, 
    'lr':lr,
    'lr_y0':lr_y0,
    }


name = "every"+str(every)
dir = os.path.join("results", name)
os.mkdir(dir)


estimated = np.empty((D, N))
mse_all = np.empty(num_seg)
mse_measured = np.empty(num_seg)
counts = np.empty(num_seg)
count = 0

X = simple_simulation(lorenz96, t_all, kwargs_sys, kwargs_adoptODE, params={'p': p})
X = np.array([X.ys[v][0][trans:] for v in vars])


i = 0
init_flag = 'rand'
while i < num_seg:
    kwargs_sys['init'] = X[:, len_seg*i]
    dataset = simple_simulation(lorenz96, t_evals, kwargs_sys, kwargs_adoptODE, params={'p': p}, params_train={'p': p})
    ys_true = copy.deepcopy(dataset.ys)
    for v in sorted(list(set(vars)-set(vars_measured))):
        dataset.ys[v] = dataset.ys[v]*jnp.nan
        if init_flag == 'rand':
            dataset.y0_train[v] = np.array([rng.uniform(interval[0], interval[1])])
        elif init_flag == 'end':
            dataset.y0_train[v] = np.array([ys_sol[v][0, -1]])
    params_final, losses, errors, params_history = train_adoptODE(dataset, print_interval=100, save_interval=1)
    mse_measured_i = np.mean((np.array([dataset.ys_sol[v].flatten() for v in vars_measured]) - np.array([ys_true[v].flatten() for v in vars_measured]))**2)
    if mse_measured_i < threshold:
        init_flag = 'end'
        ys_sol = copy.deepcopy(dataset.ys_sol)
        estimated[:, i*len_seg:(i+1)*len_seg] = np.array(list(ys_sol.values()))[:, 0, :]
        mse_all[i] = np.mean((ravel_pytree(dataset.ys_sol)[0] - ravel_pytree(ys_true)[0])**2)
        mse_measured[i] = mse_measured_i
        counts[i] = count
        count = 0
        i += 1
    else:
        init_flag = 'rand'
        count += 1


pd.DataFrame(estimated).to_csv(os.path.join(dir, "estimated.csv"), header=False, index=False)
pd.DataFrame(mse_all).to_csv(os.path.join(dir, "mse_all.csv"), header=False, index=False)
pd.DataFrame(mse_measured).to_csv(os.path.join(dir, "mse_measured.csv"), header=False, index=False)
pd.DataFrame(X).to_csv(os.path.join(dir, "true.csv"), header=False, index=False)
pd.DataFrame(counts).to_csv(os.path.join(dir, "counts.csv"), header=False, index=False)


params = {}
params['every'] = every
params['threshold'] = threshold
params['len_seg'] = len_seg
params['interval'] = interval
params['D'] = D
params['N'] = N
params['trans'] = trans
params['dt'] = dt
params['vars'] = vars
params['vars_measured'] = vars_measured
params['p'] = p
params['epochs'] = epochs
params['lr'] = lr
params['lr_y0'] = lr_y0
params['seed'] = seed


with open(os.path.join(dir, 'params.json'), 'w') as f:
    json.dump(params, f, indent=4)