In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax 
import jax.numpy as jnp
import networkx as nx
from types import SimpleNamespace

from inference.GA_inference_2 import infer_dynamics 
from visualizer.temporal_graph_matplotlib import animate_temporal_graph

from dynamics.system import GetDynamicalSystem
from simulation.ode_solve_engine import ODEEngine
from visualizer.visualize_dynamics import visualize_dynamics

In [None]:
# reimport everything
import importlib 
import inference.GA_inference_2
import visualizer.temporal_graph_matplotlib
import dynamics.system
import simulation.ode_solve_engine
import visualizer.visualize_dynamics

importlib.reload(inference.GA_inference_2)
importlib.reload(visualizer.temporal_graph_matplotlib)
importlib.reload(dynamics.system)
importlib.reload(simulation.ode_solve_engine)
importlib.reload(visualizer.visualize_dynamics)

from inference.GA_inference_2 import infer_dynamics 
from visualizer.temporal_graph_matplotlib import animate_temporal_graph
from dynamics.system import GetDynamicalSystem
from simulation.ode_solve_engine import ODEEngine
from visualizer.visualize_dynamics import visualize_dynamics

In [None]:
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf

# 2) point to the folder that contains your `configs/` directory
#    version_base=None matches hydra.main(version_base=None) in your script
with initialize(config_path="configs", version_base=None):
    # compose gives you exactly the same DictConfig you'd get in your @hydra.main
    cfg: DictConfig = compose(config_name="config")

In [None]:
cfg

system = GetDynamicalSystem(cfg.dynamics)
engine = ODEEngine(system, cfg.engine.ode_solver)
ys, ts = engine.run() 

In [None]:
dt = cfg.engine.ode_solver.saveat['dt']

In [None]:
pos, ori = system.unwrap_state(ys)

In [None]:
pos.shape

In [None]:
ori.shape

In [None]:
ts_broadcasted = jnp.broadcast_to(ts[:, None, None], pos.shape)[..., 0]
ts_broadcasted = ts_broadcasted[..., None]

In [None]:
vol_elt = jnp.zeros_like(ts_broadcasted)

In [None]:
swarm_full = jnp.concat([ts_broadcasted, pos, ori, vol_elt], axis=-1)
swarm_full.shape

In [None]:
# g_of_d = jnp.array([
#     0,      # output dim 0 ← grade 0
#     1,1,  # dims 1–3 ← grade 1
#     2 
# ])
g_of_d = jnp.array([
    0,
    1,1,1,
    2,2,2,
    3
])

In [None]:
def metamaterial_ext_pulling_force(applied_velocity, pulled_nodes, dt):
    # Apply a pulling force to the top nodes
    pulled_nodes = jnp.array(pulled_nodes, dtype = jnp.int32)
    time_forcing = dt
    def vel_fn(D_out, X):
        D_blank = jnp.zeros_like(D_out) # T, N, D
        # set D = 0 to the dt 
        D_blank = D_blank.at[:, :, 0].set(dt)
        D_blank = D_blank.at[:, pulled_nodes, 1:3].set(applied_velocity)
        return D_blank
    return vel_fn

In [None]:
def time_driver(dt):
    time_forcing = dt
    def vel_fn(D_out, X):
        D_blank = jnp.zeros_like(D_out) # T, N, D
        # set D = 0 to the dt 
        D_blank = D_blank.at[:, :, 0].set(dt)
        return D_blank
    return vel_fn

In [None]:
model = infer_dynamics(swarm_full,
                       g_of_d=g_of_d,
                       derivatives='savgol',
                       coupling_mode='gaussian',
                       max_poly_degree=2,
                       sparsity_alpha=.1,
                       ext_derivative_fxn=time_driver(dt),
                       )

In [None]:
dat, deriv = model.preprocess_data(swarm_full)


In [None]:
final_pred=model.fit(epochs=1000, lr  = 1e-4)

In [None]:
swarm_full.shape

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(ts, ys[:, 1], label='x')
plt.plot(ts, ys[:, 2], label='y')
plt.plot(ts, ys[:, 3], label='z')
plt.xlabel('Time')


In [None]:
# compare final_pred with deriv
objs = [1]

plt.plot(ts, deriv[:, objs, 1], 'b', label='x_deriv')
plt.plot(ts, deriv[:, objs, 2], 'g', label='y_deriv')
plt.plot(ts, deriv[:, objs, 3], 'r', label='z_deriv')

plt.plot(ts, final_pred[:, objs, 1], 'b--', label='x_pred')
plt.plot(ts, final_pred[:, objs, 2], 'g--',label='y_pred')
plt.plot(ts, final_pred[:, objs, 3], 'r--', label='z_pred')
plt.xlabel('Time')
plt.legend()
plt.show()



# try on gravitational dynamics



In [None]:
from hydra import initialize, compose
from omegaconf import DictConfig, OmegaConf

# 2) point to the folder that contains your `configs/` directory
#    version_base=None matches hydra.main(version_base=None) in your script
with initialize(config_path="configs", version_base=None):
    # compose gives you exactly the same DictConfig you'd get in your @hydra.main
    cfg: DictConfig = compose(config_name="config")

In [None]:
system = GetDynamicalSystem(cfg.dynamics)
engine = ODEEngine(system, cfg.engine.ode_solver)
ys, ts = engine.run() 
dt = cfg.engine.ode_solver.saveat['dt']
pos, ori = system.unwrap_state(ys)

In [None]:
pos.shape

In [None]:
ori = jnp.zeros_like(pos)
ts_broadcasted = jnp.broadcast_to(ts[:, None, None], pos.shape)[..., 0]
ts_broadcasted = ts_broadcasted[..., None]
vol_elt = jnp.zeros_like(ts_broadcasted)
grav_full = jnp.concat([ts_broadcasted, pos, ori, vol_elt], axis=-1)

# get the first 100 ts 
grav_full = grav_full[:100]

In [None]:
model_grav = infer_dynamics(grav_full,
                       g_of_d=g_of_d,
                       derivatives='savgol',
                       coupling_mode='gaussian',
                       max_poly_degree=2,
                       sparsity_alpha=0.0,
                       ext_derivative_fxn=time_driver(dt),
                       )

dat, deriv = model_grav.preprocess_data(grav_full)


In [None]:
final_pred_grav=model_grav.fit(epochs=10000, lr  = 1e-3)

In [None]:
model_grav.params

In [None]:
deriv.shape

In [None]:
# compare final_pred with deriv
objs = [13]
ts = deriv[:, objs, 0]

plt.plot(ts, deriv[:, objs, 1], 'b', label='x_deriv')
plt.plot(ts, deriv[:, objs, 2], 'g', label='y_deriv')
plt.plot(ts, deriv[:, objs, 3], 'r', label='z_deriv')

plt.plot(ts, final_pred_grav[:, objs, 1], 'b--', label='x_pred')
plt.plot(ts, final_pred_grav[:, objs, 2], 'g--',label='y_pred')
plt.plot(ts, final_pred_grav[:, objs, 3], 'r--', label='z_pred')
plt.xlabel('Time')
plt.legend()
plt.show()



In [None]:
model_grav.print_equation()