In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import importlib
import numpy as np
from jax import grad, jit, vmap
from jax.experimental.ode import odeint

from typing import TYPE_CHECKING, Callable, Union, Optional

from VariablesClass import VariablesClass
from StructureClass import  StructureClass
from StateClass import StateClass

import plot_funcs, colors, dynamics

In [None]:
## Params

H, S = 5, 1  # # Hinges, # Shims per hinge

# --- parameters / variables ---
k_soft  = jnp.ones((H, S), dtype=jnp.float32) * 1.0
k_stiff = jnp.ones((H, S), dtype=jnp.float32) * 4.0
thetas_ss = jnp.full((H, S), -1/2, dtype=jnp.float32)  # 33° per hinge
buckle  = jnp.ones((H, S), dtype=jnp.int32)

In [None]:
import StructureClass
importlib.reload(StructureClass)
from StructureClass import StructureClass

# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(hinges=H, shims=S, L=1)  # your StructureClass from earlier

In [None]:
import VariablesClass
importlib.reload(VariablesClass)
from VariablesClass import VariablesClass


# --- Initiate variables ---
Variabs = VariablesClass(Strctr,
    k_soft=k_soft,
    k_stiff=k_stiff,
    thetas_ss=thetas_ss,           # rest/target angles
    stretch_scale=10.0,              # k_stretch = 50 * max(k_stiff)
)

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State = StateClass(Strctr, buckle)  # buckle defaults to +1

In [None]:
# --- initial energy ---
E_total, E_rot, E_stretch = State.energy(Variabs, Strctr, State.pos_arr)
print("Total:", float(E_total))
print("Rotation:", float(E_rot))
print("Stretch:", float(E_stretch))

In [None]:
importlib.reload(plot_funcs)

# Suppose St.pos_arr has shape (7,2)
plot_funcs.plot_arm(State.pos_arr, L=Strctr.L)

In [None]:
final_pos, pos_in_t, vel_in_t, potential_force_evolution = State.calculate_state(Variabs, Strctr)

In [None]:
from IPython.display import HTML
importlib.reload(plot_funcs)

plot_funcs.plot_arm(final_pos, Strctr.L)

fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=None, fps=1, stride=12)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())

In [None]:
# importlib.reload(dynamics)

# n_coords = State.pos_arr.size                # = 2 * (H+2)

# # Helper to map (node, component) -> flat DOF index
# # component: 0 = x, 1 = y
# def dof_idx(node: int, comp: int) -> int:
#     return 2*node + comp

# # -------- imposed displacement function ----------
# # Must return a length-n_coords vector (flattened order) for any time t
# def imposed_disp_values(t: float) -> jnp.ndarray:
#     vals = jnp.zeros((n_coords,), dtype=State.pos_arr.dtype)
#     # First node pinned at (0,0) → both DOFs are zero; nothing else imposed.
#     # If you ever want a moving base, set these two entries to your function of t.
#     return vals

# # -------- external forces (optional) ----------
# def force_function(t: float) -> jnp.ndarray:
#     # No external forces; you can add tip forces here if needed
#     return jnp.zeros((n_coords,), dtype=State.pos_arr.dtype)

# # -------- masks (boolean) ----------
# # We impose node 0 at (0,0). No other fixed DOFs.
# fixed_DOFs = jnp.zeros((n_coords,), dtype=bool)

# imposed_disp_DOFs = jnp.zeros((n_coords,), dtype=bool)
# imposed_disp_DOFs = imposed_disp_DOFs.at[dof_idx(0, 0)].set(True)   # x of node 0
# imposed_disp_DOFs = imposed_disp_DOFs.at[dof_idx(0, 1)].set(True)   # y of node 0
# imposed_disp_DOFs = imposed_disp_DOFs.at[dof_idx(1, 0)].set(True)   # x of node 0
# imposed_disp_DOFs = imposed_disp_DOFs.at[dof_idx(1, 1)].set(True)   # y of node 0

# # -------- initial state (positions & velocities) ----------
# x0 = State.pos_arr.flatten()                  # start from current geometry
# v0 = jnp.zeros_like(x0)                    # start at rest
# state_0 = jnp.concatenate([x0, v0], axis=0)

# # -------- time grid ----------
# dt = 1e-3
# t0, t1, n_steps = 0.0, 1600.0, int(1/dt)
# time_points = jnp.linspace(t0, t1, n_steps)

# # -------- run dynamics ----------
# final_pos, pos_in_t, vel_in_t, potential_force_evolution = dynamics.solve_dynamics(
#     time_points,
#     state_0,
#     Variabs,
#     Strctr,
#     State,
#     force_function = None,
#     fixed_DOFs = fixed_DOFs,
#     imposed_disp_DOFs = imposed_disp_DOFs,
#     imposed_disp_values = None
# )

# # final_pos is shaped like St.pos_arr  (H+2, 2)
# # traj_pos is (T, H+2, 2); traj_vel is (T, H+2, 2)
# print("final_pos[0] (should be ~[0,0]):", final_pos[0])
# print("tip pos at end:", final_pos[-1])


In [None]:
for i in range(Strctr.edges):
    print('edge', i, 'length', Strctr._get_edge_length(final_pos, i))
    
for i in range(Strctr.hinges):
    print('hinge ', i, 'theta', Strctr._get_theta(final_pos, i))    