In [None]:
import pickle
import rospy
import baxter_interface
import os.path as path
import copy
from tqdm import tqdm_notebook as tqdmn
import matplotlib.pyplot as plt
import itertools
import numpy as np
import time

from collections import namedtuple
import random

# Taken from
# https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb

Transition = namedtuple('Transition', ('state', 'action', 'mask', 'next_state',
                                       'reward'))


class Memory(object):
    def __init__(self):
        self.memory = []

    def push(self, *args):
        """Saves a transition."""
        self.memory.append(Transition(*args))

    def sample(self, batch_size=None):
        if batch_size is None:
            return Transition(*zip(*self.memory))
        else:
            random_batch = random.sample(self.memory, batch_size)
            return Transition(*zip(*random_batch))

    def append(self, new_memory):
        self.memory += new_memory.memory

    def __len__(self):
        return len(self.memory)

In [None]:
# initialize ros node
rospy.init_node('lambda_trainer')
limb = baxter_interface.Limb('right')

In [None]:
# define relevant variables
PLAYBACK_MODE = 'bend'
REACH_TIME = 0.05
file_seed = path.expanduser('~/data/moveit_data/')
if PLAYBACK_MODE == 'bend':
    moveit_file = file_seed + 'bend_dof_'
else:
    file_seed = file_seed + 'full_dof_'
limb.set_joint_position_speed(0.75)

In [None]:

# class for the PPO model


In [2]:
# in a loop load all the plans and replay
joint_angles = {}
err = []
xaxis = []
# observations = []

def getState(limb, names):
    measured_pos = []
    measured_vel = []
    measured_torque = []
    for name in names:
        measured_pos.append(limb.joint_angles()[name])
        measured_vel.append(limb.joint_velocities()[name])
        measured_torque.append(limb.joint_efforts()[name])
    return np.array(measured_pos), np.array(measured_vel), np.array(measured_torque)

def reward(goal, measured):
    return -(np.sum(np.array(goal) - np.array(measured))**2)

def within_threshold(goal, measured, threshold):
    if np.abs(np.array(goal) - np.array(measured)).all() < threshold:
        return True
    else:
        return False
    
def action_step(initial_setpoint, delta, joint_names):
    



# every 1:00min update policy parameters
# after a full file playback, grab a new random file
# full outer loop is max updates

trajectory_done = True
num_files = 50
angle_threshold = 0.01
moving_joints = ['right_s1', 'right_e1', 'right_w1']

for episode in range(MAX_EPISODES):
    if trajectory_done:
        file_number = np.random.randint(num_files)
        plan = pickle.load(open(moveit_file + str(file_number) + .'.pkl', 'rb'))
        
        goal_list = []
        for i, pt in enumerate(plan):
            if i == 0:
                joint_names = copy.deepcopy(plan[i])
            else:
                goal_list.append(pt)
                
        trajectory_done = False
        memory = Memory()
        
    while goal_list:
        goal_current = goal_list.pop(0)
        done = False
        
        angles, velocities, torques = getState(limb, moving_joints)
        s_goal = np.array(goal_current)
        s = np.hstack([angles, velocities, torques, s_goal])
        
        start_time = time.time()
        while((time.time() - start_time) < REACH_TIME):
#             a = policy.sample_action(s)
#             joint_angles = f(goal_current, a)
            limb.set_joint_positions(joint_angles)
    
            angles_new, velocities_new, torques_new = getState(limb, moving_joints)
            s_ = np.hstack([angles_new, velocities_new, torques_new, s_goal])
            
            r = reward(goal_current, angles_)
            
            if within_threshold(goal_current, angles, angle_threshold):
                done = 1
                memory.push(s, a, r, s_, done)
                break
                
            done = 0
            memory.push(s, a, done, s_, r)
            s = s_
            
    trajectory_done = True
        

for file_iter in tqdmn(xrange(50), desc='Files read:'):
    plan = pickle.load(open(moveit_file + str(file_iter) + '.pkl', 'rb'))
    ctr = 0
    err.append(dict())
    err[file_iter] = {'s1':[], 'e1': [], 'w1':[]}
    obs = []
    for ctr in tqdmn(xrange(len(plan)), desc='Waypoints achieved:'):
        if ctr == 0:
            joint_names = copy.deepcopy(plan[ctr])
#             print joint_names
        else:
            current_velocities = limb.joint_velocities()
            current_torques = limb.joint_efforts()
            for (i, joint) in enumerate(joint_names):
                joint_angles[joint] = plan[ctr][i]
            if ctr == 1:
                limb.move_to_joint_positions(joint_angles)
            else:
                start_time = time.time()
                while ((time.time() - start_time) < REACH_TIME):
                    limb.set_joint_positions(joint_angles)
            measured_angles = limb.joint_angles()
            err[file_iter]['e1'].append(measured_angles['right_e1']-joint_angles['right_e1'])
            err[file_iter]['s1'].append(measured_angles['right_s1']-joint_angles['right_s1'])
            err[file_iter]['w1'].append(measured_angles['right_w1']-joint_angles['right_w1'])
            # append all dicts to a list
#     observations.append(obs)

SyntaxError: invalid syntax (<ipython-input-2-65a23f1f4ffb>, line 39)

In [None]:
# plot commanded trajectory and error
joint_angles = {}
for file_iter in tqdmn(xrange(2), desc='Files read:'):
    plan = pickle.load(open(moveit_file + str(file_iter) + '.pkl', 'rb'))
    s1 = [row[1] for row in plan[1:]]
    e1 = [row[3] for row in plan[1:]]
    w1 = [row[5] for row in plan[1:]]
    plt.figure(file_iter)
    plt.plot(np.array(s1)*180/np.pi, 'b')
    plt.plot(np.array(e1)*180/np.pi, 'r')
    plt.plot(np.array(w1)*180/np.pi, 'g')
    plt.plot(np.array(err[file_iter]['e1'])*180./np.pi, 'b--')
    plt.plot(np.array(err[file_iter]['s1'])*180./np.pi, 'r--')
    plt.plot(np.array(err[file_iter]['w1'])*180./np.pi, 'g--')
    plt.legend(['e1', 's1', 'w1'])
    plt.show()

plt.figure(0)
plt.plot(np.array(err[0]['e1'])*180./np.pi, 'b')
plt.plot(np.array(err[0]['s1'])*180./np.pi, 'r')
plt.plot(np.array(err[0]['w1'])*180./np.pi, 'g')
plt.legend(['e1', 's1', 'w1'])
plt.show()
plt.figure(1)
plt.plot(np.array(err[1]['e1'])*180./np.pi, 'b')
plt.plot(np.array(err[1]['s1'])*180./np.pi, 'r')
plt.plot(np.array(err[1]['w1'])*180./np.pi, 'g')
plt.legend(['e1', 's1', 'w1'])
plt.show()