In [1]:
#import jax and other libraries for computation
import jax.numpy as jnp
from jax import jit
from jax.scipy.signal import convolve2d
from jax.flatten_util import ravel_pytree
from jax.experimental.ode import odeint
from jax import tree_util
import jax.random as random
import numpy as np
#for visulization
import os
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
# Set Palatino as the default font
font = {'family': 'serif', 'serif': ['Palatino'], 'size': 20}
plt.rc('font', **font)
plt.rc('text', usetex=True)
# import AdoptODE
from adoptODE import train_adoptODE, simple_simulation, dataset_adoptODE
#import the MSD mechanics
from HelperAndMechanics import *
import h5py

In [None]:
# Load from HDF5
with h5py.File('../data/SpringMassModel/MechanicalData/data_eta_05_uvx.h5', 'r') as f:
    v = f['v'][:1000]
    u = f['u'][:1000]
    T = f['T'][:1000]
    x = f['x'][:1000]
    f.close()
print('max_v = ',np.max(v),'\nmax_u =',np.max(u),' \nmax_T = ',np.max(T))
N = T.shape[0]


max_v =  2.14554 
max_u = 0.988905  
max_T =  2.71647


In [119]:
def define_NavierCauchy(**kwargs_sys):
    N_sys = kwargs_sys['N_sys']

    def gen_params():
        return {key:value + kwargs_sys['par_tol']*value*np.random.uniform(-1.0, 1.0) for key,value in kwargs_sys['params_true'].items()}, {}, {}
    
    def gen_y0():
        return {'u':kwargs_sys['u0'],'v':kwargs_sys['v0'],'T':kwargs_sys['T0'],'U':kwargs_sys['U0'],'V':kwargs_sys['V0']}
    
    # electric model
    @jit
    def kernel(spacing):
        kernel = np.array([[1, 4, 1], [4, -20.0, 4], [1, 4, 1]]) / (spacing* spacing * 6)
        return kernel
    @jit
    def laplace(f,spacing):  #laplace of scalar
        f_ext = jnp.concatenate((f[0:1], f, f[-1:]), axis=0)
        f_ext = jnp.concatenate((f_ext[:, 0:1], f_ext, f_ext[:, -1:]), axis=1)
        return convolve2d(f_ext, kernel(spacing), mode='valid')
    @jit
    def epsilon(u,v,params):
        return params['epsilon_0']+params['mu_1']*v/(u+params['mu_2'])
    @jit
    def epsilon_T(u):
        return 1 - 0.9*jnp.exp(-jnp.exp(-30*(jnp.abs(u) - 0.1)))
    
    # # mechanic model
    # @jit
    # def stress_tensor(U,params):
    #     print('test')
    #     div_U = divergence(U)
    #     stress = params['lmbda'] * jnp.expand_dims(div_U, -1) * jnp.eye(2) + 2 * params['mu'] * laplace(U,params['dx'])
    #     print('test')
    #     return stress
    
    # # Define Laplacian operator
    # @jit
    # def gradient(T, dx):
    #     """Compute the gradient of a scalar field T(x, y) in 2D."""
    #     grad_x = jnp.gradient(T, axis=0) / dx  
    #     grad_y = jnp.gradient(T, axis=1) / dx  
    #     return jnp.stack([grad_x, grad_y], axis=0)  # Shape: (N, N, 2)
    # @jit
    # def divergence(U):
    #     return jnp.gradient(U[0], axis=0) + jnp.gradient(U[1], axis=1)
    
    # @jit
    # def laplacian(U,dx):
    #     return (
    #         -4 * U
    #         + jnp.roll(U, shift=1, axis=0) + jnp.roll(U, shift=-1, axis=0)
    #         + jnp.roll(U, shift=1, axis=1) + jnp.roll(U, shift=-1, axis=1)
    #     ) / dx**2
    # @jit
    # def stress_tensor(U, params):
    #     lmbda, mu, dx = params["lmbda"], params["mu"], params["dx"]
    #     div_U = divergence(U)
    #     lap_U = laplacian(U, dx)
    #     stress = jnp.zeros((2, 2, U.shape[1], U.shape[2]))
    #     stress = stress.at[jnp.diag_indices(2)].set(lmbda * div_U)
    #     stress += 2 * mu * jnp.expand_dims(lap_U, axis=1)
    #     return stress
    

    @jit
    def gradient(T, dx):
        """Compute the gradient of a scalar field T(x, y) in 2D."""
        grad_x = jnp.gradient(T, axis=0) / dx  
        grad_y = jnp.gradient(T, axis=1) / dx  
        return jnp.stack([grad_x, grad_y], axis=-1)  # Shape: (N, N, 2)
    # Define Laplacian operator
    def laplacian(U,dx):
        return (
            -4 * U
            + jnp.roll(U, shift=1, axis=0) + jnp.roll(U, shift=-1, axis=0)
            + jnp.roll(U, shift=1, axis=1) + jnp.roll(U, shift=-1, axis=1)
        ) / dx**2

    # Compute divergence
    def divergence(U,dx):
        return jnp.gradient(U[..., 0], axis=0)/dx + jnp.gradient(U[..., 1], axis=1)/dx

    # Compute stress tensor
    def stress(U, params):

        div_U = divergence(U,params['dx'])
        print(jnp.reshape(div_U, (100,100,1,1))*jnp.reshape(jnp.eye(2),(1,1,2,2)).shape )
        stress_temp = params['lmbda'] * jnp.expand_dims(div_U, -1) * jnp.eye(2) + 2 * params['mu'] * laplacian(U)
        print('teststress')
        return stress_temp
    #Full Model
    @jit
    def eom(y, t, params, iparams, exparams):
        par=params
        u=y['u']
        v=y['v']
        T=y['T']
        U=y['U']
        V=y['V']

        dudt = par['D']*laplace(u,par['spacing'])-(par['k'])*u*(u-par['a'])*(u-1) - u*v
        dvdt = epsilon(u,v,par)*(-v-(par['k'])*u*(u-par['a']-1))
        dTdt = epsilon_T(u)*(par['k_T']*jnp.abs(u)-T)
        dU = V # Update displacement
        print('test')
        dV = (1 / par['rho']) * (stress(U,par) + gradient(T,par['dx']))  # Update velocity
        print('test')
        return {'u':dudt, 'v':dvdt, 'T':dTdt, 'U':zero_out_edges(dU), 'V':zero_out_edges(dV)}
    @jit
    def loss(ys, params, iparams, exparams, targets):
        # u = ys['u']
        # u_target = targets['u']
        U = ys['U']
        U_target = targets['U']
        V = ys['V']
        V_target = targets['V']
        
        return  jnp.nanmean((U - U_target)**2 + (V-V_target)**2) #jnp.nanmean((u - u_target)**2) +
            
    return eom, loss, gen_params, gen_y0, {}


In [120]:
jnp.eye(2)

Array([[1., 0.],
       [0., 1.]], dtype=float32)

In [121]:

"""
    Reads in necessary parameters from config.ini
"""

N,size,params_electric = read_config(['D','a','k','epsilon_0','mu_1','mu_2','k_T','delta_t_e','spacing'],mode = 'chaos')
keys_electric =['D','a','k','epsilon_0','mu_1','mu_2','k_T','delta_t_e','spacing']

params_true_mechanic = {'lmbda': 1.0, 'mu': 1.0, 'rho': 1.0, 'dx': 1.0}

tol = 0
params_true = dict(zip(keys_electric,params_electric))|params_true_mechanic

params_low = {key: value - value*tol for key, value in params_true.items()}
params_high = {key: value + value*tol for key, value in params_true.items()}

length,sampling_rate = 5,20

u0,v0,T0,U0,V0,t_evals = u[0],v[0],T[0],jnp.zeros((100,100,2)),jnp.zeros((100,100,2)),np.linspace(0, params_true['delta_t_e']*sampling_rate*length, length)

kwargs_sys = {'size': 100,
              'spacing': 1,
              'N_sys': 1,
              'par_tol': 0,
              'params_true': params_true,
              'u0': u0,'v0': v0,'T0': T0,'U0': U0,'V0': V0}
kwargs_adoptODE = {'epochs': 10,'N_backups': 1,'lr': 1e-3}

In [122]:
# Setting up a dataset via simulation
Simulation_MSD = simple_simulation(define_NavierCauchy,
                            t_evals,
                            kwargs_sys,
                            kwargs_adoptODE)

test


TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'tuple'