In [None]:
import numpy as np
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import importlib
import numpy as np
from jax import grad, jit, vmap
from jax.experimental.ode import odeint
from IPython.display import HTML


from typing import TYPE_CHECKING, Callable, Union, Optional

from VariablesClass import VariablesClass
from StructureClass import  StructureClass
from StateClass import StateClass
from EquilibriumClass import EquilibriumClass

import plot_funcs, colors, dynamics, helpers_builders

In [None]:
## Params

H, S = 5, 1  # # Hinges, # Shims per hinge
k_type = 'Experimental'

# --- parameters / variables ---
if k_type == 'Numerical':
    k_soft  = jnp.ones((H, S), dtype=jnp.float32) * 1.0  # (Hinges, Shims) spring stiffnesses in soft direction, per shim
    k_stiff = jnp.ones((H, S), dtype=jnp.float32) * 4.0  # (Hinges, Shims) spring stiffnesses in stiff direction, per shim
    thetas_ss = jnp.full((H, S), 1/6, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    thresh = jnp.full((H, S), 1/4, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    file_name = None
else:
    k_soft  = None
    k_stiff = None
    thetas_ss = jnp.full((H, S), 1.03312, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    thresh = jnp.full((H, S), 1.96257, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    file_name = 'Roee_offset3mm_dl75.txt'
buckle  = jnp.ones((H, S), dtype=jnp.int32)  # initial cubkle state, per shim
# rows = jnp.array([0, 2, 4])
# cols = jnp.array([0, 0, 0])     # for example, first shim of each row
# buckle = buckle.at[rows, cols].set(-1)

key = jax.random.PRNGKey(169)   # seed
desired_buckle = jax.random.randint(key, (H, S), minval=-1, maxval=2)  # note: maxval is exclusive
desired_buckle = desired_buckle.at[desired_buckle==0].set(-1)

k_stretch_ratio = 1e2

T = 16  # total training set time (not time to reach equilibrium during every step)

alpha = 0.012  # learning rate

In [None]:
buckle

In [None]:
import StructureClass
importlib.reload(StructureClass)
from StructureClass import StructureClass

# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(hinges=H, shims=S, L=1)  # your StructureClass from earlier
# Strctr.build_dataset

In [None]:
import VariablesClass
importlib.reload(VariablesClass)
from VariablesClass import VariablesClass

# --- Initiate variables ---
Variabs = VariablesClass(Strctr,
                         k_type = k_type,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         file_name = file_name
                         )

In [None]:
import SupervisorClass
importlib.reload(SupervisorClass)
from SupervisorClass import SupervisorClass

Sprvsr = SupervisorClass(Strctr, alpha, T, desired_buckle)
Sprvsr.create_dataset(Strctr, sampling='Uniform')
print('tip positions=', Sprvsr.tip_pos_in_t)
print('desired buckle=', Sprvsr.desired_buckle_arr)

## Desired values

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State_meas = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1

## Leave tip

In [None]:
tip_loc = None

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State_notip = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1
print('State_notip buckle', State_notip.buckle_arr)


# --- initialize, no tip movement yet
Eq_notip = EquilibriumClass(Strctr, buckle)
final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq_notip.calculate_state(Variabs, Strctr, tip_pos=None)
State_notip._save_data(0, Strctr, final_pos, State_meas.buckle_arr, compute_thetas_if_missing=True)
print('State_notip buckle', State_notip.buckle_arr)
plot_funcs.plot_arm(State_notip.pos_arr , State_notip.buckle_arr, np.rad2deg(State_notip.theta_arr), Strctr.L,
                    modality="measurement")

In [None]:
fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{0}.gif", fps=4,
                                   stride=4)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())


In [None]:
import EquilibriumClass
importlib.reload(EquilibriumClass)
from EquilibriumClass import EquilibriumClass

for i, tip_pos in enumerate(Sprvsr.tip_pos_in_t):    
    
    if i == 0:
        Eq_des = EquilibriumClass(Strctr, Sprvsr.desired_buckle_arr)
    else:
        Eq_des = EquilibriumClass(Strctr, Sprvsr.desired_buckle_arr, final_pos)
    final_pos, pos_in_t, vel_in_t, potential_force_evolution = Eq_des.calculate_state(Variabs, Strctr, tip_pos=tip_pos)
    
    Sprvsr.set_desired(final_pos, potential_force_evolution[-1][-2], potential_force_evolution[-1][-1], i)

State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr, compute_thetas_if_missing=True)
print('pre training configuration')
plot_funcs.plot_arm(State_meas.pos_arr , State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr),Strctr.L,
                    modality='measurement')

## Training

In [None]:
# --- state (straight chain, unit spacing => rest lengths = 1) ---
State_meas = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1
State_update = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1

# --- initialize, no tip movement yet
Eq = EquilibriumClass(Strctr, State_meas.buckle_arr, State_meas.pos_arr)

State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr, compute_thetas_if_missing=True)
print('pre training configuration')
plot_funcs.plot_arm(State_meas.pos_arr , State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr), Strctr.L,
                    modality="measurement")

In [None]:
import SupervisorClass
importlib.reload(SupervisorClass)
from SupervisorClass import SupervisorClass

importlib.reload(plot_funcs)

for t in range(Sprvsr.T):
    print('t=', t)   
    
    ## MEASUREMENT
    
    # --- position tip ---
    State_meas.position_tip(Sprvsr, t)
    
    # --- equilibrium ---
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=State_meas.tip_pos)
    print('normal force on tip', potential_force_in_t[-1][-2])
#     edge_lengths = vmap(lambda e: Strctr._get_edge_length(final_pos, e))(jnp.arange(Strctr.edges))
#     print('edge lengths', helpers_builders.numpify(edge_lengths))

    # --- save sizes and plot ---
    State_meas._save_data(t, Strctr, final_pos, State_meas.buckle_arr, potential_force_in_t, compute_thetas_if_missing=True)
    plot_funcs.plot_arm(final_pos, State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr), Strctr.L,
                        modality = "measurement")
    print('pre buckle', State_meas.buckle_arr)
    print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    
    fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                       stride=10)
    
    # --- loss ---
    Sprvsr.calc_loss(State_meas.Fx, State_meas.Fy, t)
    print('loss', Sprvsr.loss)
    
    ## UPDATE
    
    if t == 0:
        Sprvsr.calc_update_tip(t, State_meas.Fx, State_meas.Fy, State_meas.tip_pos,
                               prev_pos = np.array([Strctr.L*Strctr.edges, 0.1]))
    else:
        Sprvsr.calc_update_tip(t, State_meas.Fx, State_meas.Fy, State_meas.tip_pos)
    print('update_tip', Sprvsr.tip_pos_update_in_t)
    
    # --- position tip ---
    State_update.position_tip(Sprvsr, t, "update")
    
    # --- equilibrium ---
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=State_update.tip_pos)
    print('normal force on tip', potential_force_in_t[-1][-2])

    # --- save sizes and plot ---
    State_update._save_data(t, Strctr, final_pos, State_update.buckle_arr, potential_force_evolution,
                            compute_thetas_if_missing=True)
    plot_funcs.plot_arm(final_pos, State_update.buckle_arr, np.rad2deg(State_update.theta_arr), Strctr.L, modality = "update")
    print('pre buckle', State_update.buckle_arr)
    print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    
    # --- shims buckle ---
    State_update.buckle(Variabs, Strctr, t, State_measured = State_meas)
    
    Eq = EquilibriumClass(Strctr, helpers_builders.jaxify(State_update.buckle_arr),
                          helpers_builders.jaxify(State_update.pos_arr))
    print('post buckle', State_update.buckle_arr)
    print('post buckle update', State_update.buckle_arr)
    print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    plot_funcs.plot_arm(final_pos, State_update.buckle_arr, np.rad2deg(State_update.theta_arr),
                        Strctr.L, modality = "update")
    

In [None]:
plt.plot(np.sum(np.abs(Sprvsr.loss_in_t**2), axis=1))
plt.xlabel('t')
plt.ylabel('Loss')

In [None]:
plt.plot(State_update.buckle_in_t[0,0,:])
plt.plot(State_update.buckle_in_t[1,0,:])
plt.plot(State_update.buckle_in_t[2,0,:])
plt.plot(State_update.buckle_in_t[3,0,:])
plt.xlabel('t')
plt.ylabel('buckle')
plt.legend(['hinge 1', 'hinge 2', 'hinge 3', 'hinge 4'])

In [None]:
fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                   stride=10)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())

In [None]:
import numpy as np
import jax.numpy as jnp

def build_torque_stiffness_from_file(
    path: str,
    *,
    angles_in_degrees: bool = True,
    savgol_window: int = None,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, callable, callable]:
    """
    Load (angle, torque) from text file and construct JAX-friendly interpolants.

    Returns
    -------
    theta_grid : jnp.ndarray, shape (N,)
        Sorted grid of angles (radians).
    torque_grid : jnp.ndarray, shape (N,)
        Torque samples on grid.
    k_grid : jnp.ndarray, shape (N,)
        Stiffness samples (numeric derivative of torque w.r.t. theta).
    torque_of_theta : callable
        JAX function theta -> torque (linear interpolation, clamped at ends).
    k_of_theta : callable
        JAX function theta -> stiffness (linear interpolation, clamped at ends).
    """
    # --- load ---
    data = np.loadtxt(path)                # shape (N, 2)
    theta = data[:, 0]
    tau   = data[:, 1]

    # degrees -> radians if needed
    if angles_in_degrees:
        theta = np.deg2rad(theta)

    # --- sort & unique (interp requires monotonic x) ---
    order = np.argsort(theta)
    theta = theta[order]
    tau   = tau[order]
    # collapse duplicates (if any)
    theta_u, idx = np.unique(theta, return_index=True)
    tau_u = tau[idx]

    # --- numeric derivative: k = d(tau)/d(theta) ---
    # central differences on nonuniform grid
    # (np.gradient handles non-uniform spacing if you pass x)
    k = np.gradient(tau_u, theta_u)

    # optional light smoothing of k (pure NumPy, outside JAX)
    if savgol_window is not None and savgol_window > 2 and savgol_window % 2 == 1:
        try:
            from scipy.signal import savgol_filter
            k = savgol_filter(k, window_length=savgol_window, polyorder=4, mode="interp")
            print('smoothed your k')
        except Exception:
            print('SciPy isnt available, just skip smoothing')

    # --- wrap as JAX arrays ---
    theta_grid  = jnp.asarray(theta_u, dtype=jnp.float32)
    torque_grid = jnp.asarray(tau_u,    dtype=jnp.float32)
    k_grid      = jnp.asarray(k,        dtype=jnp.float32)

    # --- clamped linear interpolators (JAX) ---
    def _clamp(x, xmin, xmax):
        return jnp.clip(x, xmin, xmax)

    def torque_of_theta(theta_query: jnp.ndarray) -> jnp.ndarray:
        th = _clamp(theta_query, theta_grid[0], theta_grid[-1])
        return jnp.interp(th, theta_grid, torque_grid)

    def k_of_theta(theta_query: jnp.ndarray) -> jnp.ndarray:
        th = _clamp(theta_query, theta_grid[0], theta_grid[-1])
        return jnp.interp(th, theta_grid, k_grid)

    return theta_grid, torque_grid, k_grid, torque_of_theta, k_of_theta


In [None]:
theta_grid, torque_grid, k_grid, torque_of_theta, k_of_theta = build_torque_stiffness_from_file('Roee_offset3mm_dl75.txt',
                                                                                                savgol_window = 9)

In [None]:
plt.plot(theta_grid, torque_grid)

In [None]:
lnspc = np.linspace(-0.9*np.pi,0.9*np.pi,200)
plt.plot(lnspc, torque_of_theta(lnspc))
plt.plot(lnspc, k_of_theta(lnspc))
plt.ylim([-10, 2])

In [None]:
theta_grid[torque_grid==-0.00761004]