In [13]:
#import jax and other libraries for computation
import jax.numpy as jnp
from jax import jit
from jax.scipy.signal import convolve2d
from jax.flatten_util import ravel_pytree
from jax.experimental.ode import odeint
from jax import tree_util
import jax.random as random
import numpy as np
#for visulization
import os
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
# Set Palatino as the default font
font = {'family': 'serif', 'serif': ['Palatino'], 'size': 20}
plt.rc('font', **font)
plt.rc('text', usetex=True)
# import AdoptODE
from adoptODE import train_adoptODE, simple_simulation, dataset_adoptODE
#import the MSD mechanics
from HelperAndMechanics import *
import h5py

In [33]:

def list_h5_structure(file_path):
    '''Recursively lists all groups, datasets, and attributes in an HDF5 file'''
    
    def print_structure(name, obj):
        indent = '  ' * name.count('/')
        if isinstance(obj, h5py.Group):
            print(f"{indent}📂 Group: {name}")
            # for key, value in obj.attrs.items():
            #     print(f"{indent}  └── 🏷️  Attribute: {key} = {value}")
        # elif isinstance(obj, h5py.Dataset):
            # print(f"{indent}📄 Dataset: {name} - shape: {obj.shape}, dtype: {obj.dtype}")
    
    with h5py.File(file_path, 'r') as f:
        print(f"📁 HDF5 File: {file_path}")
        f.visititems(print_structure)

def define_MSD_BOCF(**kwargs_sys):
  disc_x, disc_y = kwargs_sys['disc_x'], kwargs_sys['disc_y']
  dx, dy = kwargs_sys['len_x'] / disc_x, kwargs_sys['len_y'] / disc_y
  N_sys = kwargs_sys['N_sys']

  def d_dx(f):
    return jnp.concatenate((jnp.zeros(
        (1, disc_x)), f[1:] - f[:-1], jnp.zeros((1, disc_x))),
                           axis=0)

  def d_dy(f):
    return jnp.concatenate((jnp.zeros(
        (disc_y, 1)), f[:, 1:] - f[:, :-1], jnp.zeros((disc_y, 1))),
                           axis=1)

  def gradient(f):  #gradient of scalar field
    gx = d_dx(f)
    gy = d_dy(f)
    return jnp.stack((gx, gy), axis=2)

  # stencil5 = np.array([1/4,1/2,-1.5,1/2,1/4])
  # distr5 = np.array([0.05,0.1,0.7,0.1,0.05])
  # kernel = np.outer(distr5, stencil5/dz**2) + np.outer(stencil5/dx**2, distr5)
  kernel = np.array([[1, 4, 1], [4, -20.0, 4], [1, 4, 1]]) / (
      dx * dy * 6)  #np.array([[0,1,0],[1,-4,1.],[0,1,0]])/(dx*dy)

  def laplace(f):  #laplace of scalar
    f_ext = jnp.concatenate((f[0:1], f, f[-1:]), axis=0)
    f_ext = jnp.concatenate((f_ext[:, 0:1], f_ext, f_ext[:, -1:]), axis=1)
    return convolve2d(f_ext, kernel, mode='valid')

  H = lambda x: jnp.heaviside(x, 0)

  missing_params = {}
  for key in kwargs_sys.keys():
    if not key in [
        'disc_x', 'disc_y', 'len_x', 'len_y', 'N_sys', 'Params_BOCF',
        'puls_amp', 'puls_size', 'puls_reps', 'puls_dist', 'puls_num'
    ] + kwargs_sys['Params_BOCF']:
      missing_params[key] = kwargs_sys[key]

  def tau_v_minus(u, ap):
    return (1 - H(u - ap['theta_v_minus'])) * ap['tau_v_minus1'] + H(
        u - ap['theta_v_minus']) * ap['tau_v_minus2']

  def tau_w_minus(u, ap):
    return ap['tau_w_minus1'] + (ap['tau_w_minus2'] - ap['tau_w_minus1']) * (
        1 + jnp.tanh(ap['k_w_minus'] * (u - ap['u_w_minus']))) / 2

  def tau_so(u, ap):
    return ap['tau_so1'] + (ap['tau_so2'] - ap['tau_so1']) * (
        1 + jnp.tanh(ap['k_so'] * (u - ap['u_so']))) / 2

  def tau_s(u, ap):
    return (1 - H(u - ap['theta_w'])
            ) * ap['tau_s1'] + H(u - ap['theta_w']) * ap['tau_s2']

  def tau_o(u, ap):
    return (1 - H(u - ap['theta_o'])
            ) * ap['tau_o1'] + H(u - ap['theta_o']) * ap['tau_o2']

  def v_inf(u, ap):
    return (1 - H(u - ap['theta_v_minus']))

  def w_inf(u, ap):
    return (1 - H(u - ap['theta_o'])) * (
        1 - u / ap['tau_w_inf']) + H(u - ap['theta_o']) * ap['w_inf_star']

  # D = kwargs_sys['D']
  u0, v0, w0, s0 = kwargs_sys['u0'], kwargs_sys['v0'], kwargs_sys[
      'w0'], kwargs_sys['s0']
  puls_amp, puls_size, puls_num, puls_dist, puls_reps = kwargs_sys[
      'puls_amp'], kwargs_sys['puls_size'], kwargs_sys['puls_num'], kwargs_sys[
          'puls_dist'], kwargs_sys['puls_reps']

  def J_fi(y, ap):
    return -(y['v'] * H(y['u'] - ap['theta_v']) * (y['u'] - ap['theta_v']) *
             (ap['u_u'] - y['u'])) / ap['tau_fi']

  def J_so(y, ap):
    return (y['u'] - ap['u_o']) * (1 - H(y['u'] - ap['theta_w'])) / tau_o(
        y['u'], ap) + H(y['u'] - ap['theta_w']) / tau_so(y['u'], ap)

  def J_si(y, ap):
    return -H(y['u'] - ap['theta_w']) * y['w'] * y['s'] / ap['tau_si']

  def du_dt(y, ap):
    return ap['D'] * laplace(y['u']) - (J_fi(y, ap) + J_so(y, ap) +
                                        J_si(y, ap))

  def dv_dt(y, ap):
    return (1 - H(y['u'] - ap['theta_v'])) * (
        v_inf(y['u'], ap) - y['v']) / tau_v_minus(
            y['u'], ap) - H(y['u'] - ap['theta_v']) * y['v'] / ap['tau_v_plus']

  def dw_dt(y, ap):
    return (1 - H(y['u'] - ap['theta_w'])) * (
        w_inf(y['u'], ap) - y['w']) / tau_w_minus(
            y['u'], ap) - H(y['u'] - ap['theta_w']) * y['w'] / ap['tau_w_plus']

  def ds_dt(y, ap):
    return (
        (1 + jnp.tanh(ap['k_s'] *
                      (y['u'] - ap['u_s']))) / 2 - y['s']) / tau_s(y['u'], ap)
                      
  def rescale_params(params):
    params_scaled = {}
    
    for key in params.keys():
        if key in kwargs_sys['Params_BOCF']:
            if kwargs_sys[key]<1:
                params_scaled[key] = params[key]
            else:
                params_scaled[key] = 10**params[key]
        
    return params_scaled
  #for MSD	   
  @jit
  def epsilon_T(u):
    return 1 - 0.9*jnp.exp(-jnp.exp(-30*(jnp.abs(u) - 0.1)))
  @jit
  def eom(y, t, params, iparams, exparams):
    all_params = {**rescale_params(params), **missing_params}
    dudt = du_dt(y, all_params)
    dvdt = dv_dt(y, all_params)
    dwdt = dw_dt(y, all_params)
    dsdt = ds_dt(y, all_params)
    dTdt = 1/12.9*epsilon_T(y['u'])*(params['k_T']*jnp.abs(y['u'])-y['T'])
    dx_dotdt = 1/12.9*1/params['m'] *  (force_field_active(y['x'],y['T'],params) + force_field_passive(y['x'],params) + force_field_struct(y['x'],y['T'],params) - y['x_dot'] * params['c_damp'])
    dxdt = 1/12.9*y['x_dot']
    return {'u': dudt, 'v': dvdt, 'w': dwdt, 's': dsdt,'T':dTdt,'x':dxdt,'x_dot':dx_dotdt}

  @jit
  def loss(ys, params, iparams, exparams, targets):
    flat_fit = ravel_pytree(ys)[0]
    flat_target = ravel_pytree(targets)[0]
    return jnp.nanmean((flat_fit - flat_target)**2)

  def gen_params():
    params = {}
    for key in kwargs_sys['Params_BOCF']:
      if kwargs_sys[key]<1:
        params[key] = kwargs_sys[key] * (0.5 + np.random.rand())
      else:
        params[key] = np.log10(kwargs_sys[key] * (0.5 + np.random.rand()))
    for key in kwargs_sys['Params_MSD']:
      params[key] = kwargs_sys[key] 
    return params, {}, {}

  def gen_y0():
    u = u0 * jnp.ones((disc_x, disc_y))
    v = v0 * jnp.ones((disc_x, disc_y))
    w = w0 * jnp.ones((disc_x, disc_y))
    s = s0 * jnp.ones((disc_x, disc_y))
    #initialize the mechanical part
    size_mech = disc_x + 2* kwargs_sys['pad'] + 1
    x_vals = np.linspace(0, size_mech-1,size_mech)
    z_vals = np.linspace(0, size_mech-1,size_mech)
    # Generate meshgrid for x and z
    x_grid, z_grid = np.meshgrid(x_vals, z_vals)
    xy_grid = jnp.array([x_grid, z_grid])
    
    y = {'u': u, 'v': v, 'w': w, 's': s,'T':jnp.zeros((disc_x, disc_y)),'x':xy_grid,'x_dot':jnp.zeros(xy_grid.shape)}
    params_true = {}
    for key in kwargs_sys['Params_BOCF']:
      if kwargs_sys[key]<1:
        params_true[key] = kwargs_sys[key]
      else:
        params_true[key] = np.log10(kwargs_sys[key])
    for key in kwargs_sys['Params_MSD']:
      params_true[key] = kwargs_sys[key] 

    solver = jit(lambda y: odeint(eom,
                                  y,
                                  np.array([0.0, puls_dist]),
                                  params_true, {}, {},
                                  atol=1e-4,
                                  rtol=1e-4))
    for i in range(puls_reps):
      mask = np.zeros((disc_x, disc_y))
      pos = np.round(
          np.random.rand(puls_num, 2) *
          np.array([disc_x - puls_size - 1, disc_y - puls_size - 1
                    ])[np.newaxis]).astype(int)
      for p in pos:
        mask[p[0]:p[0] + puls_size,
             p[1]:p[1] + puls_size] = puls_amp * np.ones(
                 (puls_size, puls_size))
      y['u'] = y['u'] + mask
      sol = solver(y)
      y = tree_util.tree_map(lambda x: x[-1], sol)
    return y

  return eom, loss, gen_params, gen_y0, {}

In [15]:
# Defining standard parameters:
# This is a smaller system with adjusted parameters to allow spiral wave to occure in the smaller simulation domain.
# For the full code with larger simulation domain is at the end of this notebook.

kwargs_sys = {
    'disc_x': 128,#512
    'disc_y': 128,#512
    'len_x': 128,
    'len_y': 128,
    'N_sys': 1,
    'u_o': 0,
    'u_u': 1.58,
    'theta_v': 0.3,
    'theta_w': 0.015,
    'theta_v_minus': 0.015,
    'theta_o': 0.006,
    'tau_v_minus1': 60,
    'tau_v_minus2': 1150,
    'tau_v_plus': 1.4506,
    'tau_w_minus1': 70,
    'tau_w_minus2': 20,
    'k_w_minus': 65,
    'u_w_minus': 0.03,
    'tau_w_plus': 280,
    'tau_fi': 0.11,
    'tau_o1': 6,
    'tau_o2': 6,
    'tau_so1': 43,
    'tau_so2': 0.2,
    'k_so': 2,
    'u_so': 0.65,
    'tau_s1': 2.7342,
    'tau_s2': 3,
    'k_s': 2.0994,
    'u_s': 0.9087,
    'tau_si': 2.8723,
    'tau_w_inf': 0.07,
    'w_inf_star': 0.94,
    'u0': 0,
    'v0': 1,
    'w0': 1,
    's0': 0,
    'D': 0.1,
    'Params_BOCF': ['u_s', 'u_so', 'k_so', 'tau_so1', 'tau_si', 'k_s', 'u_u', 'theta_v', 'D', 'tau_fi'], # These are the Parameters forgotten and recovered by training!
    'Params_MSD': ['k_T','k_ij','k_ij_pad','k_j','k_a','k_a_pad','c_a','m','c_damp','n_0','l_0','spacing'],
    'pad':10,
    'puls_amp': 1.0,   #
    'puls_size': 3,    #
    'puls_reps': 30,   # These numbers define the initialization procedure
    'puls_dist': 150,  #
    'puls_num': 20     #
}
keys_MSD =['k_T','k_ij','k_ij_pad','k_j','k_a','k_a_pad','c_a','m','c_damp','n_0','l_0','spacing']
N,size,params_true_MSD = read_config(['k_T','k_ij','k_ij_pad','k_j','k_a','k_a_pad','c_a','m','c_damp','n_0','l_0','spacing'],mode = 'chaos')
params_true_MSD = dict(zip(keys_MSD,params_true_MSD))

for key in params_true_MSD.keys():
    kwargs_sys[key] = params_true_MSD[key]

In [19]:
# Following lines set up parameters and boundaries according to the parameters selected for training with 'Params_BOCF'
params_true = {}
lower_bound = {}
for key in kwargs_sys['Params_BOCF']:
  if kwargs_sys[key]<1:
    params_true[key] = kwargs_sys[key]
    lower_bound[key] = 0
  else:
    lower_bound[key] = -2
    params_true[key] = np.log10(kwargs_sys[key])
for key in kwargs_sys['Params_MSD']:
  params_true[key] = params_true_MSD[key]
  lower_bound[key] = 0

print(params_true)
# Setting up simulation domain
N_times = 100
t_evals = jnp.linspace(0, 50, N_times)

# Specifying training properties
reset_every = 34
t_reset_idcs = tuple([
    reset_every * i
    for i in range(int(np.ceil((len(t_evals) - 1) / reset_every)))
])
kwargs_adoptODE = {
    'epochs': 100,
    'lr': 1e-2,
    't_reset_idcs': t_reset_idcs,
    'N_backups': 1,
    'lower_b': lower_bound,
    'atol': 1e-5,
    'rtol': 1e-5,
    'lr_decay':0.99
}

{'u_s': 0.9087, 'u_so': 0.65, 'k_so': np.float64(0.3010299956639812), 'tau_so1': np.float64(1.6334684555795866), 'tau_si': np.float64(0.4582297982235676), 'k_s': np.float64(0.3220951928665501), 'u_u': np.float64(0.19865708695442263), 'theta_v': 0.3, 'D': 0.1, 'tau_fi': 0.11, 'k_T': 3.0, 'k_ij': 13.0, 'k_ij_pad': 23.0, 'k_j': 2.0, 'k_a': 9.0, 'k_a_pad': 23.0, 'c_a': 10.0, 'm': 1.0, 'c_damp': 15.0, 'n_0': 0.5, 'l_0': 1.0, 'spacing': 1.0}


In [20]:
# Setting up a dataset via simulation
dataset_BOCF = simple_simulation(define_MSD_BOCF,
                            t_evals,
                            kwargs_sys,
                            kwargs_adoptODE,
                            params=params_true)

In [None]:
pad = 10
dA = compute_dA(dataset_BOCF.ys['x'][0],1)
# Function to update the plot
def update_plot(frame):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # Create 2 side-by-side subplots
    
    # Plot dA_fit
    im1 = axes[0].matshow(dataset_BOCF.ys['u'][0,frame], cmap='coolwarm')
    # im1 = axes[0].matshow(dA_rec[frame, pad:-pad, pad:-pad], cmap='coolwarm', vmin=vmin, vmax=vmax)
    axes[0].set_title(f"reconstruction")
    
    # Plot dA_sim
    im2 = axes[1].matshow(dA[frame,pad:-pad,pad:-pad], cmap='coolwarm')
    # im2 = axes[1].matshow(dA_sim[frame, pad:-pad, pad:-pad], cmap='coolwarm', vmin=vmin, vmax=vmax)
    axes[1].set_title(f"simulation")
    
    # Add colorbars
    # fig.colorbar(im1, ax=axes[0])
    # fig.colorbar(im2, ax=axes[1])

    plt.show()

# Create interactive slider
frame_slider = widgets.IntSlider(min=0, max=dataset_BOCF.ys['u'][0,:].shape[0]-1, step=1, value=0, description="Frame")

# Use interactive_output instead of interactive
out = widgets.interactive_output(update_plot, {'frame': frame_slider})

# Display slider and output
display(frame_slider, out)

IntSlider(value=0, description='Frame', max=99)

Output()

In [22]:
def define_AP(**kwargs_sys):
  disc_x, disc_y = kwargs_sys['disc_x'], kwargs_sys['disc_y']
  dx, dy = kwargs_sys['len_x'] / disc_x, kwargs_sys['len_y'] / disc_y
  N_sys = kwargs_sys['N_sys']

  def d_dx(f):
    return jnp.concatenate((jnp.zeros(
        (1, disc_x)), f[1:] - f[:-1], jnp.zeros((1, disc_x))),
                           axis=0)

  def d_dy(f):
    return jnp.concatenate((jnp.zeros(
        (disc_y, 1)), f[:, 1:] - f[:, :-1], jnp.zeros((disc_y, 1))),
                           axis=1)

  kernel = np.array([[1, 4, 1], [4, -20.0, 4], [1, 4, 1]]) / (
      dx * dy * 6)  #np.array([[0,1,0],[1,-4,1.],[0,1,0]])/(dx*dy)

  def laplace(f):  #laplace of scalar
    f_ext = jnp.concatenate((f[0:1], f, f[-1:]), axis=0)
    f_ext = jnp.concatenate((f_ext[:, 0:1], f_ext, f_ext[:, -1:]), axis=1)
    return convolve2d(f_ext, kernel, mode='valid')

  def epsilon(u,v,rp):
    return rp['eps0']+rp['mu1']*v/(u+rp['mu2'])

  def eom(y, t, params, iparams, exparams):
        p=params
        u=y['u']
        v=y['v']
        dudt = p['D']*laplace(u)-(10.0**p['logk'])*u*(u-p['a'])*(u-1) - u*v
        dvdt = epsilon(u,v,p)*(-v-(10.0**p['logk'])*u*(u-p['a']-1))
        return {'u':dudt, 'v':dvdt}
    
  def gen_params():
    return {'D':1.5,'a':0.06,'logk':1.0,'eps0':0.001,'mu1':0.2,'mu2':0.3}, {}, {}

  def loss(ys, params, iparams, exparams, targets):
    flat_fit = ys['u']
    flat_target = targets['u']
    return jnp.nanmean((flat_fit - flat_target)**2)  
        
  return eom, loss, gen_params, None, {}

In [23]:
# Use previous data to generate truth for AP model
ys_BOCF = dataset_BOCF.ys
y0 = {'u':ys_BOCF['u'][:,0], 'v':ys_BOCF['v'][:,0]}
target_ys = {'u':ys_BOCF['u'], 'v':ys_BOCF['v']}
t_evals = dataset_BOCF.t_evals/12.9 # The implementations have different timescales, BOCF is in ms while AP in a typically non-dimensionalized unit.

In [24]:
params_guess = {'D':1.17,'a':0.06,'logk':1.0,'eps0':0.001,'mu1':0.2,'mu2':0.3}
kwargs_sys = {'disc_x': 128,
            'disc_y': 128,
            'len_x': 128,
            'len_y': 128,
            'N_sys': 1}
kwargs_adoptODE = {
    'epochs': 200,
    'lr': 5e-3,
    'lr_y0':5e-3,
    'N_backups': 2,
    'atol': 1e-5,
    'rtol': 1e-5,
    'lower_b':tree_util.tree_map(lambda x: 0*x+1e-4, params_guess),
    'lower_b_y0':{'u':y0['u'], 'v':0},
    'upper_b_y0':{'u':y0['u'], 'v':10}
} # lower bound
target_ys = {'u':ys_BOCF['u'], 'v':ys_BOCF['v']}
dataset_AP = dataset_adoptODE(define_AP, target_ys, t_evals, kwargs_sys, kwargs_adoptODE, params_train = params_guess)

In [25]:
params_final, losses, errors, params_history = train_adoptODE(dataset_AP, print_interval=10, save_interval=10)

Epoch 000:  Loss: 1.5e-01,  Params Err.: nan, y0 error: nan, Params Norm: 1.6e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 010:  Loss: 1.4e-01,  Params Err.: nan, y0 error: nan, Params Norm: 1.6e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 020:  Loss: 1.1e-01,  Params Err.: nan, y0 error: nan, Params Norm: 1.6e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 030:  Loss: 8.3e-02,  Params Err.: nan, y0 error: nan, Params Norm: 1.6e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 040:  Loss: 7.4e-02,  Params Err.: nan, y0 error: nan, Params Norm: 1.6e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 050:  Loss: 6.5e-02,  Params Err.: nan, y0 error: nan, Params Norm: 1.7e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 060:  Loss: 5.7e-02,  Params Err.: nan, y0 error: nan, Params Norm: 1.7e+00, iParams Err.: 0.0e+00, iParams Norm: 0.0e+00, 
Epoch 070:  Loss: 5.2e-02,  Params Err.: nan, y0 error: nan, Params Norm: 1.7e+00, iParams

In [31]:
pad = 10
dA = compute_dA(dataset_BOCF.ys['x'][0],1)
# Function to update the plot
def update_plot(frame):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # Create 2 side-by-side subplots
    
    # Plot dA_fit
    im1 = axes[0].matshow(dataset_AP.ys['u'][0,frame], cmap='coolwarm')
    # im1 = axes[0].matshow(dA_rec[frame, pad:-pad, pad:-pad], cmap='coolwarm', vmin=vmin, vmax=vmax)
    axes[0].set_title(f"simulation")
    
    # Plot dA_sim
    im2 = axes[1].matshow(dataset_AP.ys_sol['u'][0,frame], cmap='coolwarm')
    # im2 = axes[1].matshow(dA_sim[frame, pad:-pad, pad:-pad], cmap='coolwarm', vmin=vmin, vmax=vmax)
    axes[1].set_title(f"reconstruction")
    
    # Add colorbars
    # fig.colorbar(im1, ax=axes[0])
    # fig.colorbar(im2, ax=axes[1])

    plt.show()

# Create interactive slider
frame_slider = widgets.IntSlider(min=0, max=dataset_BOCF.ys['u'][0,:].shape[0]-1, step=1, value=0, description="Frame")

# Use interactive_output instead of interactive
out = widgets.interactive_output(update_plot, {'frame': frame_slider})

# Display slider and output
display(frame_slider, out)

IntSlider(value=0, description='Frame', max=99)

Output()

In [27]:

# file_path = (f"../data/SpringMassModel/MechanicalData/MSD_BOCF")

# with h5py.File(file_path, 'w') as f:
#     group = f.create_group('datasetBOCF_MSD_AP_fit')  # Create a group instead of a dataset
#     group.create_dataset('u_sol_AP', data=dataset_AP.ys_sol['u'])
#     group.create_dataset('u_BOCF',data=dataset_AP.ys['u'])
#     group.create_dataset('v_sol', data=dataset_AP.ys_sol['v'])
#     group.create_dataset('v_BOCF',data=dataset_AP.ys['v'])
#     group.create_dataset('u_sol', data=dataset_AP.ys_sol['u'])
#     group.create_dataset('u',data=dataset_AP.ys['u'])
#     group.create_dataset('T', data=dataset_BOCF.ys['T'])
#     group.create_dataset('x',data=dataset_BOCF.ys['x'])
#     group.create_dataset('x_dot',data=dataset_BOCF.ys['x_dot'])
#     params = group.create_group("params_train_AP")  # Create a subgroup
#     for key, value in dataset_AP.params_train.items():
#         params.attrs[key] = value  # Store values as attributes
# f.close()

## Recreate with Aliev Pafielov model

In [28]:
def define_MSD_rec(**kwargs_sys):
    N_sys = kwargs_sys['N_sys']

    def gen_params():
        iparams = {'testpar':0}
        params = {key:value + kwargs_sys['par_tol']*value*np.random.uniform(-1.0, 1.0) for key,value in kwargs_sys['params_true'].items()}
        iparams = {key:jnp.array([value + kwargs_sys['par_tol']*value*np.random.uniform(-1.0, 1.0) for _ in range(kwargs_sys['N_sys'])]) for key,value in kwargs_sys['params_true'].items()}
        # iparams = {key:np.full(kwargs_sys['N_sys'],[value + kwargs_sys['par_tol']*value*np.random.uniform(-1.0, 1.0)]) for key,value in kwargs_sys['params_true'].items()}
        return  params,{}, {}
    
    def gen_y0():
        return {'u':kwargs_sys['u0'],'v':kwargs_sys['v0'],'T':kwargs_sys['T0'],'x':kwargs_sys['x0'],'x_dot':kwargs_sys['x_dot0']}
    @jit
    def kernel(spacing):
        kernel = np.array([[1, 4, 1], [4, -20.0, 4], [1, 4, 1]]) / (spacing* spacing * 6)
        return kernel
    @jit
    def laplace(f,params):  #laplace of scalar
        f_ext = jnp.concatenate((f[0:1], f, f[-1:]), axis=0)
        f_ext = jnp.concatenate((f_ext[:, 0:1], f_ext, f_ext[:, -1:]), axis=1)
        return convolve2d(f_ext, kernel(params['spacing']), mode='valid')
    @jit
    def epsilon(u,v,rp):
        return rp['epsilon_0']+rp['mu_1']*v/(u+rp['mu_2'])
    @jit
    def epsilon_T(u):
        return 1 - 0.9*jnp.exp(-jnp.exp(-30*(jnp.abs(u) - 0.1)))
    
    @jit
    def eom(y, t, params, iparams, exparams):
            par=params
            u=y['u']
            v=y['v']
            T=y['T']
            x=y['x']
            x_dot=y['x_dot']

            dudt = par['D']*laplace(u,par)-(par['k'])*u*(u-par['a'])*(u-1) - u*v
            dvdt = epsilon(u,v,par)*(-v-(par['k'])*u*(u-par['a']-1))
            dTdt = epsilon_T(u)*(par['k_T']*jnp.abs(u)-T)
            dx_dotdt = 1/par['m'] *  (force_field_active(x,T,par) + force_field_passive(x,par) + force_field_struct(x,T,par) - x_dot * par['c_damp'])
            dxdt = x_dot

            return {'u':dudt, 'v':dvdt, 'T':dTdt, 'x':zero_out_edges(dxdt), 'x_dot':zero_out_edges(dx_dotdt)}
    
    @jit
    def loss(ys, params, iparams, exparams, targets):
        # u = ys['u']
        # u_target = targets['u']
        pad = 10
        x = ys['x'][:,:,pad:-pad,pad:-pad]
        x_target = targets['x'][:,:,pad:-pad,pad:-pad]
        x_dot = ys['x_dot'][:,:,pad:-pad,pad:-pad]
        x_dot_target = targets['x_dot'][:,:,pad:-pad,pad:-pad]
        u_target = targets['u']
        u = ys['u']
        return  jnp.nanmean((x - x_target)**2 + (x_dot-x_dot_target)**2)#+ jnp.nanmean((u - u_target)**2)
            
    return eom, loss, gen_params, None, {}


In [36]:
import h5py
import numpy as np

def load_run(file_path, run):
    '''Reads the data and parameters from a saved HDF5 file'''
    
    data = {}  # Dictionary to store datasets
    params_dict = {}  # Dictionary to store parameters
    
    with h5py.File(file_path, 'r') as f:
        if run not in f:
            raise ValueError(f"Run '{run}' not found in file.")
        
        group = f[run]  # Access the group corresponding to the given 'run'
        
        # Load all datasets (convert to NumPy arrays)
        for key in ['u_sol', 'u', 'v_sol', 'v', 'T_sol', 'T', 'x_sol', 'x']:
            if key in group:
                data[key] = np.array(group[key])  # Convert dataset to NumPy array
        
        # Load parameter attributes from 'params_train' subgroup
        if "params_train" in group:
            params_group = group["params_train"]
            for key in params_group.attrs:
                params_dict[key] = params_group.attrs[key]  # Store as dictionary
    
    return data, params_dict

In [52]:
file_path = '../data/SpringMassModel/MechanicalData/MSD_BOCF'
list_h5_structure(file_path)
with h5py.File(file_path, 'r') as f:
    data = f['datasetBOCF_MSD_AP_fit']  # Access the group corresponding to the given 'run'
    print(data.keys())
# Load the data and parameters from the HDF5 file



# Accessing the datasets
    params_AP = {}
    u_AP = data['u_sol_AP'][0]
    u_sim = data['u_BOCF'][0]
    v_AP = data['v_sol'][0]
    v_sim = data['v_BOCF'][0]
    T_sim = data['T'][0]
    x_sim = data['x'][0]
    x_dot_sim = data['x_dot'][0]
    params_group = data['params_train_AP']
    for key in params_group.attrs:
        params_AP[key] = params_group.attrs[key]  # Store as dictionary
targets = {}
print(params_AP)

📁 HDF5 File: ../data/SpringMassModel/MechanicalData/MSD_BOCF
📂 Group: datasetBOCF_MSD_AP_fit
  📂 Group: datasetBOCF_MSD_AP_fit/params_train_AP
<KeysViewHDF5 ['T', 'params_train_AP', 'u', 'u_BOCF', 'u_sol', 'u_sol_AP', 'v_BOCF', 'v_sol', 'x', 'x_dot']>
{'D': np.float32(2.1702752), 'a': np.float32(0.0663672), 'eps0': np.float32(1e-04), 'logk': np.float32(1.0454719), 'mu1': np.float32(0.9758727), 'mu2': np.float32(0.1895892)}


In [None]:
kwargs_sys = {'size': 100,
                'N_sys':1,
                'par_tol': .5,
                'params_true': params_true}
kwargs_adoptODE = {'epochs': 10,'N_backups': 1,'lr':2e-3,
                'lower_b': params_low,'upper_b': params_high,
                'lr_y0':2e-3,
                # 'lr_ip': kwargs_training['lr_ip'],
                'lower_b_y0':{'u':0,'v':0,'T':0,'x':x0-x0*.1,'x_dot':x_dot0+x_dot0*.1},
                'upper_b_y0':{'u':1,'v':10,'T':['T_high'],'x':x0+x0*.1,'x_dot':x_dot0+x_dot0*.1}}

dataset_MSD = dataset_adoptODE(define_MSD_rec,
                                targets,
                                t_evals, 
                                kwargs_sys,
                                kwargs_adoptODE,
                                true_params=params_true)#,
                                #true_iparams=params_true)
return dataset_MSD,Simulation_MSD

NameError: name 'kwargs_training' is not defined