In [None]:
import numpy as np
np.set_printoptions(precision=3, suppress=True)

import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import importlib
import numpy as np
import copy

from jax import grad, jit, vmap
from jax.experimental.ode import odeint
from IPython.display import HTML
from time import time

from typing import TYPE_CHECKING, Callable, Union, Optional

from VariablesClass import VariablesClass
from StructureClass import  StructureClass
from StateClass import StateClass
from EquilibriumClass import EquilibriumClass
from config import CFG

import plot_funcs, colors, helpers_builders, learning_funcs, file_funcs, numerical_experiments

colors_lst, red, custom_cmap = colors.color_scheme()
plt.rcParams['axes.prop_cycle'] = plt.cycler('color', colors_lst)

In [None]:
import config
importlib.reload(config)
from config import CFG

import StructureClass
importlib.reload(StructureClass)
from StructureClass import StructureClass

import VariablesClass
importlib.reload(VariablesClass)
from VariablesClass import VariablesClass

import SupervisorClass
importlib.reload(SupervisorClass)
from SupervisorClass import SupervisorClass

importlib.reload(plot_funcs)
importlib.reload(file_funcs)

In [None]:
# --- build geometry (all topology stays in StructureClass) ---
Strctr = StructureClass(CFG, update_scheme=CFG.Train.update_scheme)

# --- Initiate variables ---
Variabs = VariablesClass(Strctr, CFG)

# --- Initiate Supervisor sizes ---
Sprvsr = SupervisorClass(Strctr, CFG, supress_prints = True)
Sprvsr.create_dataset(Strctr, CFG, CFG.Train.dataset_sampling)
tip_pos = np.array([0.175, 0.08])  # 4th one for PPT presentation Feb2
tip_angle = -0.3  # 4th one for PPT presentation Feb2
Sprvsr.tip_pos_in_t[1] = tip_pos
Sprvsr.tip_angle_in_t[1] = tip_angle

In [None]:
# tip_ys = np.array([0.08, -0.08])
# tip_x = 0.17
# tip_pos = np.array([[0.17, 0.08], [0.17, -0.08]])
# tip_angle = np.array([0.3, -0.3])
from itertools import product

init_buckles = list(product([-1, 1], repeat=4))
desired_buckles = copy.copy(init_buckles)

if CFG.Train.dataset_sampling == 'tile':
    N_instances = np.shape(tip_pos)[0] * np.shape(tip_angle)[0]

In [None]:
# tip_ys = np.array([0.08, -0.08])
# tip_x = 0.17
# tip_angles = np.array([0.3, -0.3])
# init_buckles = [(-1, -1, -1, -1), (-1, -1, -1, 1), (-1, -1, 1, -1),  (-1, -1, 1, 1), (-1, 1, -1, -1), (-1, 1, -1, 1), 
#                 (-1, 1, 1, -1), (-1, 1, 1, 1), (-1, 1, 1, -1), (-1, 1, 1, 1)]
# desired_buckles = [((-1, -1, -1, -1)), ((1, -1, -1, -1)), ((1, -1, -1, -1)), ((1, -1, 1, -1)), ((-1, -1, 1, 1))]
loop_length = len(init_buckles)*len(desired_buckles)

# final_loss = np.zeros([np.size(tip_ys), np.size(tip_angles), len(init_buckles)])
final_loss = np.zeros([len(init_buckles), len(desired_buckles)])

loop_count=0
breath_for_gif = 3
for k, init_buckle_tup in enumerate(init_buckles):
    for l, desired_buckle_tup in enumerate(desired_buckles):
        if k >= l or k == 0 or k==1:
            print('init and desired buckles are the same or symmetrical to another configuration, skipping')
            continue
        
        init_buckle = np.array(init_buckle_tup).reshape(4,1)
        desired_buckle = np.array(desired_buckle_tup).reshape(4,1)
        
        print(f'init_buckle={init_buckle_tup}')
        print(f'desired_buckle={desired_buckle_tup}')

    #             if i == 0 or j == 0:
    #                 print('skipping')
    #                 pass
    #             else:
    #     if k == 2:

        # deisred buckle
        Sprvsr.desired_buckle_arr = desired_buckle
        # init_buckle = helpers_builders._initiate_buckle(CFG.Strctr.H, CFG.Strctr.S, buckle_pattern=init_buckle)

        # --- state (straight chain, unit spacing => rest lengths = 1) ---
        State_meas = StateClass(Strctr, Sprvsr, buckle_arr=init_buckle)  # meausrement modality
        State_update = StateClass(Strctr, Sprvsr, buckle_arr=init_buckle)  # update modality
        State_des = StateClass(Strctr, Sprvsr, buckle_arr=Sprvsr.desired_buckle_arr)  # desired state

        # --- initialize, no tip movement yet
        Eq_meas = EquilibriumClass(Strctr, CFG, buckle_arr=State_meas.buckle_arr, pos_arr=State_meas.pos_arr)  # meausrement
        Eq_des = EquilibriumClass(Strctr, CFG, buckle_arr=Sprvsr.desired_buckle_arr, pos_arr=State_des.pos_arr)  # desired state
        State_meas._save_data(0, Strctr, State_meas.pos_arr, State_meas.buckle_arr)
        State_update._save_data(0, Strctr, State_meas.pos_arr, State_update.buckle_arr)  
        State_des._save_data(0, Strctr, State_des.pos_arr, State_des.buckle_arr) 

        buckle_bool = False
        meas_count = 0  # for tile

        for t in range(1, Sprvsr.T - breath_for_gif):    
            print('t=', t)   

            ## MEASUREMENT

            # --- equilibrium - measured & desired---
            Eq_meas = EquilibriumClass(Strctr, CFG, buckle_arr=State_meas.buckle_arr, pos_arr=None)  # meausrement
            Eq_des = EquilibriumClass(Strctr, CFG, buckle_arr=Sprvsr.desired_buckle_arr, pos_arr=None)  # desired state
            # condition for calculating measurement
    #         CFG.Train.dataset_sampling == 'uniform' and (t==1 or bukle_bool)
    #         CFG.Train.dataset_sampling == 'tile' and (meas_count<N_instances or buckle_bool)
    #         CFG.Train.dataset_sampling == 'specified' and (t==1 or buckle_bool)
            if t == 1 or buckle_bool or CFG.Train.dataset_sampling != 'specified':
                final_pos, pos_in_t, _, F_theta = Eq_meas.calculate_state(Variabs, Strctr, Sprvsr,
                                                                          init_pos=None,
                                                                          tip_pos=Sprvsr.tip_pos_in_t[t], 
                                                                          tip_angle=Sprvsr.tip_angle_in_t[t])
                final_pos_des, pos_in_t_des , _, F_theta_des = Eq_des.calculate_state(Variabs, Strctr, Sprvsr, 
                                                                                      init_pos=None,
                                                                                      tip_pos=Sprvsr.tip_pos_in_t[t], 
                                                                                      tip_angle=Sprvsr.tip_angle_in_t[t])

            meas_count += 1
            # --- save sizes and plot - measured & desired ---
            State_meas._save_data(t, Strctr, final_pos, State_meas.buckle_arr, F_theta)
            State_des._save_data(t, Strctr, final_pos_des, State_des.buckle_arr, F_theta_des)

            Sprvsr.set_desired(final_pos_des, State_des.Fx, State_des.Fy, t)
#             if t == 1:
#                 fig, (axL, axR) = plt.subplots(1, 2, figsize=(9, 4), constrained_layout=False)
#                 plot_funcs.plot_arm(State_meas.pos_arr_in_t[:,:,t], State_meas.buckle_in_t[:,:, t], 
#                                     State_meas.theta_arr_in_t[:, t], Strctr.L, modality="measurement", 
#                                     show=False, ax=axL)
#                 plot_funcs.plot_arm(State_des.pos_arr_in_t[:,:,t], State_des.buckle_in_t[:,:, t], 
#                                     State_des.theta_arr_in_t[:, t], Strctr.L, modality="desired",     
#                                     show=False, ax=axR)
#                 plt.show()

            # --- loss ---    
            Sprvsr.calc_loss(Variabs, t, State_meas.Fx, State_meas.Fy)

            ## UPDATE
            if t == 1:
                Sprvsr.calc_update_tip(t, Strctr, Variabs, State_meas, current_tip_pos=Sprvsr.tip_pos_in_t[1],
                                       current_tip_angle = Sprvsr.tip_angle_in_t[1],
                                       # prev_tip_update_pos = init_tip_update_pos, prev_tip_update_angle = init_tip_update_angle,
                                       prev_tip_update_pos = Sprvsr.tip_pos_in_t[0], prev_tip_update_angle = Sprvsr.tip_angle_in_t[0],
                                       correct_for_total_angle = True, correct_for_coil = True,
                                       correct_for_cut_origin = True)
            else:
                Sprvsr.calc_update_tip(t, Strctr, Variabs, State_meas, current_tip_pos=Sprvsr.tip_pos_in_t[t],
                                       current_tip_angle=Sprvsr.tip_angle_in_t[t],
                                       correct_for_total_angle = True, correct_for_coil = True,
                                       correct_for_cut_origin = True)
#             print('current_tip_pos =', Sprvsr.tip_pos_update_in_t[t])
#             print('current_tip_angle =', Sprvsr.tip_angle_update_in_t[t])

            # --- equilibrium ---
            final_pos_update, pos_in_t_update, _, F_theta_udpate = Eq_meas.calculate_state(Variabs, Strctr, Sprvsr,
                                                                             State_update.pos_arr,
                                                                             tip_pos=Sprvsr.tip_pos_update_in_t[t],
                                                                             tip_angle=Sprvsr.tip_angle_update_in_t[t])

            # --- save sizes and plot ---
            State_update._save_data(t, Strctr, final_pos_update, State_update.buckle_arr, F_theta)

            # --- shims buckle ---
            buckle_bool = State_update.buckle(Variabs, Strctr, t, State_measured = State_meas)
            if buckle_bool:
                meas_count = 0

            # break if training succeeded
            if Sprvsr.loss_MSE <= 10**(-5):
                print(f'successful training')
                # add fictitious times just for gif plot
                for i in range(breath_for_gif):
                    State_update._save_data(t+breath_for_gif+1, Strctr, final_pos_update, State_update.buckle_arr, F_theta)
                break

        # Animate
        pos_in_t_meas = np.moveaxis(State_meas.pos_arr_in_t, 2, 0)
        pos_in_t_udpate = np.moveaxis(State_update.pos_arr_in_t, 2, 0)
        buckle_in_t = np.moveaxis(State_meas.buckle_in_t, 2, 0)

        init = 1
        final = t+breath_for_gif

        # save_path_gif = f"tip_y{tip_y}_tip_angle{tip_angle}__init_{np.array2string(init_buckle.T[0])}_desired{np.array2string(desired_buckle.T[0])}.gif"
        save_path_gif = f"gif_init_{np.array2string(init_buckle.T[0])}_desired{np.array2string(desired_buckle.T[0])}.gif"
        fig, anim = plot_funcs.animate_arm_w_arcs(pos_in_t_udpate[init:final,:,:], Strctr.L, frames=final, interval_ms=400,
                                                  save_path=save_path_gif, fps=2, show_inline=False,
                                                  buckle_traj=buckle_in_t[init:final,:,:])

        # important training afo t
        # save_path_csv = f"final_loss_{Sprvsr.loss_MSE}_tip_y{tip_y}_tip_angle{tip_angle}_init_{np.array2string(init_buckle.T[0])}_desired{np.array2string(desired_buckle.T[0])}.csv"
        save_path_csv = f"final_loss_{Sprvsr.loss_MSE}_init_{np.array2string(init_buckle.T[0])}_desired{np.array2string(desired_buckle.T[0])}.csv"
        file_funcs.export_training_csv(save_path_csv, Strctr, Sprvsr, T=t, State_meas=State_meas,
                                       State_update=State_update, State_des=None)

        # final loss
        # final_loss[i, j, k] = loss_MSE
        final_loss[k, l] = Sprvsr.loss_MSE

        loop_count+=1
        print(f'finished {loop_count} out of {loop_length}')

In [None]:
Sprvsr.tip_pos_update_in_t

In [None]:
Sprvsr.tip_pos_in_t

In [None]:
fig, ax = plt.subplots(figsize=(7.2, 5))

im = ax.imshow(final_loss.reshape(4,3), cmap=custom_cmap)

# labels (same length!)
xtick_labels = (init_buckles)
ytick_labels = (f'y={tip_ys[0]}, theta={tip_angles[0]}', f'y={tip_ys[0]}, theta={tip_angles[1]}',
                f'y={tip_ys[1]}, theta={tip_angles[0]}', f'y={tip_ys[1]}, theta={tip_angles[1]}')

ax.set_xticks(range(3))
ax.set_xticklabels(xtick_labels)

ax.set_yticks(range(4))
ax.set_yticklabels(ytick_labels)


cax = ax.inset_axes((1.05, 0, 0.05, 1.0))  # (x, y, width, height) in axes coords

cbar = fig.colorbar(im, cax=cax)

cbar.set_label(r'$\|\mathcal{L}\|$')

In [None]:
with open('final_loss.npy', 'wb') as f:
    np.save(f, final_loss.reshape(4,3))

In [None]:
with open('final_loss.npy', 'rb') as f:
    a = np.load(f)
print(a)