In [None]:
import numpy as np
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import importlib
import numpy as np
from jax import grad, jit, vmap
from jax.experimental.ode import odeint

from typing import TYPE_CHECKING, Callable, Union, Optional

from VariablesClass import VariablesClass
from StructureClass import  StructureClass
from StateClass import StateClass
from EquilibriumClass import EquilibriumClass

import plot_funcs, colors, dynamics, helpers_builders

In [None]:
## Params

H, S = 2, 1  # # Hinges, # Shims per hinge

# --- parameters / variables ---
k_soft  = jnp.ones((H, S), dtype=jnp.float32) * 1.0  # (Hinges, Shims) spring stiffnesses in soft direction, per shim
k_stiff = jnp.ones((H, S), dtype=jnp.float32) * 4.0  # (Hinges, Shims) spring stiffnesses in stiff direction, per shim
thetas_ss = jnp.full((H, S), 1/6, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
thresh = jnp.full((H, S), 1/4, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
buckle  = jnp.ones((H, S), dtype=jnp.int32)  # initial cubkle state, per shim

k_stretch_ratio = 1e3

T = 10  # total training set time (not time to reach equilibrium during every step)

In [None]:
import StructureClass
importlib.reload(StructureClass)
from StructureClass import StructureClass

# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(hinges=H, shims=S, L=1)  # your StructureClass from earlier
# Strctr.build_dataset

In [None]:
import VariablesClass
importlib.reload(VariablesClass)
from VariablesClass import VariablesClass

# --- Initiate variables ---
Variabs = VariablesClass(Strctr,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         )

In [None]:
import SupervisorClass
importlib.reload(SupervisorClass)
from SupervisorClass import SupervisorClass

Sprvsr = SupervisorClass(T)
Sprvsr.create_dataset(Strctr, sampling='Uniform')
print('tip positions=', Sprvsr.tip_loc_in_t)

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1

In [None]:
import EquilibriumClass
importlib.reload(EquilibriumClass)
from EquilibriumClass import EquilibriumClass

importlib.reload(helpers_builders)
importlib.reload(plot_funcs)

# --- initialize, no tip movement yet
Eq = EquilibriumClass(Strctr, State.buckle_arr, State.pos_arr)

State._save_data(0, Strctr, State.pos_arr, State.buckle_arr, compute_thetas_if_missing=True)
print('pre training configuration')
plot_funcs.plot_arm(State.pos_arr , State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L)

## Training

In [None]:
for t in range(Sprvsr.T):
    print('t=', t)   
    State.position_tip(Sprvsr, t)
    final_pos, pos_in_t, vel_in_t, potential_force_evolution = Eq.calculate_state(Variabs, Strctr, tip_loc=State.tip_loc)
#     edge_lengths = vmap(lambda e: Strctr._get_edge_length(final_pos, e))(jnp.arange(Strctr.edges))
#     print('edge lengths', helpers_builders.numpify(edge_lengths))
    State._save_data(t, Strctr, final_pos, State.buckle_arr, compute_thetas_if_missing=True)
    
    print('theta', np.rad2deg(State.theta_arr))
    
    plot_funcs.plot_arm(final_pos, State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L)
    
    print('pre buckle', State.buckle_arr)
    print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    
    State.buckle(Variabs, Strctr, t)
    
    # --- state (straight chain, unit spacing => rest lengths = 1) ---
    Eq = EquilibriumClass(Strctr, State.buckle_arr, State.pos_arr)
    
    print('post buckle', State.buckle_arr)
    print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    
    plot_funcs.plot_arm(final_pos, State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L)
    

In [None]:
State.pos_arr[Strctr.edges_arr[Strctr.hinges_arr[1]]]
State.pos_arr

In [None]:
print('edge lengths', Strctr.all_edge_lengths(final_pos))
print('thetas ', np.rad2deg(Strctr.all_hinge_angles(final_pos)))
print('buckle ', State.buckle_arr)
print('pos_arr', final_pos)
print('buckl', buckle)
helpers_builders.numpify(helpers_builders._initiate_buckle(Strctr.hinges, Strctr.shims))

In [None]:
from IPython.display import HTML
importlib.reload(plot_funcs)

plot_funcs.plot_arm(final_pos, State.buckle, np.rad2deg(Strctr.all_hinge_angles(final_pos)), Strctr.L)

# fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=None, fps=1, stride=12)

# # Show inline animation (uses JavaScript/HTML5)
# HTML(anim.to_jshtml())