In [2]:
import mujoco
import mediapy as media

import os
import numpy as np

from brax.mjx.base import State as MjxState 
from mujoco import mjx
from brax.io import html, mjcf, model

In [5]:
class Controller():
    
    def __init__(self, record_target_history = False, csv_file: str = None):
        # Read target from csv file
        if csv_file is not None:
            self.pose = np.loadtxt(csv_file, delimiter=",", skiprows=1)
        self.curr_target_idx = 1
        self.target_pose_history = None
        self.target_pose_history_size = 1
        
        self.simulation_frequency = 200
        self.record_target_history = record_target_history
        
    def init_target_pose_history(self, target: np.ndarray):
        self.target_pose_history = target
        
    def set_target_trajectory(self, target: np.ndarray):
        self.pose = target
        self.curr_target_idx = 1
    
    def set_simulation_frequency(self, frequency: int):
        self.simulation_frequency = frequency

    def get_initial_pose(self, data: mujoco.MjData):
        """Return the initial pose for the robot."""
        if self.pose is None:
            print("No target pose is set.")
            return
        # NOTE: The shape of self.pose is (N, 7) -> Quaternion Representation        
        flange_pose = get_pose_above(self.pose[0, :], 0.1058)  # NOTE: We obtain the flange pose by shifting the pose in the local z by specified amount
        
        data.qpos[0:3] = flange_pose[0:3] 
        # NOTE: Converting the quaternion in to axis angle, MuJoCo uses wxyz
        data.qpos[3:6] = quaternion_to_axis_angle_vector(flange_pose[3], flange_pose[4], flange_pose[5], flange_pose[6]) 
        # print("Initial pose: ", data.qpos)
    
    def get_control_input(self, data: mujoco.MjData):
        """Return the control input for the robot."""
        # NOTE: Since we are doing position control ,we just have to provide the position targets
        t = data.time
        
        if t < 0.4:  # FOR STABILIZING THE SIMULATION
            data.ctrl = self.get_ee_pose(self.pose[0, :])

            if self.record_target_history:
                self.target_pose_history[self.target_pose_history_size] = self.pose[0, :]
                self.target_pose_history_size += 1
        
        elif self.curr_target_idx < self.pose.shape[0]:
            data.ctrl = self.get_ee_pose(self.pose[self.curr_target_idx, :])
            
            if self.record_target_history:
                self.target_pose_history[self.target_pose_history_size] = self.pose[self.curr_target_idx, :]
                self.target_pose_history_size += 1
            
            self.curr_target_idx += 1
            
        else:    
            return self.curr_target_idx
        
        return 0
    
    def get_ee_pose(self, peg_pose):
        """Return the end-effector pose giving the target peg pose."""
        # NOTE: Target Position is for ctrl of the slide joints
        # NOTE: Target Orientation is for ctrl of the hinge joints (Radians)
        
        target = get_pose_above(peg_pose, 0.1058)
        target_position = target[0:3] 
        target_orientation = np.array(quaternion_to_axis_angle_vector(target[3], target[4], target[5], target[6]))
        return np.concatenate((target_position, target_orientation), axis=0)


In [8]:
class Simulator():
    def __init__(self, xml_path = "env/env.xml"):
        # Create Mujoco model and data
        self.model = mujoco.MjModel.from_xml_path(xml_path)
        self.model.opt.o_solimp[:] = [0.95, 0.99, 0.0001, 0.1, 1]
        self.model.opt.o_friction[:] = [1.5, 1.5, 0.005 , 0.0001, 0.0001]
        self.model.opt.timestep = 1e-4
        self.model.opt.gravity[2] = 0
        # self.model.opt.o_solimp[:] = [0.9, 0.95, 0.001, 0.5, 2]
        self.data = mujoco.MjData(self.model)
        self.renderer = mujoco.Renderer(self.model, 720, 960)

        # mjx_model = mjx.put_model(self.model)
        # mjx_data = mjx.put_data(self.model, self.data)

        # print(self.data.qpos, type(self.data.qpos))
        # print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())
        
        # visualize contact frames and forces, make body transparent
        options = mujoco.MjvOption()
        mujoco.mjv_defaultOption(options)
        # options.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
        # options.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True
        # options.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True
        self.mj_option = options

        # tweak scales of contact visualization elements
        self.model.vis.scale.contactwidth = 0.05
        self.model.vis.scale.contactheight = 0.03
        self.model.vis.scale.forcewidth = 0.05
        self.model.vis.map.force = 0.05
        
        # Simulate settings
        self.times = 1
        self.duration = 40                   # (seconds)
        self.framerate = 50                 # (Hz)
        self.frames = []
        self.n_steps = int(self.duration * self.framerate)
        n_steps = self.n_steps

        # Data recording
        self.i = 0              # current step index
        self.sim_time = np.zeros(n_steps)
        self.position = np.zeros((n_steps, 3))
        self.orientation = np.zeros((n_steps, 4))
        self.velocity = np.zeros((n_steps, self.model.nv))
        self.acceleration = np.zeros((n_steps, self.model.nv))
        self.force = np.zeros((n_steps, 3))
        self.torque = np.zeros((n_steps, 3))
        self.target_pose_history = np.zeros((n_steps, 7))

        # Sites in Mujoco
        sites_on_corner = int(self.model.numeric_data[0])
        self.peg_site_indexes = [
            mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_SITE, "peg/peg_bottom_" + str(i)) for i in range(1, 1+sites_on_corner)
        ]
        self.sites_underwater_ratio = np.zeros(n_steps)
        
        # Options
        self.RECORD_VIDEO = False
        self.SHOW_VIDEO = False
        self.FIGURE = False
        self.FIGURE_TITLE = None
    
    def with_options(self, 
                     RECORD_VIDEO: bool = None, 
                     SHOW_VIDEO: bool = None,
                     SAVE_VIDEO: bool = None,
                     FIGURE: bool = None, 
                     FIGURE_TITLE: str = None):
        if RECORD_VIDEO is not None: self.RECORD_VIDEO = RECORD_VIDEO
        if SHOW_VIDEO is not None: self.SHOW_VIDEO = SHOW_VIDEO
        if SAVE_VIDEO is not None: self.SAVE_VIDEO = SAVE_VIDEO
        if FIGURE is not None: self.FIGURE = FIGURE
        if FIGURE_TITLE is not None: self.FIGURE_TITLE = FIGURE_TITLE

        return self
        
    def run(self, controller: Controller, resume_last = False, start_plot_time = 0.4, video_filename=None):
        '''
        Run simulation with the given controller. 
        Output the simulation data in the form of a numpy array: \ 
        [time, position, orientation, force, torque, sites_underwater_ratio]
        '''
        
        data = self.data
        
        if not resume_last:    # start a new simulation
            mujoco.mj_resetData(self.model, self.data)      # Reset state and time.
            controller.get_initial_pose(data)           # Set desired positions for each joint
            controller.init_target_pose_history(self.target_pose_history)  # Initialize target pose history for plotting
            self.frames.clear()
            self.i = 0
            
        # Set controller's simulation frequency
        controller.set_simulation_frequency(self.framerate)

        end = 0
        
        while self.i < self.n_steps:
            # Sensor and state recording
            self.record_state(self.i, data)  

            # Control 
            end = controller.get_control_input(data)
            if end != 0:
                end = self.i
                break
            
            # Simulation
            while data.time < self.i / self.framerate:
                mujoco.mj_step(self.model, self.data)
                
            # Rendering 
            if self.RECORD_VIDEO:
                self.renderer.update_scene(data, "track", self.mj_option)
                pixels = self.renderer.render()
                # media.show_image(pixels)
                self.frames.append(pixels)
            
            self.i += 1
            
        start_plot_position = int(self.framerate * start_plot_time) + int(self.framerate)       # seconds = start_plot_position / framerate
        
        if self.SHOW_VIDEO: 
            media.show_video(self.frames[start_plot_position:], fps=self.framerate * self.times)

        if (self.SAVE_VIDEO and video_filename is not None):
            media.write_video(video_filename,self.frames[start_plot_position:], fps=self.framerate * self.times)

        if self.FIGURE:
            self.plot_figures(start_plot_position, end, controller.target_pose_history)
        
        output_data = [np.hstack((self.sim_time[start_plot_position: end, None], 
                                self.position[start_plot_position: end], 
                                self.orientation[start_plot_position:end], 
                                self.force[start_plot_position: end], 
                                self.torque[start_plot_position: end],
                                self.sites_underwater_ratio[start_plot_position: end, None]
                                )), self.frames[start_plot_position:]]
        return output_data

    def record_state(self, i, data: mujoco.MjData):
        self.sim_time[i] = data.time
        self.position[i] = data.sensordata[6:9]
        self.orientation[i] = data.sensordata[9:13]
        self.velocity[i] = data.qvel[:]
        self.acceleration[i] = data.qacc[:]
        self.force[i] = - data.sensordata[0:3]
        self.torque[i] = - data.sensordata[3:6]

        # Calculate the ratio of sites "underwater"
        num_corner_sites = len(self.peg_site_indexes)
        sites_underwater = 0
        for idx in self.peg_site_indexes:
            
            if data.site_xpos[idx][2] < -1e-3 and((abs(data.site_xpos[idx][0])<0.015) and (abs(data.site_xpos[idx][1])<0.015)):
                sites_underwater += 1
            
        self.sites_underwater_ratio[i] = sites_underwater / num_corner_sites if num_corner_sites > 0 else 0

    def plot_figures(self, start_plot_position, end, target_pose, save_to_file = False):
        import matplotlib.pyplot as plt
        sps = start_plot_position
        sim_time, acceleration, position, orientation, velocity, force, torque = self.sim_time[sps: end], self.acceleration[sps: end], self.position[sps: end], self.orientation[sps: end], self.velocity[sps: end], self.force[sps: end], self.torque[sps: end]
        target_pose = target_pose[sps: end]
        
        dpi = 150
        width = 1200
        height = 900
        figsize = (width / dpi, height / dpi)
        _, ax = plt.subplots(3, 2, figsize=figsize, dpi=dpi, sharex=True)
        
        # set title
        if self.FIGURE_TITLE is not None:
            plt.suptitle(self.FIGURE_TITLE)
        
        ax[0, 0].plot(sim_time, velocity[:, 0:3])
        ax[0, 0].set_title('velocity')
        ax[0, 0].set_ylabel('meter/s')

        ax[0, 1].plot(sim_time, velocity[:, 3:7])
        ax[0, 1].set_title('angular velocity')
        ax[0, 1].set_ylabel('radian/s')
        
        if 1 == 1 :
            lines = ax[1, 0].plot(sim_time, position)
            ax[1, 0].set_title('position')
            ax[1, 0].set_ylabel('meter')
            # ax[1, 0].plot(self.sim_time[int(self.framerate * 0.4): end], target_pose[:, 2], '--')
            target_lines = ax[1, 0].plot(sim_time, target_pose[:, 2], '--')
            ax[1, 0].legend(iter(lines + target_lines), ('x', 'y', 'z', 'target z'))
        else:
            lines = ax[1, 0].plot(sim_time, position[:,0:2])
            ax[1, 0].set_title('position')
            ax[1, 0].set_ylabel('meter')
            ax[1, 0].legend(iter(lines), ('x', 'y', 'z'))
            ax[1, 0].plot(sim_time, target_pose[:, 0:2], '--')
    
        if 1 == 1 :
            lines = ax[1, 1].plot(sim_time, orientation)
            ax[1, 1].set_title('orientation')
            ax[1, 1].set_ylabel('radian')
            ax[1, 1].legend(iter(lines), ('w', 'x', 'y', 'z'))
        else:
            lines = ax[1, 1].plot(sim_time, orientation[:, 1:])
            ax[1, 1].set_title('orientation (quaternion)')
            ax[1, 1].set_ylabel('radian')
            ax[1, 1].legend(iter(lines), ('x', 'y', 'z'))
            target_lines = ax[1, 1].plot(sim_time, target_pose[:, 4:], '--')
        
        lines = ax[2, 0].plot(sim_time, force)
        ax[2, 0].set_title('force')
        ax[2, 0].set_ylabel('Newton')
        ax[2, 0].set_xlabel('second')
        ax[2, 0].legend(iter(lines), ('x', 'y', 'z'))
        
        lines = ax[2, 1].plot(sim_time, torque)
        ax[2, 1].set_title('torque')
        ax[2, 1].set_ylabel('N*m')
        ax[2, 1].set_xlabel('second')
        ax[2, 1].legend(iter(lines), ('x', 'y', 'z'), loc='upper right')

        plt.tight_layout()
        if save_to_file:
            os.makedirs("pic", exist_ok=True)
            plt.savefig("pic/Exp6_record_" + str(self.FIGURE_TITLE) + ".png")

  '''
