In [1]:
import plotly.graph_objects as go
import numpy as np
from numpy import cos, sin
import ipywidgets as widgets
from ipywidgets import HBox, VBox
from IPython.display import display, clear_output
from dataclasses import dataclass

## C FFI

In [4]:
import ctypes
from ctypes import POINTER, c_double, c_size_t

clibrary = ctypes.CDLL("cpython.so")

clibrary.set_frames_py.argtypes = [
    POINTER(c_double),  # table_x
    POINTER(c_double),  # table_y
    POINTER(c_double),  # table_z
    POINTER(c_double),  # slider_x
    POINTER(c_double),  # slider_y
    POINTER(c_double),  # slider_z
    POINTER(c_double),  # state_A
    POINTER(c_double),  # state_B
    c_size_t            # num_of_frames
]
clibrary.set_frames_py.restype = ctypes.c_bool

## Visualisation

In [None]:
AXIS_RANGE = [-15, 15] 
MESH_INDICES = [(0, 1, 6), (2, 3, 6), (4, 5, 6), (1, 4, 6), (3, 0, 6), (5, 2, 6)]
STEP_SIZE = 0.0000001
TOTAL_FRAMES = 10

x_widget_A = widgets.FloatSlider(value=0, min=-10, max=10, step=STEP_SIZE, description='X, state A')
y_widget_A = widgets.FloatSlider(value=0, min=-10, max=10, step=STEP_SIZE, description='Y, state A')
z_widget_A = widgets.FloatSlider(value=0, min=-10, max=10, step=STEP_SIZE, description='Z, state A')
phi_widget_A = widgets.FloatSlider(value=0, min=-np.pi, max=np.pi, step=STEP_SIZE, description='Phi, state A')
theta_widget_A = widgets.FloatSlider(value=0, min=-np.pi, max=np.pi, step=STEP_SIZE, description='Theta, state A')
psi_widget_A = widgets.FloatSlider(value=0, min=-np.pi, max=np.pi, step=STEP_SIZE, description='Psi, state A')

x_widget_B = widgets.FloatSlider(value=0, min=-10, max=10, step=STEP_SIZE, description='X,state B')
y_widget_B = widgets.FloatSlider(value=0, min=-10, max=10, step=STEP_SIZE, description='Y, state B')
z_widget_B = widgets.FloatSlider(value=0, min=-10, max=10, step=STEP_SIZE, description='Z, state B')
phi_widget_B = widgets.FloatSlider(value=0, min=-np.pi, max=np.pi, step=STEP_SIZE, description='Phi, state B')
theta_widget_B = widgets.FloatSlider(value=0, min=-np.pi, max=np.pi, step=STEP_SIZE, description='Theta, state B')
psi_widget_B = widgets.FloatSlider(value=0, min=-np.pi, max=np.pi, step=STEP_SIZE, description='Psi, state B')

frame_widget = widgets.IntSlider(value=0, min=0, max=TOTAL_FRAMES-1, step=1, description='Current Frame')

WIDGETS_A = [x_widget_A, y_widget_A, z_widget_A, phi_widget_A, theta_widget_A, psi_widget_A]
WIDGETS_B = [x_widget_B, y_widget_B, z_widget_B, phi_widget_B, theta_widget_B, psi_widget_B]

def find_element_by_tag(fig, tag):
    for i, plot_element in enumerate(fig.data):
        if plot_element.customdata and plot_element.customdata[0] == tag:
            return i
    return None


In [None]:
base_points = np.array([
    [7.0 * np.sin(np.deg2rad(60) * i + np.deg2rad(15) / 2), 7.0 * np.cos(np.deg2rad(60) * i + np.deg2rad(15) / 2), 0] for i in range(6)
])

def initialize_plot():
    fig = go.FigureWidget()
    fig.update_layout(
        scene=dict(xaxis=dict(range=AXIS_RANGE), yaxis=dict(range=AXIS_RANGE), zaxis=dict(range=AXIS_RANGE)),
        title='6-DOF robot with vertical parallel rails'
    )
    
    fig.add_scatter3d(mode='markers', marker=dict(size=5), line=dict(color='blue'), name='Initial Table Points')
    fig.add_scatter3d(mode='markers', marker=dict(size=5), line=dict(color='red'), name='Final Table Points')
    fig.add_mesh3d(color='blue', opacity=0.5, name='Initial Table Mesh')
    fig.add_mesh3d(color='red', opacity=0.5, name='Final Table Mesh')

    fig.update_layout(scene=dict(xaxis=dict(range=AXIS_RANGE, title='X Axis'), yaxis=dict(range=AXIS_RANGE, title='Y Axis'), zaxis=dict(range=AXIS_RANGE, title='Z Axis'), aspectmode='cube'), margin=dict(l=0, r=0, b=0, t=0))
    return fig

def update_plot(change):
    state_A = (ctypes.c_double * 6)(x_widget_A.value, y_widget_A.value, z_widget_A.value, phi_widget_A.value, theta_widget_A.value, psi_widget_A.value)
    state_B = (ctypes.c_double * 6)(x_widget_B.value, y_widget_B.value, z_widget_B.value, phi_widget_B.value, theta_widget_B.value, psi_widget_B.value)
    
    # Prepare arrays to hold the result
    table_x = (ctypes.c_double * (TOTAL_FRAMES * 8))()
    table_y = (ctypes.c_double * (TOTAL_FRAMES * 8))()
    table_z = (ctypes.c_double * (TOTAL_FRAMES * 8))()
    slider_x = (ctypes.c_double * (TOTAL_FRAMES * 6))()
    slider_y = (ctypes.c_double * (TOTAL_FRAMES * 6))()
    slider_z = (ctypes.c_double * (TOTAL_FRAMES * 6))()

    # Call the C function
    success = clibrary.set_frames_py(
        table_x, table_y, table_z,
        slider_x, slider_y, slider_z,
        state_A, state_B,
        TOTAL_FRAMES
    )

    if success:
        initial_frame_points = np.array([[table_x[i], table_y[i], table_z[i]] for i in range(8)])
        final_frame_points = np.array([[table_x[8*frame_widget.value + i], table_y[8*frame_widget.value + i], table_z[8*frame_widget.value + i]] for i in range(8)])

        with plot.batch_update():
            plot.data[0].x = initial_frame_points[:, 0]
            plot.data[0].y = initial_frame_points[:, 1]
            plot.data[0].z = initial_frame_points[:, 2]

            # Update the final table points (current selected frame)
            plot.data[1].x = final_frame_points[:, 0]
            plot.data[1].y = final_frame_points[:, 1]
            plot.data[1].z = final_frame_points[:, 2]

            # Update the mesh for the initial frame
            plot.data[2].x = initial_frame_points[:, 0]
            plot.data[2].y = initial_frame_points[:, 1]
            plot.data[2].z = initial_frame_points[:, 2]
            plot.data[2].i = [i[0] for i in MESH_INDICES]
            plot.data[2].j = [i[1] for i in MESH_INDICES]
            plot.data[2].k = [i[2] for i in MESH_INDICES]

            # Update the mesh for the selected frame
            plot.data[3].x = final_frame_points[:, 0]
            plot.data[3].y = final_frame_points[:, 1]
            plot.data[3].z = final_frame_points[:, 2]
            plot.data[3].i = [i[0] for i in MESH_INDICES]
            plot.data[3].j = [i[1] for i in MESH_INDICES]
            plot.data[3].k = [i[2] for i in MESH_INDICES]


plot = initialize_plot()

x_widget_A.value = -2.58
y_widget_A.value = 1.77
z_widget_A.value = 1.94
phi_widget_A.value = 0.71
theta_widget_A.value = 1.22
psi_widget_A.value = 0.47

x_widget_B.value = 0.0
y_widget_B.value = -1.45
z_widget_B.value = -2.58
phi_widget_B.value = 0.20
theta_widget_B.value = -0.61
psi_widget_B.value = -0.41


for widget in WIDGETS_A + WIDGETS_B + [frame_widget]:
    widget.observe(update_plot, names='value')

update_plot(None)

# Display widgets
display(HBox([VBox(WIDGETS_A), VBox(WIDGETS_B), frame_widget]))
display(plot)