In [1]:
import jax.numpy as jnp
from jax import jit
from jax.flatten_util import ravel_pytree

import numpy as np
import scipy
import matplotlib.pyplot as plt

from adoptODE import train_adoptODE, simple_simulation
from mechanics import *
from data_reading import *
import progressbar

In [3]:
l_a = 0.8876
t=0
params = {'k_g':1, 'l_g':1,'eta':.5,'k_a':1,'k_p':1,'l_p':1}\

N = 0
size = 100
mode = "chaos"

"""
    Reads in necessary parameters from config.ini
"""

print("Preparing data...")
N,size,[] = read_config([])

print("Reading data...")
x = read_vector("../data/SpringMassModel/x.csv",(N,2,size+1,size+1))
x_cm = read_vector("../data/SpringMassModel/x_cm.csv",(N,2,size+1,size+1))
T = read_scalar("../data/SpringMassModel/T.csv",(1,N,size,size))[0]

  3% |##                                                                      |

Preparing data...
Reading data...
shape of data:  (2000, 2, 101, 101)


100% |########################################################################|
  3% |##                                                                      |

shape of data:  (2000, 2, 101, 101)


100% |########################################################################|
  2% |#                                                                       |

shape of data:  (1, 2000, 100, 100)


100% |########################################################################|


In [4]:
N,size,ls = read_config(["l_0","c_a"])
l_a0,c_a = ls
x_i,x_j,x_cm_e,l_a_i = shape_input_for_adoptode(x, x_cm,T,50,50)
x_i[0:20,0].shape
t_eval = jnp.linspace(0,2000*0.016,2000)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [10]:
def define_system(**kwargs_sys):
    
    def gen_y0():
        
        return {'x':x0, 'z':z0}

    def gen_params():
        d = d_0 + np.random.rand()*d_max
        f = f_0 + np.random.rand()*f_max

        return {'d':d, 'f':f}, {}, {}
        
    @jit
    def eom(xy, t, params, iparams, exparams):
        x_cm = kwargs_sys['x_cm']
        x_j = kwargs_sys['x_j']
        l_a = kwargs_sys['l_a']
        x = jnp.array([xy['x1'], xy['x2']])

        # Interpolating splines make in more efficient
        if t>0 or cnt > cnt_pref:
            x_j_int = interpolate_spline(x_j,t)
            x_cm_int = interpolate_spline(x_cm,t)
            l_a_int = interpolate_spline(l_a,t)
        
        f = total_force(x, x_j_int, x_cm_int, l_a_int, t, params)

        dx1 = xy['y1']
        dx2 = xy['y2']
        dy1 = 1/m * (f[1] - nu * xy['y1'])
        dy2 = 1/m * (f[2] - nu * xy['y2'])
        return {'x1':dx1, 'x2':dx2, 'y1':dy1, 'y2':dy2}

    @jit
    def loss(xy, params, iparams, 
                    exparams, targets):
        x1 = xy['x1']
        x2 = xy['x2']
        t_x1 = targets['x1']
        t_x2 = targets['x2']
        return jnp.mean((x1-t_x1)**2 + (x1-t_x1)**2)

    return eom, loss, gen_params, gen_y0, {}