In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import importlib
import numpy as np
from jax import grad, jit, vmap
from jax.experimental.ode import odeint

from typing import TYPE_CHECKING, Callable, Union, Optional

from VariablesClass import VariablesClass
from StructureClass import  StructureClass
from StateClass import StateClass
from EquilibriumClass import EquilibriumClass

import plot_funcs, colors, dynamics

In [None]:
## Params

H, S = 5, 1  # # Hinges, # Shims per hinge

# --- parameters / variables ---
k_soft  = jnp.ones((H, S), dtype=jnp.float32) * 1.0  # (Hinges, Shims) spring stiffnesses in soft direction, per shim
k_stiff = jnp.ones((H, S), dtype=jnp.float32) * 4.0  # (Hinges, Shims) spring stiffnesses in stiff direction, per shim
thetas_ss = jnp.full((H, S), -1/2, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
buckle  = jnp.ones((H, S), dtype=jnp.int32)  # initial cubkle state, per shim

T = 200  # total training set time (not time to reach equilibrium during every step)

In [None]:
import StructureClass
importlib.reload(StructureClass)
from StructureClass import StructureClass

# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(hinges=H, shims=S, L=1)  # your StructureClass from earlier

In [None]:
import VariablesClass
importlib.reload(VariablesClass)
from VariablesClass import VariablesClass

# --- Initiate variables ---
Variabs = VariablesClass(Strctr,
    k_soft=k_soft,
    k_stiff=k_stiff,
    thetas_ss=thetas_ss,           # rest/target angles
    stretch_scale=10.0,              # k_stretch = 50 * max(k_stiff)
)

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State = StateClass(Variabs, Strctr, T)  # buckle defaults to +1

In [None]:
import EquilibriumClass
importlib.reload(EquilibriumClass)
from EquilibriumClass import EquilibriumClass

# --- state (straight chain, unit spacing => rest lengths = 1) ---
Eq = EquilibriumClass(Strctr, buckle)  # buckle defaults to +1

In [None]:
final_pos, pos_in_t, vel_in_t, potential_force_evolution = Eq.calculate_state(Variabs, Strctr)

In [None]:
from IPython.display import HTML
importlib.reload(plot_funcs)

plot_funcs.plot_arm(final_pos, Strctr.L)

fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=None, fps=1, stride=12)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())