In [None]:
#Add the path to the compiled dgm and dgh packages 
import sys
sys.path.append('/usr/local/lib/python3/dist-packages/')
#Import DGM and math libs
import dynamic_graph_manager_cpp_bindings
import threading
import numpy as np
#Import libs for communicating with the LCM sychronization messages from the DGM
from lcm_msgs import ipc_trigger_t
import lcm
#Import the communication libs for connecting to the plotJuggler
import zmq
import math
import json
from time import sleep

In [None]:
class FrankaDynamicDghManager():
    
    def __init__(self, robot_config, 
                 ds_ratio = 10, 
                 plotting = True, 
                 logging = True, 
                 max_log_count = 1e7, 
                 plotter_port=5555):
        
        self.controller = None
        self.plotter_port = plotter_port
        self.robot_config = robot_config
        self.ds_ratio = ds_ratio
        self.trigger_counter = 0
        self.logging = logging
        self.max_log_count = max_log_count
        self.logging = logging
        self.plotting = plotting
        self.log_data = []
        
    
    def thread(self):
        while self.running:
             self.lc.handle()  
                
    def start(self):
        print('Starting the Thread ...')
        
        if self.plotting:
            #Interface for plotting and logging the data
            self.context = zmq.Context()
            self.socket = self.context.socket(zmq.PUB)
            self.socket.bind(f"tcp://*:{self.plotter_port}")
            
        #Interface for getting sync triggers from the DGM over LCM
        self.msg = ipc_trigger_t()
        self.lc = lcm.LCM()
        self.subscription = self.lc.subscribe("dgm_franka_control_trigger",\
                                              self.trigger_callback)
        self.subscription.set_queue_capacity(1)
        
        #Instantiate a Dynamic Graph Head (DGH) class to connect to the robot
        self.head = dynamic_graph_manager_cpp_bindings.DGMHead(self.robot_config)
        self.log_data = []
        self.trigger_counter = 0
        
        self.running = True
        self.lcm_thread = threading.Thread(target = self.thread)
        self.lcm_thread.start()
        sleep(0.2)
        #initial_states can be used by the controller
        self.initial_states = self.read_states()
        print('Tread Started!')
        
    def stop(self):
        print('Stopping the thread ...')
        self.running = False
        self.controller = None
        self.lcm_thread.join()
        self.lc.unsubscribe(self.subscription)
        self.socket.close()
        del(self.lc)
        del(self.head)
        print('Thread stopped!')
        
    
    def read_states(self):
        #Get the sensor values from the shared memory
        self.head.read()
        T  = self.head.get_sensor("joint_torques").copy()
        q  = self.head.get_sensor("joint_positions").copy()
        dq = self.head.get_sensor("joint_velocities").copy() 
        return [q, dq, T]
    
    def write_command(self, cmd):
        assert max(cmd.shape) == 7, "The control command should be a vector of 7 numbers!"
        # Write the sensor values to the shared memory
        self.head.set_control("ctrl_joint_torques", cmd.reshape(7,1))
        self.head.set_control("ctrl_stamp", np.array(self.trigger_timestamp).reshape(1,1)/1000000)
        self.head.write()
    
    def generate_plot_data(self, state):
        q, dq, T = state
        data ={
                "timestamp": self.trigger_timestamp,
                "robot_states": {
                    "q": q.tolist(),
                    "dq": dq.tolist(),
                    "torques":T.tolist()
                }
              }
        return data
    
    def trigger_callback(self, channel, data):
        msg = ipc_trigger_t.decode(data)
        self.trigger_timestamp = msg.timestamp
        self.trigger_counter += 1
        
        if self.trigger_counter % self.ds_ratio == 0:
            state = self.read_states()
            
            if self.plotting:
                data = self.generate_plot_data(state)
                self.socket.send_string( json.dumps(data) )
                
            if self.controller is not None:
                cmd = self.controller(state, self.initial_states)
                self.write_command(cmd)
            else:
                cmd = np.zeros((7,1))
                self.write_command(cmd)
                
            if self.logging and self.trigger_counter/self.ds_ratio < self.max_log_count:
                self.log_data.append([state, cmd.copy(), self.trigger_timestamp])
                
    def get_recorded_dataset(self):
        states = []
        for i in range(len(self.log_data[0][0])):
            states.append(np.vstack([d[0][i] for d in self.log_data]))

        cmds = np.vstack([d[1].T for d in self.log_data])
        stamps = np.vstack([d[2] for d in self.log_data])

        return stamps, states, cmds


In [None]:
robot_yaml = "../dgm_franka/franka_dynamic_interface.yaml"
franka_dgh = FrankaDynamicDghManager(robot_yaml, ds_ratio=1)

In [None]:
franka_dgh.start()

In [None]:
franka_dgh.stop()

In [None]:
stamps, states, cmds = franka_dgh.get_recorded_dataset()

In [None]:
import matplotlib.pyplot as plt

In [None]:
_ = plt.hist(stamps[1:]-stamps[0:-1], bins=100)

In [None]:
_ = plt.plot(states[2][:,3])

In [None]:
import pickle

with open('franka_states_velocity_test.pckl', 'wb') as f:
    pickle.dump([stamps, states, cmds], f)