In [None]:
import numpy as np
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import importlib
import numpy as np
from jax import grad, jit, vmap
from jax.experimental.ode import odeint
from IPython.display import HTML

from typing import TYPE_CHECKING, Callable, Union, Optional

from VariablesClass import VariablesClass
from StructureClass import  StructureClass
from StateClass import StateClass
from EquilibriumClass import EquilibriumClass

import plot_funcs, colors, dynamics, helpers_builders, learning_funcs, file_funcs, numerical_experiments

In [None]:
colors_lst, red, custom_cmap = colors.color_scheme()
plt.rcParams['axes.prop_cycle'] = plt.cycler('color', colors_lst)

## Params

In [None]:
# --- Structure and init params ---

H, S = 5, 1  # # Hinges, # Shims per hinge

Nin = 3  # tip position in (x, y) and its angle
Nout = 3  # Fx, Fy, torque, all on tip

buckle  = jnp.ones((H, S), dtype=jnp.int32)  # initial cubkle state, per shim
# rows = jnp.array([1, 2, 3])
# cols = jnp.array([0])     # for example, first shim of each row
# buckle = buckle.at[rows, cols].set(-1)

k_type = 'Experimental'
# k_type = 'Numerical'

if k_type == 'Experimental':
    L = 0.045  # ~45mm
    translate_ratio = 1
    # L = 1  # keeping it fake bro, otherwise forces explode
    # translate_ratio = 0.045/L
else:
    L = 1
    translate_ratio=1/L

In [None]:
# --- training params --- 

# T_training = 116  # total training set time (not time to reach equilibrium during every step)
T_training = 32

alpha = 0.1  # learning rate

# desired_buckle_type = 'random'
# desired_buckle_rand_key = 169
# desired_buckle_type = 'opposite'
desired_buckle_type = 'straight'

# control also tip angle - True, else False
control_tip_angle = True
# control_tip_angle = False

# if True, fix nodes (0, 1) to zero. if Flase, just the first
# control_first_edge = False
control_first_edge = True

# tip values to buckle shims - 'BEASTAL' for the BEASTAL scheme, else 'one_to_one'
update_scheme = 'one_to_one'

# dataset_sampling = 'uniform'  # random uniform vals for x, y, angle
# dataset_sampling = 'almost flat'  # flat piece, single measurement
dataset_sampling = 'stress strain'

if dataset_sampling == 'stress strain':
    exp_start = 180*1e-3/translate_ratio*(0.99)
    distance = 140*1e-3/translate_ratio
#     exp_start = 4*L - 0.1
#     distance = 2*L
else: 
    exp_start = None
    distance = None

In [None]:
# --- k params ---

if k_type == 'Numerical':
    k_soft_uniform = 1.0
    k_stiff_uniform = 1.5
    thetas_ss_uniform = 1/2
    thresh_uniform = 1

In [None]:
# --- Equilibrium params ---

# k_stretch_ratio = 100/L
k_stretch_ratio = 10
# for equilibrium state calculation
n_steps = int(1e3)
T_eq= 0.2  # time for equilibrium claculation, [s]
print('dt=', float(T_eq/n_steps))
damping = 0.5  # something*sqrt(k*m)
mass = 1 # [kg]
calc_through_energy = False

In [None]:
# --- automatic parameters / variables ---
if k_type == 'Numerical':
    k_soft  = jnp.ones((H, S), dtype=jnp.float32) * k_soft_uniform  # (Hinges, Shims) stiff. in soft direction, per shim
    k_stiff = jnp.ones((H, S), dtype=jnp.float32) * k_stiff_uniform  # (Hinges, Shims) stiff. in stiff direction, per shim
    thetas_ss = jnp.full((H, S), thetas_ss_uniform, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    thresh = jnp.full((H, S), thresh_uniform, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    file_name = None
else:
    k_soft  = None
    k_stiff = None
    thetas_ss = jnp.full((H, S), 1.03312, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    thresh = jnp.full((H, S), 1.96257, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    file_name = 'Roee_offset3mm_dl75.txt'

if desired_buckle_type == 'random':
    key = jax.random.PRNGKey(desired_buckle_rand_key)   # seed
    desired_buckle = jax.random.randint(key, (H, S), minval=-1, maxval=2)  # note: maxval is exclusive
    desired_buckle = desired_buckle.at[desired_buckle==0].set(-1)
elif desired_buckle_type == 'opposite':
    desired_buckle = -buckle
elif desired_buckle_type == 'straight':
    desired_buckle = buckle

In [None]:
import StructureClass
importlib.reload(StructureClass)
from StructureClass import StructureClass

# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(hinges=H, shims=S, L=L)
if update_scheme == 'BEASTAL':
    Strctr._build_learning_parameters(Nin, Nout)

In [None]:
import VariablesClass
importlib.reload(VariablesClass)
from VariablesClass import VariablesClass

# --- Initiate variables ---
Variabs = VariablesClass(Strctr,
                         k_type=k_type,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         file_name = file_name
                         )

In [None]:
import SupervisorClass
importlib.reload(SupervisorClass)
from SupervisorClass import SupervisorClass

Sprvsr = SupervisorClass(Strctr, alpha, T_training, desired_buckle, control_tip_angle=control_tip_angle,
                         control_first_edge = control_first_edge, update_scheme=update_scheme)
# print('desired_tau_in_t after init ', Sprvsr.desired_tau_in_t)
Sprvsr.create_dataset(Strctr, sampling=dataset_sampling, 
    exp_start = exp_start,
    distance = distance)
# print('tip positions=', Sprvsr.tip_pos_in_t)
# print('desired buckle=', Sprvsr.desired_buckle_arr)

## One shot - choose tip position

In [None]:
importlib.reload(numerical_experiments)

a=0.01
b=8/100
tip_pos = np.array([Strctr.edges*Strctr.L*(1-a), Strctr.L*(a)])
print('tip_pos=', tip_pos)
# tip_pos = np.array([1.9, 0.01])
tip_angle = np.pi*b
# tip_angle = None

State, pos_in_t, _ = numerical_experiments.one_shot(Strctr, Variabs, Sprvsr, T_eq, n_steps, damping, mass,
                                                    calc_through_energy, buckle, tip_pos, tip_angle, init_pos=None)

k_stretch_ratio = 2 * k_stretch_ratio
# --- Initiate variables - stronger k ---
Variabs = VariablesClass(Strctr,
                         k_type=k_type,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         file_name = file_name
                         )
State, pos_in_t_1, _ = numerical_experiments.one_shot(Strctr, Variabs, Sprvsr, T_eq, n_steps, damping, mass,
                                                      calc_through_energy, buckle, tip_pos, tip_angle,
                                                      init_pos=pos_in_t[-1])

k_stretch_ratio = 2 * k_stretch_ratio
# --- Initiate variables - stronger k ---
Variabs = VariablesClass(Strctr,
                         k_type=k_type,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         file_name = file_name
                         )
State, pos_in_t_2, _ = numerical_experiments.one_shot(Strctr, Variabs, Sprvsr, T_eq, n_steps, damping, mass,
                                                      calc_through_energy, buckle, tip_pos, tip_angle,
                                                      init_pos=pos_in_t_1[-1])

k_stretch_ratio = 2 * k_stretch_ratio
# --- Initiate variables - stronger k ---
Variabs = VariablesClass(Strctr,
                         k_type=k_type,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         file_name = file_name
                         )
State, pos_in_t_3, _ = numerical_experiments.one_shot(Strctr, Variabs, Sprvsr, T_eq, n_steps, damping, mass,
                                                      calc_through_energy, buckle, tip_pos, tip_angle,
                                                      init_pos=pos_in_t_2[-1])

k_stretch_ratio = 2 * k_stretch_ratio
# --- Initiate variables - stronger k ---
Variabs = VariablesClass(Strctr,
                         k_type=k_type,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         file_name = file_name
                         )
State, pos_in_t_4, _ = numerical_experiments.one_shot(Strctr, Variabs, Sprvsr, T_eq, n_steps, damping, mass,
                                                      calc_through_energy, buckle, tip_pos, tip_angle,
                                                      init_pos=pos_in_t_3[-1])

In [None]:
importlib.reload(plot_funcs)

fig, anim = plot_funcs.animate_arm(pos_in_t_4, Strctr.L, interval_ms=100, save_path=f"arm_animation_t{0}.gif", fps=2,
                                   frames=30)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())

In [None]:
# # --- state (straight chain, unit spacing => rest lengths = 1) ---
# State = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1
# print('State_notip buckle', State.buckle_arr)

In [None]:
# import EquilibriumClass
# importlib.reload(EquilibriumClass)
# from EquilibriumClass import EquilibriumClass
# importlib.reload(dynamics)

# a=0.05
# b=1/9
# # a=0
# # b=0
# tip_pos = np.array([Strctr.edges*Strctr.L*(1-a), Strctr.L*(a)])
# # tip_pos = np.array([1.9, 0.01])
# tip_angle = np.pi*b
# # tip_angle = None

# # --- initialize, no tip movement yet
# if State.pos_arr is not None:  # calculate equilib from state of previous run, if happened
#     Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, calc_through_energy = calc_through_energy,
#                           buckle_arr=buckle, pos_arr=State.pos_arr)
# else:  # use initial flat position
#     Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, calc_through_energy = calc_through_energy,
#                           buckle_arr=buckle)
# final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr,
#                                                                          control_first_edge=Sprvsr.control_first_edge,
#                                                                          tip_pos=tip_pos, tip_angle=tip_angle)
# State._save_data(0, Strctr, final_pos, State.buckle_arr, potential_force_in_t, compute_thetas_if_missing=True,
#                  control_tip_angle = control_tip_angle)
# print('State_notip buckle', State.buckle_arr)
# plot_funcs.plot_arm(State.pos_arr , State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L,
#                     modality="measurement")
# print('edge len', Strctr.all_edge_lengths(State.pos_arr))
# print('stretch energy', Variabs.k_stretch*(Strctr.all_edge_lengths(State.pos_arr) - Strctr.rest_lengths)**2)  # stretch energy
# print('torque energy', Variabs.torque(State.theta_arr)*Strctr.L)
# print('Fx', State.Fx)

## Desired values (or ADMET stress strain)

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass

# vanilla state

State = StateClass(Variabs, Strctr, Sprvsr, buckle_arr=buckle)
State.pos_arr

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State_meas = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # measurement state, buckle defaults to +1
State_des = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # desired states a.f.o training set

## Parameter study

In [None]:
# dist_noise = 0.001
# angle_noise = 1e-2
# tip_pos = Sprvsr.tip_pos_in_t[3]

# k_strtch_vec = [1e2, 2e2, 4e2, 1e3, 2e3, 4e3, 1e4, 2e4, 4e4]
# Fx_afo_pos = np.zeros(len(k_strtch_vec))

# for i, k_stretch_ratio in enumerate(k_strtch_vec):
    
#     print('k_stretch_ratio', k_stretch_ratio)
    
#     # --- Initiate variables ---
#     Variabs = VariablesClass(Strctr,
#                              k_type = k_type,
#                              k_soft=k_soft,
#                              k_stiff=k_stiff,
#                              thetas_ss=thetas_ss,           # rest/target angles
#                              thresh=thresh,                 # threshold to buckle shims
#                              stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
#                              file_name = file_name
#                              )
    
#     Sprvsr.create_dataset(Strctr, sampling=dataset_sampling, 
#                           exp_start = exp_start,
#                           distance = distance,
#                           dist_noise = dist_noise,
#                           angle_noise = angle_noise)
    
#     State_des = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # desired states a.f.o training set
    
# #     if i == 0:
# #         Eq_des = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=Sprvsr.desired_buckle_arr)
# #     else:
# #         Eq_des = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=Sprvsr.desired_buckle_arr,
# #                                   pos_arr=final_pos)
#     Eq_des = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=Sprvsr.desired_buckle_arr)
#     final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq_des.calculate_state(Variabs, Strctr, tip_pos=tip_pos,
#                                                                                  tip_angle=Sprvsr.tip_angle_in_t[i])
    
#     State_des._save_data(i, Strctr, pos_arr=final_pos, buckle_arr=State_des.buckle_arr, Forces=potential_force_in_t,
#                          compute_thetas_if_missing=True, control_tip_angle = control_tip_angle)
#     print('State tip forces ', [State_des.Fx, State_des.Fy])
#     print('edges length', Strctr.all_edge_lengths(State_des.pos_arr))
#     print('stretch energy', Variabs.k_stretch*(Strctr.all_edge_lengths(State_des.pos_arr) - Strctr.rest_lengths)**2)  # stretch energy
#     print('torque energy', Variabs.torque(State_des.theta_arr)*Strctr.L)
#     # print('State tip torque ', State_des.tip_torque)
#     Sprvsr.set_desired(final_pos, State_des.tip_pos[-2], State_des.tip_pos[-1], i, State_des.tip_torque)
#     plot_funcs.plot_arm(State_des.pos_arr , State_des.buckle_arr, np.rad2deg(State_des.theta_arr), Strctr.L,
#                     modality='measurement')
#     plt.plot(potential_force_in_t[-1, ::2], '.')
#     plt.plot(potential_force_in_t[-1, 1::2], '.')
#     plt.show()
    
# #     tot_energies = np.zeros(int(T_eq))
# #     rot_energies = np.zeros(int(T_eq))
# #     strtch_energies = np.zeros(int(T_eq))
# #     for i in range(int(T_eq)):
# #         tot_energies[i], rot_energies[i], strtch_energies[i] = Eq_des.energy(Variabs, Strctr, pos_in_t[i])

# #     plt.plot(tot_energies)
# #     plt.plot(rot_energies)
# #     plt.legend(['tot', 'rot', 'stretch'])
# #     plt.yscale('log')
# #     plt.show()

#     Fx_afo_pos[i] = State_des.Fx
    
# # State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr, compute_thetas_if_missing=True,
# #                       control_tip_angle = control_tip_angle)
# # print('pre training configuration')
# # plot_funcs.plot_arm(State_meas.pos_arr , State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr),Strctr.L,
# #                     modality='measurement')
# plt.plot(Fx_afo_pos)
# plt.ylabel('Fx')

## Desired values (or ADMET stress strain) contd

In [None]:
# fx = 3 + 't'

In [None]:
Fx_afo_pos = np.zeros(T_training)
tip_angle = -0.004

for i, tip_pos in enumerate(Sprvsr.tip_pos_in_t): 
    
#     a=0.1
#     b=-1/7
#     tip_pos = np.array([Strctr.edges*Strctr.L*(1-a), -Strctr.L*(a)])
#     tip_angle = np.pi*b
    
    if i == 0:
        Eq_des = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=Sprvsr.desired_buckle_arr)
    else:
        Eq_des = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=Sprvsr.desired_buckle_arr,
                                  pos_arr=final_pos)
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq_des.calculate_state(Variabs, Strctr, tip_pos=tip_pos,
                                                                                 tip_angle=tip_angle)
    
    State_des._save_data(i, Strctr, pos_arr=final_pos, buckle_arr=State_des.buckle_arr, Forces=potential_force_in_t,
                         compute_thetas_if_missing=True, control_tip_angle = control_tip_angle)
    print('State tip forces ', [State_des.Fx, State_des.Fy])
    print('edge lengths ', State_des.edge_lengths)
    State_des.buckle(Variabs, Strctr, i, State_des)
    # print('State tip torque ', State_des.tip_torque)
    Sprvsr.set_desired(final_pos, np.NaN, np.NaN, i, State_des.tip_torque)
    if i%4==0:
        plot_funcs.plot_arm(State_des.pos_arr , State_des.buckle_arr, np.rad2deg(State_des.theta_arr), Strctr.L,
                        modality='measurement')
#     plt.plot(potential_force_in_t[-1, ::2], '.')
#     plt.plot(potential_force_in_t[-1, 1::2], '.')
#     plt.show()
    
#     tot_energies = np.zeros(int(T_eq))
#     rot_energies = np.zeros(int(T_eq))
#     strtch_energies = np.zeros(int(T_eq))
#     for i in range(int(T_eq)):
#         tot_energies[i], rot_energies[i], strtch_energies[i] = Eq_des.energy(Variabs, Strctr, pos_in_t[i])

#     plt.plot(tot_energies)
#     plt.plot(rot_energies)
#     plt.legend(['tot', 'rot', 'stretch'])
#     plt.yscale('log')
#     plt.show()

    Fx_afo_pos[i] = State_des.Fx
    
State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr, compute_thetas_if_missing=True,
                      control_tip_angle = control_tip_angle)
print('pre training configuration')
plot_funcs.plot_arm(State_meas.pos_arr , State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr),Strctr.L,
                    modality='measurement')

In [None]:
plot_funcs.plot_arm((State_des.pos_arr-State_des.pos_arr[-1,:]), State_des.buckle_arr, np.rad2deg(State_des.theta_arr),Strctr.L,
                    modality='measurement')

In [None]:
plt.plot(-Sprvsr.tip_pos_in_t[0,0]+Sprvsr.tip_pos_in_t[:,0], -Fx_afo_pos)

In [None]:
plt.plot(potential_force_in_t[-5:-1].T, '.')

In [None]:
potential_force_in_t[-1][-1]+potential_force_in_t[-1][-3]

In [None]:
potential_force_in_t[-1][2]+potential_force_in_t[-1][4]

In [None]:
## --- Save csv file ---

file_funcs.export_stress_strain_sim(Sprvsr, Fx_afo_pos, L, State_des.buckle_arr, filename="L=1_buckle[1 1 1 1 1] belly down.csv")

In [None]:
## --- import CSVs and plot ---

importlib.reload(file_funcs)
importlib.reload(plot_funcs)

# exp_path = "Experimental ADMET Bi-ax\\Oct4\\leave_every_time\\Roee_buckle11111_bellyDown.csv"
exp_path = "Experimental ADMET Bi-ax\\Oct3\\Roee_buckle11111_init190mm_inside150mm_bellyDown.csv"
exp_df = file_funcs.import_stress_strain_exp_and_plot(exp_path, plot=False)
sim_path = "Simulation stress strain\\Oct6\\L=1_buckle[1 1 1 1 1] belly down.csv"
sim_df = file_funcs.import_stress_strain_sim_and_plot(sim_path, plot=False)

plot_funcs.plot_compare_sim_exp_stress_strain(exp_df, sim_df, translate_ratio, Strctr.hinges)

In [None]:
# buckle_10001 = file_funcs.import_stress_strain_sim_and_plot("Simulation stress strain\\L=1_buckle[1 0 0 0 1] belly down.csv")
# buckle_10101 = file_funcs.import_stress_strain_sim_and_plot("Simulation stress strain\\L=1_buckle[1 0 1 0 1] belly down.csv")
buckle_11111 = file_funcs.import_stress_strain_sim_and_plot("Simulation stress strain\\Oct3\\L=1_buckle[1 1 1 1 1] belly up qstnmrk.csv")

# buckle_10001_tip = 2*(buckle_10001['x_tip'] - buckle_10001['x_tip'][0])/ translate_ratio
# buckle_10001_Fx = buckle_10001['Fx'] * translate_ratio/Strctr.hinges
# buckle_10101_tip = 2*(buckle_10101['x_tip'] - buckle_10101['x_tip'][0])/ translate_ratio
# buckle_10101_Fx = buckle_10101['Fx'] * translate_ratio/Strctr.hinges
buckle_11111_tip = 2*(buckle_11111['x_tip'] - buckle_11111['x_tip'][0])/ translate_ratio
buckle_11111_Fx = buckle_11111['Fx'] * translate_ratio/Strctr.hinges
# plt.plot(buckle_10001_tip, buckle_10001_tip)
# plt.plot(buckle_10101_tip, buckle_10101_Fx, '.')
plt.plot(buckle_11111_tip, buckle_11111_Fx, '.')
plt.xlabel('pos [mm]')
plt.ylabel('Fx [N]')

In [None]:
plt.plot(exp_df["Position (mm)"], exp_df["Load2 (N)"])
sim_tip = 2*(sim_df['x_tip'] - sim_df['x_tip'][0])/ translate_ratio
sim_Fx = sim_df['Fx'] * translate_ratio/Strctr.hinges
plt.plot(sim_tip, sim_Fx, '.')

## SIngle hinge stress strain

In [None]:
skip = 8

if k_type == 'Experimental':
    thetas, torques, ks, torque_of_theta, k_of_theta = file_funcs.build_torque_stiffness_from_file(file_name, savgol_window=9)
    thetas = np.concatenate((thetas[::-skip], thetas[::skip]))
    torques = np.concatenate((torques[::-skip], -torques[::-skip]))
    ks = np.concatenate((ks[::-skip], ks[::skip]))
elif k_type == 'Numerical':
    # thetas = np.concatenate((np.arange(-np.pi/2, np.pi/2, np.pi/(T_training/2)), -np.arange(-np.pi/2, np.pi/2, np.pi/(T_training/2))))
    thetas = np.concatenate((np.arange(-np.pi/2, np.pi/2, np.pi/(T_training/2)), -np.arange(-np.pi/2, np.pi/2, np.pi/(T_training/2))))
    torques = np.zeros(np.size(thetas))
    B = -np.ones((Strctr.hinges, Strctr.shims))
    for i, theta in enumerate(thetas):
        T = thetas[i, None]                 # (H,1)
        TH = Variabs.thetas_ss[:, None]     # (H,1) 
        # spring constant is position dependent
        stiff_mask = ((B == 1) & (T < TH)) | ((B == -1) & (T > -TH))  # thetas are counter-clockwise    
        k_rot_state = jnp.where(stiff_mask, Variabs.k_stiff, Variabs.k_soft)  # (H,S)
        torques[i] = jnp.sum(k_rot_state * (T - B*TH)) 
        B_nxt = np.zeros((Strctr.hinges, Strctr.shims))
        for j in range(Strctr.hinges):
            for l in range(Strctr.shims):
                if B[j, l] == 1 and T[j] < -Variabs.thresh[j]:  # buckle up since thetas are CCwise
                    B_nxt[j, l] = -1
                    print('buckled up, theta=', T[j])
                elif B[j, l] == -1 and T[j] > Variabs.thresh[j]:  # buckle down, thetas are CCwise
                    B_nxt[j, l] = 1
                    print('buckled down, theta=', T[j])
                else:
                    B_nxt[j, l] = B[j, l]
        B = B_nxt
thetas_degs = np.rad2deg(thetas)

In [None]:
plt.plot(thetas, torques)
plt.show()

In [None]:
import EquilibriumClass
importlib.reload(EquilibriumClass)
from EquilibriumClass import EquilibriumClass

torque_sim = np.zeros(np.size(thetas))
Fx_sim = np.zeros(np.size(thetas))
Fy_sim = np.zeros(np.size(thetas))
energy_sim = np.zeros(np.size(thetas))
for i, theta in enumerate(-thetas):
    print('theta', np.rad2deg(theta))
    tip_angle = -np.pi + theta
    tip_pos = Strctr.L*np.array([1, 0]) + Strctr.L*np.array([np.cos(theta), np.sin(theta)])
    # print('tip_angle', tip_angle)
    # print('tip_pos', tip_pos)
    
    # --- initialize, no tip movement yet
    if State.pos_arr is not None:  # calculate equilib from state of previous run, if happened
        Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, calc_through_energy = calc_through_energy,
                              buckle_arr=State.buckle_arr, pos_arr=State.pos_arr)
    else:  # use initial flat position
        Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, calc_through_energy = calc_through_energy,
                              buckle_arr=buckle_arr)
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=tip_pos,
                                                                             tip_angle=None)
    theta_from_dynamics = vmap(lambda h: Strctr._get_theta(final_pos, h))(jnp.arange(Strctr.hinges))
    print('theta_from_dynamics = ', theta_from_dynamics)
    if k_type == 'Experimental':
        theta_eff = State.buckle_arr[0] * theta_from_dynamics              # (H,S)
        k_rot_state = Variabs.k(theta_eff)
    else:
        stiff_mask = ((State.buckle_arr[0] == 1) & (theta_from_dynamics < Variabs.thetas_ss)) \
        | ((State.buckle_arr[0] == -1) & (theta_from_dynamics > -Variabs.thetas_ss))  # thetas are counter-clockwise    
        k_rot_state = jnp.where(stiff_mask, Variabs.k_stiff, Variabs.k_soft)  # (H,S)
    print('k=', k_rot_state)
    print('delta theta=', theta_from_dynamics - State.buckle_arr[0]*Variabs.thetas_ss)
    State._save_data(i, Strctr, final_pos, State.buckle_arr, Forces=potential_force_in_t, compute_thetas_if_missing=True,
                     control_tip_angle=control_tip_angle)
    State.buckle(Variabs, Strctr, i, State_meas)
    plot_funcs.plot_arm(State.pos_arr , State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L,
                        modality="measurement") 
    torque_sim[i] = helpers_builders.torque(tip_angle, State.Fx, State.Fy)
    Fx_sim[i] = State.Fx
    Fy_sim[i] = State.Fy
    energy_sim[i], _, _ = Eq.energy(Variabs, Strctr, final_pos)
    print('torque=', State.tip_torque)
    print('Forces=', potential_force_in_t[-1])
    print('tip angle=', tip_angle)
    
plt.plot(thetas, torques, thetas, -torque_sim)

In [None]:
plt.plot(thetas, torques)  # experimental
plt.plot(-thetas[4:-4], -Variabs.k(thetas[4:-4])*(thetas[4:-4]-Variabs.thetas_ss).reshape(-1), '--')  # energy
plt.plot(thetas[1:], -torque_sim[1:], '--')  # forces directly
# plt.plot(thetas[1:], torque_from_energy[1:], thetas[1:], torque_sim[1:], '--') 
plt.xlabel(r'$\theta$')
plt.ylabel(r'$\tau$')
# plt.legend(['Experimental', 'Numerical'])
plt.legend(['experimental', 'jax energy', 'jax forces'])
plt.ylim([-6, 6])

In [None]:
print('theta', thetas[18])
print('torque at theta', Variabs.torque(thetas[18]))
print('desiree buckle theta', Variabs.thresh)

In [None]:
import EquilibriumClass
importlib.reload(EquilibriumClass)
from EquilibriumClass import EquilibriumClass

energies = np.zeros(int(T_eq))
for i in range(int(T_eq)):
    energies[i], _, _ = Eq.energy(Variabs, Strctr, pos_in_t[i])
    
plt.plot(energies)
plt.yscale('log')

# energies_nxt, _, _ = Eq.calculate_energy_in_t(Variabs, Strctr, pos_in_t)
# plt.plot(energies_nxt)
# plt.yscale('log')

plot_funcs.plot_arm(State.pos_arr , State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L,
                    modality="measurement")

In [None]:
importlib.reload(plot_funcs)

fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=100, save_path=f"arm_animation_t{0}.gif", fps=2,
                                   frames=30)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())

In [None]:
# import EquilibriumClass
# importlib.reload(EquilibriumClass)
# from EquilibriumClass import EquilibriumClass

# # --- initialize, no tip movement yet
# if State.pos_arr is not None:  # calculate equilib from state of previous run, if happened
#     Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, calc_through_energy = calc_through_energy,
#                         buckle_arr=buckle, pos_arr=State.pos_arr)
# else:  # use initial flat position
#     Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, calc_through_energy = calc_through_energy,
#                           buckle_arr=buckle)
# final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=tip_pos,
#                                                                          tip_angle=tip_angle)
# State._save_data(0, Strctr, final_pos, State.buckle_arr, potential_force_in_t, compute_thetas_if_missing=True,
#                  control_tip_angle = control_tip_angle)
# print('State_notip buckle', State.buckle_arr)
# plot_funcs.plot_arm(State.pos_arr , State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L,
#                     modality="measurement")

## Training

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass
importlib.reload(helpers_builders)

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State_meas = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1
State_update = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1

# --- initialize, no tip movement yet
Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, calc_through_energy = calc_through_energy,
                      buckle_arr=State_meas.buckle_arr, pos_arr=State_meas.pos_arr)

State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr, compute_thetas_if_missing=True,
                      control_tip_angle = control_tip_angle)
print('pre training configuration')
plot_funcs.plot_arm(State_meas.pos_arr , State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr), Strctr.L,
                    modality="measurement")

In [None]:
importlib.reload(plot_funcs)

for t in range(Sprvsr.T):
    print('t=', t)   
    
    ## MEASUREMENT
    
    # --- equilibrium ---
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=State_meas.tip_pos,
                                                                             tip_angle=State_meas.tip_angle)
#     edge_lengths = vmap(lambda e: Strctr._get_edge_length(final_pos, e))(jnp.arange(Strctr.edges))
#     print('edge lengths', helpers_builders.numpify(edge_lengths))

    # --- save sizes and plot ---
    State_meas._save_data(t, Strctr, final_pos, State_meas.buckle_arr, potential_force_in_t, compute_thetas_if_missing=True,
                          control_tip_angle = control_tip_angle)
    plot_funcs.plot_arm(final_pos, State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr), Strctr.L,
                        modality = "measurement")
    print('Forces', potential_force_in_t[-1])
    print('Fx on tip, measurement', State_meas.Fx)
    print('Fy on tip, measurement', State_meas.Fy)
    plt.plot(potential_force_in_t[-1,:], '.')
    plt.show()
    fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                   frames=20)

    # Show inline animation (uses JavaScript/HTML5)
    HTML(anim.to_jshtml())
    print('torque on tip, measurement', State_meas.tip_torque)
    
    fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                       frames=20)
    
    # --- loss ---
    Sprvsr.calc_loss(t, State_meas.Fx, State_meas.Fy, tau=State_meas.tip_torque)
    print('desired Fx', Sprvsr.desired_Fx_in_t[t])
    print('desired Fy', Sprvsr.desired_Fy_in_t[t])
    print('loss', Sprvsr.loss)
    
    ## UPDATE
    
    if t == 0:
        Sprvsr.calc_update_tip(t, Strctr, Variabs, State_meas,
                               current_tip_angle = State_meas.tip_angle,
                               prev_tip_update_pos = np.array([Strctr.L*Strctr.edges, 0.1]),
                               prev_tip_update_angle = 0.0)
    else:
        Sprvsr.calc_update_tip(t, Strctr, Variabs, State_meas,
                               current_tip_angle = State_meas.tip_angle)
    print('update_tip', Sprvsr.tip_pos_update_in_t)
    print('update angle', Sprvsr.tip_angle_update_in_t)
    
    # --- equilibrium ---
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=State_update.tip_pos,
                                                                             tip_angle=State_update.tip_angle)

    # --- save sizes and plot ---
    State_update._save_data(t, Strctr, final_pos, State_update.buckle_arr, potential_force_in_t,
                            compute_thetas_if_missing=True, control_tip_angle = control_tip_angle)
    plot_funcs.plot_arm(final_pos, State_update.buckle_arr, np.rad2deg(State_update.theta_arr),
                        Strctr.L, modality = "update")
    print('pre buckle', State_update.buckle_arr)
    # print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    
    # --- shims buckle ---
    State_update.buckle(Variabs, Strctr, t, State_measured = State_meas)
    
    Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=helpers_builders.jaxify(State_update.buckle_arr),
                          pos_arr=helpers_builders.jaxify(State_update.pos_arr))
    print('post buckle', State_update.buckle_arr)
    # print('post buckle update', State_update.buckle_arr)
    # print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    plot_funcs.plot_arm(final_pos, State_update.buckle_arr, np.rad2deg(State_update.theta_arr),
                        Strctr.L, modality = "update")   

In [None]:
np.array([Sprvsr.desired_tau_in_t[t]-State_meas.tip_torque])

In [None]:
plt.plot(np.sum(np.abs(Sprvsr.loss_in_t**2), axis=1))
plt.xlabel('t')
plt.ylabel('Loss')

In [None]:
plt.plot(State_update.buckle_in_t[0,0,:])
plt.plot(State_update.buckle_in_t[1,0,:])
plt.plot(State_update.buckle_in_t[2,0,:])
plt.plot(State_update.buckle_in_t[3,0,:])
plt.plot(State_update.buckle_in_t[4,0,:])
plt.xlabel('t')
plt.ylabel('buckle')
plt.legend(['hinge 1', 'hinge 2', 'hinge 3', 'hinge 4', 'hinge 5'])

In [None]:
fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                   frames=20)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())

In [None]:
fontsize = 16
plt.plot(thetas[1:], torques[1:], '-', thetas[1:], torque_sim[1:], '--', linewidth=3.0)
plt.legend(['experimental', 'numerical'], fontsize=fontsize)
plt.xlabel(r'$\theta\,\left[rad\right]$', fontsize=fontsize)
plt.ylabel(r'$\tau$', fontsize=fontsize)

plt.plot(thetas[1:], energy_sim[1:], '--', linewidth=3.0)
plt.legend(['theoretical', 'numerical'], fontsize=fontsize)
plt.xlabel(r'$\theta\,\left[rad\right]$', fontsize=fontsize)
plt.ylabel(r'$E$', fontsize=fontsize)

plt.plot(thetas[0:], Fx_sim[0:], '-', linewidth=3.0)
plt.legend(['theoretical', 'numerical'], fontsize=fontsize)
plt.xlabel(r'$\theta\,\left[rad\right]$', fontsize=fontsize)
plt.ylabel(r'$Fx$', fontsize=fontsize)

plt.plot(thetas[0:], Fy_sim[0:], '-', linewidth=3.0)
plt.legend(['theoretical', 'numerical'], fontsize=fontsize)
plt.xlabel(r'$\theta\,\left[rad\right]$', fontsize=fontsize)
plt.ylabel(r'$Fy$', fontsize=fontsize)

plt.plot(potential_force_in_t[-18:,::2].T, '.')
plt.ylabel('forces x')
plt.show()
plt.plot(potential_force_in_t[-18:,1::2].T, '.')
plt.ylabel('forces y')
plt.show()

# np.shape(vel_in_t)
plt.plot(vel_in_t[-18:,:,0].T, '.')
plt.ylabel('velocities x')
plt.show()
plt.plot(vel_in_t[-18:,:,1].T, '.')
plt.ylabel('velocities y')
plt.show()

State.buckle_in_t

## Show what is static and what is not

In [None]:
# import jax
# import jax.numpy as jnp
# import equinox as eqx

# params, static = eqx.partition(Variabs, eqx.is_inexact_array)

# # 1) Quick look: print the leaves that are considered “params”
# print("Param leaves (count):", len(jax.tree_util.tree_leaves(params)))

# # 2) See shapes/dtypes of param leaves
# leaf_summary = jax.tree_util.tree_map(
#     lambda x: (getattr(x, "shape", None), getattr(x, "dtype", None)), params
# )
# print(leaf_summary)

# # 3) If you want a compact structure printout:
# print("PARAM STRUCT:", jax.tree_util.tree_structure(params))
# print("STATIC STRUCT:", jax.tree_util.tree_structure(static))
