In [None]:
import numpy as np
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import importlib
import numpy as np
from jax import grad, jit, vmap
from jax.experimental.ode import odeint
from IPython.display import HTML

from typing import TYPE_CHECKING, Callable, Union, Optional

from VariablesClass import VariablesClass
from StructureClass import  StructureClass
from StateClass import StateClass
from EquilibriumClass import EquilibriumClass

import plot_funcs, colors, dynamics, helpers_builders, learning_funcs

In [None]:
## Params

H, S = 5, 1  # # Hinges, # Shims per hinge
k_type = 'Experimental'

Nin = 3  # tip position in (x, y) and its angle
Nout = 3  # Fx, Fy, torque, all on tip

control_tip_angle = True

update_scheme = 'one_to_one'

# dataset_sampling = 'Uniform'  # random uniform vals for x, y, angle
dataset_sampling = 'Flat'  # flat piece, single measurement

dt = 1e-3
t0, t1, n_steps = 0.0, 1600.0, int(1/dt)
time_points = jnp.linspace(t0, t1, n_steps)

k_soft_uniform = 1.0
k_stiff_uniform = 4.0
thetas_ss_uniform = 1/6
thresh_uniform = 1/4

buckle  = jnp.ones((H, S), dtype=jnp.int32)  # initial cubkle state, per shim
# rows = jnp.array([0, 2])
# cols = jnp.array([0, 0])     # for example, first shim of each row
# buckle = buckle.at[rows, cols].set(-1)

k_stretch_ratio = 1e2

T_training = 4  # total training set time (not time to reach equilibrium during every step)

alpha = 0.025  # learning rate

# desired_buckle_type = 'random'
# desired_buckle_rand_key = 169
desired_buckle_type = 'opposite'

# for equilibrium state calculation
n_steps = int(2e4)
T_eq= 1400.0  # time for equilibrium claculation
print('dt=', float(T_eq/n_steps))
damping = 5.0
mass = 4.0

In [None]:
# --- parameters / variables ---
if k_type == 'Numerical':
    k_soft  = jnp.ones((H, S), dtype=jnp.float32) * k_soft_uniform  # (Hinges, Shims) stiff. in soft direction, per shim
    k_stiff = jnp.ones((H, S), dtype=jnp.float32) * k_stiff_uniform  # (Hinges, Shims) stiff. in stiff direction, per shim
    thetas_ss = jnp.full((H, S), thetas_ss_uniform, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    thresh = jnp.full((H, S), thresh_uniform, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    file_name = None
else:
    k_soft  = None
    k_stiff = None
    thetas_ss = jnp.full((H, S), 1.03312, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    thresh = jnp.full((H, S), 1.96257, dtype=jnp.float32)  # (Hinges, Shims) rest angles per shim
    file_name = 'Roee_offset3mm_dl75.txt'

if desired_buckle_type == 'random':
    key = jax.random.PRNGKey(desired_buckle_rand_key)   # seed
    desired_buckle = jax.random.randint(key, (H, S), minval=-1, maxval=2)  # note: maxval is exclusive
    desired_buckle = desired_buckle.at[desired_buckle==0].set(-1)
elif desired_buckle_type == 'opposite':
    desired_buckle = -buckle

In [None]:
import StructureClass
importlib.reload(StructureClass)
from StructureClass import StructureClass

# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(hinges=H, shims=S, L=1)  # your StructureClass from earlier
if update_scheme == 'BEASTAL':
    Strctr._build_learning_parameters(Nin, Nout)

In [None]:
import VariablesClass
importlib.reload(VariablesClass)
from VariablesClass import VariablesClass

# --- Initiate variables ---
Variabs = VariablesClass(Strctr,
                         k_type = k_type,
                         k_soft=k_soft,
                         k_stiff=k_stiff,
                         thetas_ss=thetas_ss,           # rest/target angles
                         thresh=thresh,                 # threshold to buckle shims
                         stretch_scale=k_stretch_ratio,            # k_stretch = 50 * max(k_stiff)
                         file_name = file_name
                         )

In [None]:
import SupervisorClass
importlib.reload(SupervisorClass)
from SupervisorClass import SupervisorClass

Sprvsr = SupervisorClass(Strctr, alpha, T_training, desired_buckle, control_tip_angle=control_tip_angle,
                         update_scheme=update_scheme)
Sprvsr.create_dataset(Strctr, sampling=dataset_sampling)
print('tip positions=', Sprvsr.tip_pos_in_t)
print('desired buckle=', Sprvsr.desired_buckle_arr)

## Desired values

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State_meas = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1

## One shot - choose tip position

In [None]:
# import EquilibriumClass
# importlib.reload(EquilibriumClass)
# from EquilibriumClass import EquilibriumClass

# a=0.15
# b=1/11
# # tip_pos = None
# tip_pos = np.array([Strctr.edges*Strctr.L*(1-a), Strctr.L*(a)])

# tip_angle = np.pi*b

# # --- state (straight chain, unit spacing => rest lengths = 1) ---
# State = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1
# print('State_notip buckle', State.buckle_arr)

# # --- initialize, no tip movement yet
# Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=buckle)
# final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=tip_pos,
#                                                                          tip_angle=tip_angle)
# State._save_data(0, Strctr, final_pos, State.buckle_arr, potential_force_in_t, compute_thetas_if_missing=True)
# print('State_notip buckle', State.buckle_arr)
# plot_funcs.plot_arm(State.pos_arr , State.buckle_arr, np.rad2deg(State.theta_arr), Strctr.L,
#                     modality="measurement")

In [None]:
# Energy_vec = np.zeros((3, n_steps))
# for t in range(n_steps):
#     Energy_vec[:, t] = Eq.energy(Variabs, Strctr, pos_in_t[t])
# plt.plot(Energy_vec.T)

# plt.plot(potential_force_in_t[2500:,:], '.')
# plt.show()

In [None]:
# importlib.reload(plot_funcs)

# fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=100, save_path=f"arm_animation_t{0}.gif", fps=2,
#                                    frames=30)

# # Show inline animation (uses JavaScript/HTML5)
# HTML(anim.to_jshtml())

## Desired values contd

In [None]:
for i, tip_pos in enumerate(Sprvsr.tip_pos_in_t):    
    
    if i == 0:
        Eq_des = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=Sprvsr.desired_buckle_arr)
    else:
        Eq_des = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=Sprvsr.desired_buckle_arr,
                                  pos_arr=final_pos)
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq_des.calculate_state(Variabs, Strctr, tip_pos=tip_pos,
                                                                                      tip_angle=Sprvsr.tip_angle_in_t[i])
    
    Sprvsr.set_desired(final_pos, potential_force_in_t[-1][-2], potential_force_in_t[-1][-1], i)

State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr, compute_thetas_if_missing=True)
print('pre training configuration')
plot_funcs.plot_arm(State_meas.pos_arr , State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr),Strctr.L,
                    modality='measurement')

## Training

In [None]:
import StateClass
importlib.reload(StateClass)
from StateClass import StateClass
importlib.reload(helpers_builders)

# --- state (straight chain, unit spacing => rest lengths = 1) ---
State_meas = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1
State_update = StateClass(Variabs, Strctr, Sprvsr, buckle_arr = buckle)  # buckle defaults to +1

# --- initialize, no tip movement yet
Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=State_meas.buckle_arr, pos_arr=State_meas.pos_arr)

State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr, compute_thetas_if_missing=True)
print('pre training configuration')
plot_funcs.plot_arm(State_meas.pos_arr , State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr), Strctr.L,
                    modality="measurement")

In [None]:
import SupervisorClass
importlib.reload(SupervisorClass)
from SupervisorClass import SupervisorClass

In [None]:
importlib.reload(plot_funcs)

for t in range(Sprvsr.T):
    print('t=', t)   
    
    ## MEASUREMENT
    
    # --- position tip ---
    State_meas.position_tip(Sprvsr, t)
    
    # --- equilibrium ---
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=State_meas.tip_pos,
                                                                             tip_angle=State_meas.tip_angle)
#     edge_lengths = vmap(lambda e: Strctr._get_edge_length(final_pos, e))(jnp.arange(Strctr.edges))
#     print('edge lengths', helpers_builders.numpify(edge_lengths))

    # --- save sizes and plot ---
    State_meas._save_data(t, Strctr, final_pos, State_meas.buckle_arr, potential_force_in_t, compute_thetas_if_missing=True)
    plot_funcs.plot_arm(final_pos, State_meas.buckle_arr, np.rad2deg(State_meas.theta_arr), Strctr.L,
                        modality = "measurement")
    print('Forces', potential_force_in_t[-1])
    print('Fx on tip, measurement', State_meas.Fx)
    print('Fy on tip, measurement', State_meas.Fy)
    plt.plot(potential_force_in_t[-1,:], '.')
    plt.show()
    fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                   frames=20)

    # Show inline animation (uses JavaScript/HTML5)
    HTML(anim.to_jshtml())
    print('torque on tip, measurement', State_meas.tip_torque)
    
    fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                       frames=20)
    
    # --- loss ---
    Sprvsr.calc_loss(State_meas.Fx, State_meas.Fy, t)
    print('desired Fx', Sprvsr.desired_Fx_in_t[t])
    print('desired Fy', Sprvsr.desired_Fy_in_t[t])
    print('loss', Sprvsr.loss)
    
    ## UPDATE
    
    if t == 0:
        Sprvsr.calc_update_tip(t, Strctr, Variabs, State_meas,
                               current_tip_angle = State_meas.tip_angle,
                               prev_tip_update_pos = np.array([Strctr.L*Strctr.edges, 0.1]),
                               prev_tip_update_angle = 0.0)
    else:
        Sprvsr.calc_update_tip(t, Strctr, Variabs, State_meas,
                               current_tip_angle = State_meas.tip_angle)
    print('update_tip', Sprvsr.tip_pos_update_in_t)
    print('update angle', Sprvsr.tip_angle_update_in_t)
    
    
    # --- position tip ---
    State_update.position_tip(Sprvsr, t, "update")
    
    # --- equilibrium ---
    final_pos, pos_in_t, vel_in_t, potential_force_in_t = Eq.calculate_state(Variabs, Strctr, tip_pos=State_update.tip_pos,
                                                                             tip_angle=State_update.tip_angle)

    # --- save sizes and plot ---
    State_update._save_data(t, Strctr, final_pos, State_update.buckle_arr, potential_force_in_t,
                            compute_thetas_if_missing=True)
    plot_funcs.plot_arm(final_pos, State_update.buckle_arr, np.rad2deg(State_update.theta_arr), Strctr.L, modality = "update")
    print('pre buckle', State_update.buckle_arr)
    # print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    
    # --- shims buckle ---
    State_update.buckle(Variabs, Strctr, t, State_measured = State_meas)
    
    Eq = EquilibriumClass(Strctr, T_eq, n_steps, damping, mass, buckle_arr=helpers_builders.jaxify(State_update.buckle_arr),
                          pos_arr=helpers_builders.jaxify(State_update.pos_arr))
    print('post buckle', State_update.buckle_arr)
    # print('post buckle update', State_update.buckle_arr)
    # print('energy', Eq.energy(Variabs, Strctr, final_pos)[-1])
    plot_funcs.plot_arm(final_pos, State_update.buckle_arr, np.rad2deg(State_update.theta_arr),
                        Strctr.L, modality = "update")   

In [None]:
# U = np.array([[1, 0, -1, 0, 0],
#                [1, 0, 0, -1, 0],
#                [1, 0, 0, 0, -1],
#                [0, 1, -1, 0, 0],
#                [0 ,1, 0, -1, 0],
#                [0, 1, 0, 0, -1]])
# U_dagger = np.linalg.pinv(U)
# U_dagger.T

In [None]:
plt.plot(np.sum(np.abs(Sprvsr.loss_in_t**2), axis=1))
plt.xlabel('t')
plt.ylabel('Loss')

In [None]:
plt.plot(State_update.buckle_in_t[0,0,:])
plt.plot(State_update.buckle_in_t[1,0,:])
plt.plot(State_update.buckle_in_t[2,0,:])
plt.plot(State_update.buckle_in_t[3,0,:])
plt.plot(State_update.buckle_in_t[4,0,:])
plt.xlabel('t')
plt.ylabel('buckle')
plt.legend(['hinge 1', 'hinge 2', 'hinge 3', 'hinge 4', 'hinge 5'])

In [None]:
fig, anim = plot_funcs.animate_arm(pos_in_t, Strctr.L, interval_ms=20, save_path=f"arm_animation_t{t}.gif", fps=4,
                                   frames=20)

# Show inline animation (uses JavaScript/HTML5)
HTML(anim.to_jshtml())