In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax 
import jax.numpy as jnp
import networkx as nx

In [None]:
from inference.GA_inference_2 import infer_dynamics 
from visualizer.temporal_graph_matplotlib import animate_temporal_graph


In [None]:
import importlib
import inference.GA_inference_2
import visualizer.temporal_graph_matplotlib
importlib.reload(inference.GA_inference_2)
importlib.reload(visualizer.temporal_graph_matplotlib)
from inference.GA_inference_2 import infer_dynamics as GA_dynamics_inference
from visualizer.temporal_graph_matplotlib import animate_temporal_graph


In [None]:
metamat_data = np.load('data/normalized_trajectory_metamaterial.npz')
print(metamat_data.files)

In [None]:
# make an array of metamat_data 
metamat_arr = jnp.stack([metamat_data['x'], metamat_data['y'], jnp.deg2rad(metamat_data['angle'])], axis=-1)

t_f = 10 
ts = jnp.linspace(0, t_f, metamat_arr.shape[0])
ts_expanded = jnp.repeat(ts[:, None], metamat_arr.shape[1], axis=1)[:, :, None]
metamat_arr = jnp.concat([ts_expanded, metamat_arr], axis=-1)
print(metamat_arr.shape)


In [None]:
dt = t_f/metamat_arr.shape[0]

In [None]:
# nodes 8, 5 and 6 are the top ones that get pulled? 
# plot the trajectory of the top nodes
fig, ax = plt.subplots()
ax.plot(metamat_arr[:, 8, 1], metamat_arr[:, 8, 2], label='Node 8')
ax.plot(metamat_arr[:, 5, 1], metamat_arr[:, 5, 2], label='Node 5')
ax.plot(metamat_arr[:, 6, 1], metamat_arr[:, 6, 2], label='Node 6')
ax.set_xlabel('X Position')
ax.set_ylabel('Y Position')
ax.legend()
plt.show()

In [None]:
top_nodes = np.array([8, 5, 6])
top_nodes_vel = (metamat_arr[-1, top_nodes, 2] - metamat_arr[0, top_nodes, 2])/t_f
print("Top nodes velocities: ", top_nodes_vel)
pulling_velocity = np.mean(top_nodes_vel)
print("Pulling velocity: ", pulling_velocity)
pulling_velocity_vec = jnp.array([0, pulling_velocity])

In [None]:
def metamaterial_ext_pulling_force(applied_velocity, pulled_nodes, dt):
    # Apply a pulling force to the top nodes
    pulled_nodes = jnp.array(pulled_nodes, dtype = jnp.int32)
    time_forcing = dt
    def vel_fn(D_out, X):
        D_blank = jnp.zeros_like(D_out) # T, N, D
        # set D = 0 to the dt 
        D_blank = D_blank.at[:, :, 0].set(dt)
        D_blank = D_blank.at[:, pulled_nodes, 1:3].set(applied_velocity)
        return D_blank
    return vel_fn

In [None]:
meta_vel_fn = metamaterial_ext_pulling_force(pulling_velocity_vec, top_nodes, dt)

In [None]:
plt.figure(figsize=(10, 10))
nodes = [1,2,3,4]
plt.plot(metamat_arr[:, nodes, 0], metamat_arr[:, nodes, 1])
plt.plot(metamat_arr[:, nodes, 0], metamat_arr[:, nodes, 2])
plt.plot(metamat_arr[:, nodes, 0], metamat_arr[:, nodes, 3])
plt.xlabel('Time')
plt.title('Metamaterial data')
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
plt.plot(metamat_arr[:, 0, 0], metamat_arr[:, 0, 1], label='x')
plt.plot(metamat_arr[:, 0, 0], metamat_arr[:, 0, 2], label='y')
plt.plot(metamat_arr[:, 0, 0], metamat_arr[:, 0, 3], label='angle')
plt.legend()
plt.title('Metamaterial data, node 0')
plt.show()

In [None]:
g_of_d = jnp.array([
    0,      # output dim 0 ← grade 0
    1,1,  # dims 1–3 ← grade 1
    2 
])

import networkx as nx 
G = nx.Graph()
# add 9 nodes 
for ind in range(9):
    G.add_node(ind)
edges_border = [(0,4), (0,2), (2,1), (1,3), (3,5), (5,8), (8,6), (6,4)]
G.add_edges_from(edges_border)
internal_edges = [(7,4), (7,3), (7,8), (7,2)]
G.add_edges_from(internal_edges)
pos = nx.spring_layout(G)
nx.draw(G, pos)

coupling_matrix = nx.adjacency_matrix(G).todense()
print(coupling_matrix)
coupling_matrix = jnp.array(coupling_matrix)


In [None]:
model_savgol = infer_dynamics(metamat_arr, 
                       g_of_d=g_of_d,
                       derivatives='savgol',
                       coupling=coupling_matrix,
                       coupling_mode='fixed',
                       max_poly_degree=4,
                       sparsity_alpha=0.,
                       ext_derivative_fxn=meta_vel_fn,
                       learned_individual_terms=True)
model_grad = infer_dynamics(metamat_arr, 
                       g_of_d=g_of_d,
                       coupling=coupling_matrix,
                       coupling_mode='fixed',
                       max_poly_degree=4,
                       sparsity_alpha=0.,
                       ext_derivative_fxn=meta_vel_fn,
                       learned_individual_terms=True) 

_, deriv_savgol = model_savgol.preprocess_data(metamat_arr)
_, deriv_grad = model_grad.preprocess_data(metamat_arr)

fig, ax = plt.subplots()
x = metamat_arr[:, 0, 0]

# — smooth (solid) —
ax.plot(x, deriv_savgol[:,0,1], 'g',   label='x')
ax.plot(x, deriv_savgol[:,0,2], 'r',   label='y')
#ax.plot(x, deriv_savgol[:,0,3], 'b',   label='angle')

# — raw gradient (dashed) via fmt strings —
ax.plot(x, deriv_grad[:,0,1], 'g--',  label='x grad')
ax.plot(x, deriv_grad[:,0,2], 'r--',  label='y grad')
#ax.plot(x, deriv_grad[:,0,3], 'b--',  label='angle grad')

ax.legend()
plt.title('Derivatives of node 0')

In [None]:
model = GA_dynamics_inference(
            metamat_arr,                      # Time series data
            Gn=3,                             # Algebra dimension
            derivatives='savgol',             # Derivative computation method ('savgol' or 'difference')
            coupling=coupling_matrix,         # Coupling matrix (NxN matrix, or 'gaussian' or 'spline')
            coupling_mode='fixed',            # Coupling mode ('fixed' or 'learned')
            max_poly_degree=4,                # Maximum polynomial degree for polynomial terms
            sparsity_alpha=0.01,              # Sparsity regularization parameter
            ext_derivative_fxn=meta_vel_fn,   # External derivative function (pulling up on nodes)
            learned_individual_terms=True     # Whether to learn individual terms (True) or not (False)
        )

_ = model.preprocess_data(metamat_arr)
final_preds = model.fit()


In [None]:
model = infer_dynamics(metamat_arr, 
                       g_of_d=g_of_d,
                       derivatives='savgol',
                       coupling=coupling_matrix,
                       coupling_mode='fixed',
                       max_poly_degree=2,
                       sparsity_alpha=0.001,
                       ext_derivative_fxn=meta_vel_fn,
                       learned_individual_terms=True)

In [None]:
dat, deriv = model.preprocess_data(metamat_arr)
pretraining_W = np.copy(model.params['W'])
plt.matshow(jnp.abs(pretraining_W), cmap='grey');

In [None]:
pretraining_individual_terms = np.copy(model.params['individual_terms'])
plt.imshow(jnp.abs(model.params['individual_terms']), cmap='grey')

In [None]:
final_pred=model.fit(epochs=10000, lr  = 1e-4)

In [None]:
model.print_equation()

In [None]:
dt

In [None]:
final_pred[0, 0, 0]

In [None]:
type(model.fixed_K)

In [None]:
preds = model.predict(metamat_arr)

In [None]:
plt.matshow(jnp.abs(model.params['W']), cmap='grey')

In [None]:
pretraining_individual_terms - model.params['individual_terms']

In [None]:
plt.imshow(jnp.abs(model.params['individual_terms']), cmap='grey')

In [None]:
jnp.max(jnp.abs(model.params['individual_terms'])) #- jnp.abs(pretraining_individual_terms)

In [None]:
jnp.max(jnp.abs(model.params['W']))

In [None]:
preds.shape

In [None]:
final_pred.shape


In [None]:
model.derivatives.shape

In [None]:
node_num = 0
deriv_gt = model.derivatives

for node_num in range(9):
    plt.figure(figsize=(10, 10))
    plt.plot(metamat_arr[:, 0, 0], final_pred[:, node_num, 1], 'g', label='x deriv fit')
    plt.plot(metamat_arr[:, 0, 0], final_pred[:, node_num, 2], 'r', label='y deriv fit')
    plt.plot(metamat_arr[:, 0, 0], final_pred[:, node_num, 3], 'b', label='angle deriv fit')

    # plot GT
    plt.plot(metamat_arr[:, 0, 0], deriv_gt[:, node_num, 1], 'g', label='x deriv', linestyle='dashed')
    plt.plot(metamat_arr[:, 0, 0], deriv_gt[:, node_num, 2], 'r', label='y deriv', linestyle='dashed')
    plt.plot(metamat_arr[:, 0, 0], deriv_gt[:, node_num, 3], 'b', label='angle deriv', linestyle='dashed')
    plt.legend()
    plt.title(f'Metamaterial derivatives, node {node_num}')
    plt.show()

In [None]:
integrated = jnp.cumsum(final_pred, axis=0)*dt + metamat_arr[0, :, :]
print(integrated.shape)

In [None]:
node_num = 0

for node_num in range(9):
    plt.figure(figsize=(10, 10))
    plt.plot(metamat_arr[:, 0, 0], integrated[:, node_num, 1], label='x')
    plt.plot(metamat_arr[:, 0, 0], integrated[:, node_num, 2], label='y')
    plt.plot(metamat_arr[:, 0, 0], integrated[:, node_num, 3], label='angle')

    # plot GT
    plt.plot(metamat_arr[:, 0, 0], metamat_arr[:, node_num, 1], label='x GT', linestyle='dashed')
    plt.plot(metamat_arr[:, 0, 0], metamat_arr[:, node_num, 2], label='y GT', linestyle='dashed')
    plt.plot(metamat_arr[:, 0, 0], metamat_arr[:, node_num, 3], label='angle GT', linestyle='dashed')
    plt.legend()
    plt.title(f'Metamaterial data, node {node_num}')
    plt.show()

In [None]:
model.print_equation()

In [None]:
jnp.max(metamat_arr - integrated)

In [None]:
preds

In [None]:
plt.imshow(coupling_matrix, cmap='gray')

In [None]:
plt.imshow(model.fixed_K, cmap='gray')

In [None]:
model.params['W'].shape

In [None]:
pos = 
animate_temporal_graph()