In [52]:
import jax.numpy as jnp
from jax import jit 
from jax.flatten_util import ravel_pytree

import numpy as np
import scipy
import matplotlib.pyplot as plt
import interpax

from adoptODE import train_adoptODE, simple_simulation, dataset_adoptODE
from mechanics import *
from data_reading import *
import progressbar

In [53]:
l_a = 0.8876
t=0
true_params = {'k_g':1, 'l_g':1,'eta':.5,'k_a':1,'k_p':1,'l_p':1}

N = 0
size = 100
mode = "chaos"

"""
    Reads in necessary parameters from config.ini
"""

print("Preparing data...")
N,size,[] = read_config([])

print("Reading data...")
x_temp = read_vector("../data/SpringMassModel/x.csv",(N,2,size+1,size+1))
x_cm_temp = read_vector("../data/SpringMassModel/x_cm.csv",(N,2,size+1,size+1))
T = read_scalar("../data/SpringMassModel/T.csv",(1,N,size,size))[0]

  2% |#                                                                       |

Preparing data...
Reading data...
shape of data:  (2000, 2, 101, 101)


100% |########################################################################|
  3% |##                                                                      |

shape of data:  (2000, 2, 101, 101)


100% |########################################################################|
  2% |#                                                                       |

shape of data:  (1, 2000, 100, 100)


100% |########################################################################|


In [54]:
def define_system(**kwargs_sys):

    #bounds for parameters
    nu_min, nu_max = kwargs_sys['nu_min'], kwargs_sys['nu_max']
    m_min, m_max = kwargs_sys['m_min'], kwargs_sys['m_max']
    l_g_min, l_g_max = kwargs_sys['l_g_min'], kwargs_sys['l_g_max']
    l_p_min, l_p_max = kwargs_sys['l_p_min'], kwargs_sys['l_p_max']
    k_g_min, k_g_max = kwargs_sys['k_g_min'], kwargs_sys['k_g_max']
    k_a_min, k_a_max = kwargs_sys['k_a_min'], kwargs_sys['k_a_max']
    k_p_min, k_p_max = kwargs_sys['k_p_min'], kwargs_sys['k_p_max']
    eta_min, eta_max = kwargs_sys['eta_min'], kwargs_sys['eta_max']

    # Interpolated params and coresponding time ,
    x_cm_arr = kwargs_sys['x_cm']
    x_j_arr = kwargs_sys['x_j']
    l_a_arr = kwargs_sys['l_a']
    

    def gen_y0():

        #takes initial conditions from kwargs(data)
        x1_0 = kwargs_sys['x1_0']
        x2_0 = kwargs_sys['x2_0']
        y1_0 = kwargs_sys['y1_0']
        y2_0 = kwargs_sys['y2_0']

        return {'x1':x1_0, 'x2':x2_0, 'y1':y1_0, 'y2':y2_0}

    def gen_params():

        nu = nu_min + (nu_max - nu_min) * np.random.rand()
        m = m_min + (m_max - m_min) * np.random.rand()

        l_g = l_g_min + (l_g_max - l_g_min) * np.random.rand()
        l_p = l_p_min + (l_p_max - l_p_min) * np.random.rand()

        k_g = k_g_min + (k_g_max - k_g_min) * np.random.rand()
        k_a = k_a_min + (k_a_max - k_a_min) * np.random.rand()
        k_p = k_p_min + (k_p_max - k_p_min) * np.random.rand()
        
        eta = eta_min + (eta_max - eta_min) * np.random.rand()

        return {'nu':nu,'m':m,'l_g':l_g,'l_p':l_p,'k_g':k_g, 'k_a':k_a,'k_p':k_p, 'eta':eta}, {}, {}

        
    @jit
    def eom(xy, t, params, iparams, exparams):
        x = jnp.array([xy['x1'], xy['x2']])
        # get interpolated parameters at corresponding time
        x_cm = t_to_value_4p(x_cm_arr,t_interp,t,N_interp)
        x_j = t_to_value_4p(x_j_arr,t_interp,t,N_interp)
        l_a = t_to_value_1p(l_a_arr,t_interp,t,N_interp)

        #initialize total force
        f = total_force(x, x_j, x_cm, l_a, t, params)

        #initialize eom
        dx1 = xy['y1']
        dx2 = xy['y2']
        dy1 = 1/params['m'] * (f[1] - params['nu'] * xy['y1'])
        dy2 = 1/params['m'] * (f[2] - params['nu'] * xy['y2'])

        return {'x1':dx1, 'x2':dx2, 'y1':dy1, 'y2':dy2}

    @jit
    def loss(xy, params, iparams, exparams, targets):
        
        x1 = xy['x1']
        x2 = xy['x2']
        t_x1 = targets['x1']
        t_x2 = targets['x2']
        return jnp.mean((x1-t_x1)**2 + (x2-t_x2)**2)

    return eom, loss, gen_params, gen_y0, {}

In [76]:
N,size,ls = read_config(["l_0","c_a","k_ij","k_j","k_a","m","c_damp","n_0","delta_t_m","it_m"])
l_0, c_a, k_g0, k_p0, k_a0, m0, nu0, eta0, delta_t_m, it_m = ls
l_a0,l_p0,l_g0 = l_0, l_0, l_0

delta_t = delta_t_m * it_m
t_evals = jnp.linspace(0,2000*delta_t,2000)
N_interp = 50

x_i,x_j,x_cm,l_a = shape_input_for_adoptode(x_temp, x_cm_temp,T,50,50) #the last two variables define the cell in the grid 
t_interp, x_cm_interp = interpolate_x(x_cm,t_eval,N_interp)
t_interp, x_j_interp = interpolate_x(x_j,t_eval,N_interp)
t_interp, l_a_interp = interpolate_scalar(l_a,t_eval,N_interp)

In [77]:
kwargs_sys = { 
    'nu_min': nu0 - nu0 * 0.1,'nu_max': nu0 + nu0 * 0.1,
    'm_min': m0 - m0 * 0.1,'m_max' : m0 + m0 * 0.1,
    'l_g_min': l_g0 - l_g0 * 0.1,'l_g_max': l_g0 + l_g0 * 0.1,
    'l_p_min': l_p0 - l_p0 * 0.1,'l_p_max': l_p0 + l_p0 * 0.1,
    'k_g_min': k_g0 - k_g0 * 0.1,'k_g_max': k_g0 + k_g0 * 0.1,
    'k_p_min': k_p0 - k_p0 * 0.1,'k_p_max': k_p0 + k_p0 * 0.1,
    'k_a_min': k_a0 - k_a0 * 0.1,'k_a_max': k_a0 + k_a0 * 0.1,
    'eta_min': eta0 - eta0 * 0.1,'eta_max': eta0 + eta0 * 0.1,
    't_interp': t_interp,
    'N_interp': N_interp,
    'x_cm':x_cm_interp,
    'x_j':x_j_interp,
    'l_a':l_a_interp,
    'x1_0':x_i[0,0],
    'x2_0':x_i[0,1],
    'y1_0':(x_i[1,0]-x_i[0,0])/delta_t,
    'y2_0':(x_i[1,1]-x_i[0,1])/delta_t,
    'N_sys': 1
}
kwargs_adoptODE = {'lr':3e-2, 'epochs':200,'N_backups':5}
dataset = simple_simulation(define_system,
                                t_evals,
                                kwargs_sys,
                                kwargs_adoptODE)


In [78]:
print('The true parameters used to generate the data: ', dataset.params)
print('The initial gues of parameters for the recovery: ', dataset.params_train )

The true parameters used to generate the data:  {'nu': 14.097883056815917, 'm': 0.9028675597243122, 'l_g': 1.0576632539674002, 'l_p': 0.9799187570086381, 'k_g': 13.111669781834362, 'k_a': 9.231034699741615, 'k_p': 1.8033307539106551, 'eta': 0.5093111325918764}
The initial gues of parameters for the recovery:  {'nu': 14.97593783967928, 'm': 1.0560865405432396, 'l_g': 0.9896920767397941, 'l_p': 1.0556434480443577, 'k_g': 12.330466232269101, 'k_a': 8.805897644297735, 'k_p': 2.1763872455295767, 'eta': 0.48305348762893174}


In [79]:
_ = train_adoptODE(dataset)
print('True params: ', dataset.params)
print('Found params: ', dataset.params_train)

Epoch 000:  Loss: 3.6e-01,  Params Err.: 1.3e+00, y0 error: 0.0e+00, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 020:  Loss: 1.1e-02,  Params Err.: 1.2e+00, y0 error: 0.0e+00, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 040:  Loss: 1.1e-02,  Params Err.: 1.1e+00, y0 error: 0.0e+00, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 060:  Loss: 7.2e-03,  Params Err.: 1.1e+00, y0 error: 0.0e+00, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 080:  Loss: 5.7e-03,  Params Err.: 1.0e+00, y0 error: 0.0e+00, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 100:  Loss: 4.6e-03,  Params Err.: 9.5e-01, y0 error: 0.0e+00, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 120:  Loss: 4.0e-03,  Params Err.: 9.1e-01, y0 error: 0.0e+00, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 140:  Loss: 3.4e-03,  Params

In [80]:
nan_array = jnp.full((1, 2000), jnp.nan)
kwargs_adoptODE = {'lr':3e-2, 'epochs':200,'N_backups':5}
targets = {"x1":x_i[:,0].reshape((1,2000)),'x2':x_i[:,1].reshape((1,2000)),'y1':nan_array,'y2':nan_array}
dataset2 = dataset_adoptODE(define_system,
                                targets,
                                t_evals,
                                kwargs_sys,
                                kwargs_adoptODE)

In [81]:
_ = train_adoptODE(dataset2)

print('Found params: ', dataset2.params_train)

Epoch 000:  Loss: nan,  Params Err.: nan, y0 error: nan, Params Norm: 2.1e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 


Exception: Gradients resulted to nans. Maybe try the back_check function to see is your backward pass is instable. In that case it can help to increase the number of Backups ('N_backups') used in between time points.

In [67]:
k_a0

9.0