In [9]:
#import jax and other libraries for computation
import jax.numpy as jnp
from jax import jit
from jax.scipy.signal import convolve2d
from jax.flatten_util import ravel_pytree
from jax.experimental.ode import odeint
from jax import tree_util
import jax.random as random
import numpy as np
# import AdoptODE
from adoptODE import train_adoptODE, simple_simulation, dataset_adoptODE
#import the MSD mechanics
from HelperAndMechanics import *
import h5py
# Load from HDF5
with h5py.File('../data/SpringMassModel/MechanicalData/data_eta_05_uvx.h5', 'r') as f:
    v = f['v'][:100]
    u = f['u'][:100]
    T = f['T'][:100]
    x = f['x'][:100]
    f.close()

def define_MSD_sim(**kwargs_sys):

    N_sys = kwargs_sys['N_sys']

    def gen_params():
        
        return {key:value + kwargs_sys['par_tol']*value*np.random.uniform(-1.0, 1.0) for key,value in kwargs_sys['params_true'].items()}, {}, {}
    
    def gen_y0():
        return {'u':kwargs_sys['u0'],'v':kwargs_sys['v0'],'T':kwargs_sys['T0'],'x':kwargs_sys['x0'],'x_dot':kwargs_sys['x_dot0']}
    @jit
    def kernel(spacing):
        kernel = np.array([[1, 4, 1], [4, -20.0, 4], [1, 4, 1]]) / (spacing* spacing * 6)
        return kernel
    @jit
    def laplace(f,params):  #laplace of scalar
        f_ext = jnp.concatenate((f[0:1], f, f[-1:]), axis=0)
        f_ext = jnp.concatenate((f_ext[:, 0:1], f_ext, f_ext[:, -1:]), axis=1)
        return convolve2d(f_ext, kernel(params['spacing']), mode='valid')
    @jit
    def epsilon(u,v,rp):
        return rp['epsilon_0']+rp['mu_1']*v/(u+rp['mu_2'])
    @jit
    def epsilon_T(u):
        return 1 - 0.9*jnp.exp(-jnp.exp(-30*(jnp.abs(u) - 0.1)))
    @jit
    def eom(y, t, params, iparams, exparams):
            
            par=params
            u=y['u']
            v=y['v']
            T=y['T']
            x=y['x']
            x_dot=y['x_dot']

            dudt = par['D']*laplace(u,par)-(par['k'])*u*(u-par['a'])*(u-1) - u*v
            dvdt = epsilon(u,v,par)*(-v-(par['k'])*u*(u-par['a']-1))
            dTdt = epsilon_T(u)*(par['k_T']*jnp.abs(u)-T)
            dx_dotdt = 1/par['m'] *  (force_field_active(x,T,par) + force_field_passive(x,par) + force_field_struct(x,T,par) - x_dot * par['c_damp'])
            dxdt = x_dot

            return {'u':dudt, 'v':dvdt, 'T':dTdt, 'x':zero_out_edges(dxdt), 'x_dot':zero_out_edges(dx_dotdt)}
    @jit
    def loss(ys, params, iparams, exparams, targets):
        # u = ys['u']
        # u_target = targets['u']
        pad = 10
        x = ys['x'][:,:,pad:-pad,pad:-pad]
        x_target = targets['x'][:,:,pad:-pad,pad:-pad]
        x_dot = ys['x_dot'][:,:,pad:-pad,pad:-pad]
        x_dot_target = targets['x_dot'][:,:,pad:-pad,pad:-pad]
        
        return  jnp.nanmean((x - x_target)**2 + (x_dot-x_dot_target)**2)#jnp.nanmean((u - u_target)**2) +
            
    return eom, loss, gen_params, gen_y0, {}

def define_MSD_fit(**kwargs_sys):

    N_sys = kwargs_sys['N_sys']

    def gen_params():
        
        return {key:value + kwargs_sys['par_tol']*value*np.random.uniform(-1.0, 1.0) for key,value in kwargs_sys['params_true'].items()}, {}, {}
    
    def gen_y0():
        return {'u':kwargs_sys['u0'],'v':kwargs_sys['v0'],'T':kwargs_sys['T0'],'x':kwargs_sys['x0'],'x_dot':kwargs_sys['x_dot0']}
    @jit
    def kernel(spacing):
        kernel = np.array([[1, 4, 1], [4, -20.0, 4], [1, 4, 1]]) / (spacing* spacing * 6)
        return kernel
    @jit
    def laplace(f,params):  #laplace of scalar
        f_ext = jnp.concatenate((f[0:1], f, f[-1:]), axis=0)
        f_ext = jnp.concatenate((f_ext[:, 0:1], f_ext, f_ext[:, -1:]), axis=1)
        return convolve2d(f_ext, kernel(params['spacing']), mode='valid')
    @jit
    def epsilon(u,v,rp):
        return rp['epsilon_0']+rp['mu_1']*v/(u+rp['mu_2'])
    @jit
    def epsilon_T(u):
        return 1 - 0.9*jnp.exp(-jnp.exp(-30*(jnp.abs(u) - 0.1)))
    @jit
    def eom(y, t, params, iparams, exparams):
            
            par=params
            u=y['u']
            v=y['v']
            T=y['T']
            x=y['x']
            x_dot=y['x_dot']

            dudt = par['D']*laplace(u,par)-(par['k'])*u*(u-par['a'])*(u-1) - u*v
            dvdt = epsilon(u,v,par)*(-v-(par['k'])*u*(u-par['a']-1))
            dTdt = epsilon_T(u)*(par['k_T']*jnp.abs(u)-T)
            dx_dotdt = 1/par['m'] *  (force_field_active(x,T,par) + force_field_passive(x,par) + force_field_struct(x,T,par) - x_dot * par['c_damp'])
            dxdt = x_dot

            return {'u':dudt, 'v':dvdt, 'T':dTdt, 'x':zero_out_edges(dxdt), 'x_dot':zero_out_edges(dx_dotdt)}
    @jit
    def loss(ys, params, iparams, exparams, targets):
        # u = ys['u']
        # u_target = targets['u']
        pad = 10
        x = ys['x'][:,:,pad:-pad,pad:-pad]
        x_target = targets['x'][:,:,pad:-pad,pad:-pad]
        x_dot = ys['x_dot'][:,:,pad:-pad,pad:-pad]
        x_dot_target = targets['x_dot'][:,:,pad:-pad,pad:-pad]
        
        return  jnp.nanmean((x - x_target)**2 + (x_dot-x_dot_target)**2)#jnp.nanmean((u - u_target)**2) +
            
    return eom, loss, gen_params, None , {}

def get_ini_sim(u,v,T,x,x_dot,sim_indx,size = 100,pad = 10,delta_t_e = 0.08,sampling_rate = 10,length = 30):
    '''gets the initial conditions for the simulation
    if sim_indx=0 then the initial conditions are taken from raw data
    if sim_indx=1 then the initial conditions are taken from the last state of the recent simulation'''
    t_evals = np.linspace(0, delta_t_e*sampling_rate*length, length)
    if sim_indx == 0:
        #define grid and initial conditions
        u_fit = u[:length*sampling_rate][::sampling_rate,:,:]
        T_fit = T[:length*sampling_rate][::sampling_rate,:,:]
        v_fit = v[:length*sampling_rate][::sampling_rate,:,:]
        x_fit = x[:length*sampling_rate][::sampling_rate,:,:,:]
        x_dot = np.gradient(x,axis = 0)/delta_t_e
        x_dot_fit = x_dot[:length*sampling_rate][::sampling_rate,:,:,:]
        #select input state
        u0 = u_fit[0]
        v0 = v_fit[0]
        T0 = T_fit[0]
        x0 = x_fit[0]
        x_dot0 = x_dot_fit[0]
    else:
        u0 = u[-1]
        v0 = v[-1]
        T0 = T[-1]
        x0 = x[-1]
        x_dot0 = x_dot[-1]
        
    return u0, v0, T0, x0, x_dot0, t_evals

def get_ini_fit(u_sol,v_sol,T_sol,x,x_dot,sim_indx,size = 100,pad = 10,delta_t_e = 0.08,sampling_rate = 10,length=30):
    '''gets the initial conditions for the simulation
    if sim_indx=0 then u,v,T are initialized as resting stante (0)
    if sim_indx=1 then the u,v,T are taken from the last state of the recent fit'''
    t_evals = np.linspace(0, delta_t_e*sampling_rate*length, length)
    if sim_indx == 0:
        #select input state
        # dA = np.abs(compute_dA(x, 1)[:,pad:-pad,pad:-pad])
        u0 = jnp.full((1,length,100,100),.5)#dA[0]/np.max(dA[0])*np.max(u)
        v0 = jnp.full((1,length,100,100),0*np.max(v))#dA[0]/np.max(dA[0])*np.max(v)
        T0 = jnp.full((1,length,100,100),0*np.max(T))#dA[0]/np.max(dA[0])*np.max(T)
        
    else:
        u0 = u_sol[:,-1]
        u0 = np.broadcast_to(u0, (1,length, 100, 100))
        v0 = v_sol[:,-1]
        v0 = np.broadcast_to(v0, (1,length, 100, 100))
        T0 = T_sol[:,-1]
        T0 = np.broadcast_to(T0, (1,length, 100, 100))
    print(u0.shape,v0.shape,T0.shape,x.shape,x_dot.shape)
    return u0, v0, T0, x, x_dot, t_evals


In [10]:
"""
    Reads in necessary parameters from config.ini
"""
N,size,params = read_config(['D','a','k','epsilon_0','mu_1','mu_2','k_T','delta_t_e'
                             ,'k_T','k_ij','k_ij_pad','k_j','k_a','k_a_pad','c_a','m','c_damp',
                             'n_0','l_0','spacing'],mode = 'chaos')
N = T.shape[0]
keys =['D','a','k','epsilon_0','mu_1','mu_2','k_T','delta_t_e'
        ,'k_T','k_ij','k_ij_pad','k_j','k_a','k_a_pad','c_a','m','c_damp',
        'n_0','l_0','spacing']
pad = 10
tol = .8
params_true = dict(zip(keys,params))
params_low = {key: value - value*tol for key, value in params_true.items()}
params_high = {key: value + value*tol for key, value in params_true.items()}
params_high['k_ij_pad'],params_low['k_ij_pad'] = params_true['k_ij_pad']  ,params_true['k_ij_pad']
params_high['k_a_pad'],params_low['k_a_pad'] = params_true['k_a_pad']  ,params_true['k_a_pad']
x_dot = np.gradient(x, axis=0) / params_true['delta_t_e']
sol_list = []
sim_indx =0
length = 10
sampling_rate = 20

u0,v0,T0,x0,x_dot0,t_evals = get_ini_sim(u,v,T,x,x_dot,sim_indx,length = length,sampling_rate=sampling_rate)
kwargs_sys = {'size': 100,
            'spacing': 1,
            'N_sys': 1,
            'par_tol': 0,
            'params_true': params_true,
            'u0': u0,'v0': v0,'T0': T0,'x0': x0,'x_dot0': x_dot0}
kwargs_adoptODE = {'epochs': 10,'N_backups': 1,'lr': 1e-3,
                'lower_b_y0':{'u':u0,'v':v0,'T':T0,'x':x0,'x_dot':x_dot0},
                    'upper_b_y0':{'u':u0,'v':v0,'T':T0,'x':x0,'x_dot':x_dot0}}
# Setting up a dataset via simulation
Simulation_MSD = simple_simulation(define_MSD_sim,
                            t_evals,
                            kwargs_sys,
                            kwargs_adoptODE)

In [11]:
u0,v0,T0,x0,x_dot0,t_evals = get_ini_fit(Simulation_MSD.ys['u'],Simulation_MSD.ys['v'],
                                        Simulation_MSD.ys['T'], Simulation_MSD.ys['x'], Simulation_MSD.ys['x_dot'],
                                        sim_indx,length=length,sampling_rate=sampling_rate)
targets = {'u':u0,'v':v0,'T':T0,'x':x0,'x_dot':x_dot0}
        #    'x':Simulation_MSD.ys['x'][0].reshape(1,length,2,size+2*pad+1,size+2*pad+1),'x_dot':Simulation_MSD.ys['x_dot'][0].reshape(1,length,2,size+2*pad+1,size+2*pad+1)}
kwargs_sys = {'size': 100,
            'N_sys': 1,
            'par_tol': tol,
            'params_true': params_true}
kwargs_adoptODE = {'epochs': 50,'N_backups': 3,'lr': 1e-3,'lower_b': params_low,'upper_b': params_high,
                'lr_y0':2e-3,
                'lower_b_y0':{'u':0.,'v':0.,'T':0.,'x':x0[0],'x_dot':x_dot0[0]},
                'upper_b_y0':{'u':1.,'v':np.max(v),'T':np.max(T),'x':x0[0],'x_dot':x_dot0[0]}}
dataset_MSD = dataset_adoptODE(define_MSD_fit,
                                targets,
                                t_evals, 
                                kwargs_sys,
                                kwargs_adoptODE,
                                true_params=params_true)

(1, 10, 100, 100) (1, 10, 100, 100) (1, 10, 100, 100) (1, 10, 2, 121, 121) (1, 10, 2, 121, 121)


In [12]:
_ = train_adoptODE(dataset_MSD)
sol_list.append(dataset_MSD.params_train)
print('Found params: ', dataset_MSD.params_train)

Epoch 000:  Loss: 7.1e-03,  Params Err.: 1.5e+01, y0 error: nan, Params Norm: 3.0e+01, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 


ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (8 of them) had size 1, e.g. axis 0 of argument y0['T'] of type float32[1,100,100];
  * some axes (2 of them) had size 10, e.g. axis 0 of argument y0['x'] of type float32[10,2,121,121]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

# Simulated data (replace with actual data)
data1 = Simulation_MSD.ys['u'][0]  # Heatmap 1
data2 = dataset_MSD.ys_sol['u'][0]  # Heatmap 2
time_steps = data1.shape[0]  # Change based on your dataset

# Function to update the heatmaps
def update_plot(t):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    im1 = axes[0].imshow(data1[t], cmap="coolwarm", aspect="auto")
    axes[0].set_title("Ground truth")
    fig.colorbar(im1, ax=axes[0])

    im2 = axes[1].imshow(data2[t], cmap="coolwarm", aspect="auto")
    axes[1].set_title("Reconstruction")
    fig.colorbar(im2, ax=axes[1])

    plt.show()

# Create an interactive slider
slider = widgets.IntSlider(min=0, max=time_steps-1, step=1, value=0, description="Time")

# Display the interactive widget
widgets.interactive(update_plot, t=slider)

interactive(children=(IntSlider(value=0, description='Time', max=9), Output()), _dom_classes=('widget-interact…

## make new simulation

In [None]:
sim_indx = 1
u0,v0,T0,x0,x_dot0,t_evals = get_ini_sim(dataset_MSD.ys['u'][0],dataset_MSD.ys['v'][0],dataset_MSD.ys['T'][0],dataset_MSD.ys['x'][0],dataset_MSD.ys['x_dot'][0],sim_indx,length=length,sampling_rate=sampling_rate)
kwargs_sys = {'size': 100,
        'spacing': 1,
        'N_sys': 1,
        'par_tol': 0,
        'params_true': params_true,
        'u0': u0,'v0': v0,'T0': T0,'x0': x0,'x_dot0': x_dot0}
kwargs_adoptODE = {'epochs': 10,'N_backups': 1,'lr': 1e-3,
                'lower_b_y0':{'u':u0,'v':v0,'T':T0,'x':x0,'x_dot':x_dot0},
                    'upper_b_y0':{'u':u0,'v':v0,'T':T0,'x':x0,'x_dot':x_dot0}}
# Setting up a dataset via simulation
Simulation_MSD = simple_simulation(define_MSD_sim,
                            t_evals,
                            kwargs_sys,
                            kwargs_adoptODE)

In [None]:
u0,v0,T0,x0,x_dot0,t_evals = get_ini_fit(dataset_MSD.ys_sol['u'][0],dataset_MSD.ys_sol['v'][0], dataset_MSD.ys_sol['T'][0], dataset_MSD.ys['x'][0],
                                        dataset_MSD.ys['x_dot'][0], sim_indx,length=length,sampling_rate=sampling_rate)
targets = {'u':u0,'v':v0,'T':T0,'x':x0,'x_dot':x_dot0}
# {'u':u0.reshape(1,length,100,100),'v':v0.reshape(1,length,100,100),'T':T0.reshape(1,length,100,100),'x':x_sim.reshape(1,length,2,size+2*pad+1,size+2*pad+1),'x_dot':x_dot_sim.reshape(1,length,2,size+2*pad+1,size+2*pad+1)}
kwargs_sys = {'size': 100,
        'N_sys': 1,
        'par_tol': tol,
        'params_true': params_true,
        'u0': u0,'v0': v0,'T0': T0,'x0': x0,'x_dot0': x_dot0}
kwargs_adoptODE = {'epochs': 50,'N_backups': 2,'lr': 1e-3,'lower_b': params_low,'upper_b': params_high,
                'lr_y0':2e-3,
                'lower_b_y0':{'u':0.,'v':0.,'T':0.,'x':x0,'x_dot':x_dot0},
                'upper_b_y0':{'u':np.max(u),'v':np.max(v),'T':np.max(T),'x':x0,'x_dot':x_dot0}}
dataset_MSD = dataset_adoptODE(define_MSD_fit,
                            targets,
                            t_evals, 
                            kwargs_sys,
                            kwargs_adoptODE,
                            true_params=params_true)
_ = train_adoptODE(dataset_MSD)
print(dataset_MSD.params_train)

sol_list.append(dataset_MSD.params_train)
