In [None]:
from __future__ import division
import pickle
import os
import random
import uuid
import time
import types
from copy import deepcopy as copy

import gym
from gym import spaces
from gym.envs.classic_control import rendering
import numpy as np
import tensorflow as tf
import tempfile

import baselines.common.tf_util as U

from baselines import logger
from baselines.common.schedules import LinearSchedule
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from baselines.deepq.simple import ActWrapper

from pyglet.window import key as pygkey

In [None]:
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
%matplotlib inline

In [None]:
import matplotlib as mpl
mpl.rc('savefig', dpi=300)

In [None]:
data_dir = os.path.join('data', 'cursor-control')

In [None]:
sess = tf.Session()

helpers

In [None]:
def save_tf_vars(scope, path):
  saver = tf.train.Saver([v for v in tf.global_variables() if v.name.startswith(scope + '/')])
  saver.save(sess, save_path=path)

In [None]:
def load_tf_vars(scope, path):
  saver = tf.train.Saver([v for v in tf.global_variables() if v.name.startswith(scope + '/')])
  saver.restore(sess, path)

In [None]:
def build_mlp(
    input_placeholder,
    output_size,
    scope,
    n_layers=1,
    size=256,
    activation=tf.nn.relu,
    output_activation=tf.nn.softmax,
    reuse=False
  ):
  out = input_placeholder
  with tf.variable_scope(scope, reuse=reuse):
    for _ in range(n_layers):
      out = tf.layers.dense(out, size, activation=activation)
    out = tf.layers.dense(out, output_size, activation=output_activation)
  return out

In [None]:
def plot_trajectories(
  rollouts, goal, title='', file_name=None):
  plt.title(title)

  for rollout in rollouts:
    trajectory = [x[0] for x in rollout]
    x, y = list(zip(*trajectory))[:2]
    if is_succ(rollout):
      cmap = mpl.cm.YlGn
    else:
      cmap = mpl.cm.gray
    plt.scatter(x, y, c=range(len(x)), cmap=cmap, alpha=0.75, linewidth=0)
    plt.scatter(
      [goal[0]], [goal[1]], marker='*', color='yellow', 
      edgecolor='black', linewidth=1, s=300, alpha=0.5)
    
  plt.xlim([-0.05, 1.05])
  plt.ylim([-0.05, 1.05])
  plt.xticks([])
  plt.yticks([])
  plt.axis('off')
  if file_name is not None:
    plt.savefig(os.path.join(data_dir, file_name), bbox_inches='tight')
  plt.show()

In [None]:
cart_to_polar = lambda v: np.array([
  np.arctan2(v[0], v[1]), 
  #max(-max_speed, min(max_speed, np.linalg.norm(v)))
  max_speed
])

normalize_polar = lambda v: np.array([
  (-v[0] + 0.5*np.pi) % (2 * np.pi), 
  #max(-max_speed, min(max_speed, v[1]))
  max_speed
])

def polar_to_cart(v):
  return v[1]*np.array([np.cos(v[0]), np.sin(v[0])])

In [None]:
def perp_dist(pos, init_pos, goal):
  u = pos - init_pos
  v = goal - init_pos
  w = init_pos + u.dot(v)
  return np.linalg.norm(pos - w)

setup env

In [None]:
n_tasks = 10 # each task/env has a unique goal/target

n_act_dim = 2 # vx, vy

#grid_size = (28, 32) # ecog grid dimensions
#bci_dim = grid_size[0]*grid_size[1] # number of BCI channels
bci_dim = 384

n_ext_obs_dim = 7 # number of external state observation dimensions
n_obs_dim = n_ext_obs_dim + bci_dim

gamma = 0.99 # discount factor
max_ep_len = 500 # number of timesteps
goal_dist_thresh = 0.05 # radius of goal
succ_rew_bonus = 1 # for reaching goal
max_speed = 0.01

In [None]:
is_succ = lambda r: r[-1][-1]['succ']
get_ttt = lambda r: r[-1][-1]['ttt']
get_dfot = lambda r: r[-1][-1]['dfot']

In [None]:
goals = np.random.random((n_tasks, 2))

In [None]:
goals = []
for ang in np.arange(0, 2*np.pi, 2*np.pi/8):
  goals.append(polar_to_cart(np.array([ang, 0.5])) + 0.5)
goals = np.array(goals)

In [None]:
plt.scatter(goals[:, 0], goals[:, 1], linewidth=0, color='gray', s=100, marker='*')
plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.show()

In [None]:
with open(os.path.join(data_dir, 'goals.pkl'), 'wb') as f:
  pickle.dump(goals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'goals.pkl'), 'rb') as f:
  goals = pickle.load(f)

In [None]:
def create_2D_gaussian(data, **kwargs):
  '''
  Generate a 2D Gaussian.
  data: N_X x N_Y x 2 matrix. N_X and N_Y are sizes of the ECoG grid.
  kwargs: amp, cent_x, cent_y, spreadata_X, spreadata_Y
  '''
  amp = 2
  cent_x, cent_y = 0, 0
  spreadata_X, spreadata_Y = 7, 4.5
  for key, value in kwargs.items():
    if key is 'amp':
      amp = value
    elif key is 'spreadata_X':
      spreadata_X = value
    elif key is 'spreadata_Y':
      spreadata_Y = value
    elif key is 'cent_x':
      cent_x = value
    elif key is 'cent_y':
      cent_y = value
  F = amp*np.exp(
    -(((data[:,:,0] - cent_x)**2)/(2*spreadata_X**2) +
    ((data[:,:,1] - cent_y)**2)/(2*spreadata_Y**2)))
  return F

def features_2D(x, y, **kwargs):
    ''' 
    Simulate 2D neural features for decoding direction.
    Calls create_2D_gaussian() to generate features.
    Arguments:
    x = x direction magnitude
    y = y direction magnitude
    Keyword arguments:
    noise = amplitude for uniformly distributed noise as % of amplitude (default 0)
    grid_size = tuple for size of generated neural features (default 28 x 32)
    plot = booelan for ploting the features (default False)
    '''
    # Check arguments
    noise = 10
    plot = False
    for key, value in kwargs.items():
      if key is 'noise':
        noise = value
      elif key is 'plot':
        plot = value
    x = int(grid_size[0]/2 + x)
    x = np.max([np.min([x, grid_size[0]]), 0])
    y = int(grid_size[1]/2 + y)
    y = np.max([np.min([y, grid_size[1]]), 0])
    data_X, data_Y = np.meshgrid(range(grid_size[0]), range(grid_size[1]))
    data = np.stack((data_X, data_Y),axis=2)
    kwargs.update({'cent_x':x, 'cent_y':y})
    Z = create_2D_gaussian(data, **kwargs)
    noise = noise/100 * np.max(Z)
    Z += noise*np.random.rand(grid_size[1], grid_size[0])
    if plot:
      f = plt.figure()
      plt.imshow(Z)
      plt.colorbar()
      plt.xlabel('X')
      plt.ylabel('Y')
      str_title = 'Simulated directional features'
      str_title += '\nCentroid = ({},{}) Noise amp = {:.3f}'
      plt.title(str_title.format(x, y, noise))
      plt.gca().invert_yaxis()
      plt.show()
    return Z.flatten()
  
mag = 1000
internal_encode_obs = lambda action, ext_obs, goal: features_2D(*(polar_to_cart(action)*mag))
def internal_decode_act(obs):
  bci_obs = extract_bci_obs(obs)
  Z = bci_obs.reshape((grid_size[1], grid_size[0]))
  i, j = np.unravel_index(np.argmax(Z, axis=None), Z.shape)
  i -= Z.shape[0] // 2
  j -= Z.shape[1] // 2
  return normalize_polar(cart_to_polar(np.array([j, i])))

In [None]:
class CursorControl(gym.Env):
  metadata = {
    'render.modes': ['human']
  }
  
  def __init__(
      self, 
      max_ep_len=max_ep_len, # max number of timesteps
      goal=None, # target position
      init_pos=None, # fixed initial position
      rand_init=True, # True -> new, random initial position for each episode
      return_to_init=False, # True -> need to return to initial position to succeed
      human_user=False, # True -> human at the keyboard, instead of optimal synthetic user
      blending=0, # between 0 and 1 (0 -> no blending, 1 -> ignore agent and take optimal actions)
      rand_goal=False, # True -> new, random goal for each episode
      using_reward_shaping=False,
      vel_smoothing=0.8
    ):
    if not rand_init and init_pos is None:
      raise ValueError
      
    self.observation_space = spaces.Box(np.zeros(n_obs_dim), np.ones(n_obs_dim))
    self.action_space = spaces.Box(np.zeros(2), np.array([2*np.pi, max_speed]))
    
    self.max_ep_len = max_ep_len
    self.goal = goal
    self.return_to_init = return_to_init
    self.init_pos = init_pos
    self.return_to_init = return_to_init
    self.rand_init = rand_init
    self.blending = blending
    self.rand_goal = rand_goal
    self.human_user = human_user
    self.using_reward_shaping = using_reward_shaping
    self.vel_smoothing = vel_smoothing
    
    if not rand_goal:
      assert goal is not None
      self._set_goal(goal)
    
    self.pos = None # position
    self.vel = None # velocity
    self.curr_step = None # timestep in current episode
    self.viewer = None
    self.curr_obs = None # latest observation generated by self._obs()
    self.succ = None # True -> most recent episode was successful
    self.dfot = None
    
    self.goal_reached = None # True -> goal has been reached and agent needs to return to initial position
    self.init_reached = None # True -> initial position has been reached after goal has been reached
        
  def _set_goal(self, goal):
    self.goal = goal
    self.optimal_user_policy = make_synth_user_policy(goal, using_ext_obs=True) # for blending
    self.user_policy = make_human_user_policy() if self.human_user else make_synth_user_policy(
      goal, using_ext_obs=True) # for BCI features
    self.reward_func = make_reward_func(
      goal, using_reward_shaping=self.using_reward_shaping, return_to_init=self.return_to_init)
        
  def _obs(self):
    goal_reached_ind = np.array([1.0 if self.goal_reached else 0.0])
    ext_obs = np.concatenate((self.pos, self.vel, self.init_pos, goal_reached_ind)) # external state observations ("context")
    int_act = self.user_policy(ext_obs) # intended action
    bci_obs = internal_encode_obs(int_act, ext_obs, self.goal) # BCI output
    self.curr_obs = np.concatenate((ext_obs, bci_obs))
    return self.curr_obs
  
  def _step(self, action):   
    opt_act = self.optimal_user_policy(extract_ext_obs(self.curr_obs))
    action = self.blending * opt_act + (1 - self.blending) * action
    
    action = polar_to_cart(normalize_polar(action))
      
    self.vel = self.vel_smoothing * self.vel + (1 - self.vel_smoothing) * action.ravel()
    self.pos += self.vel
    oob = (self.pos < 0).any() or (self.pos >= 1).any()
    self.pos = np.minimum(np.ones(2), np.maximum(np.zeros(2), self.pos))
    
    self.dfot += perp_dist(self.pos, self.init_pos, self.goal)
        
    self.curr_step += 1
    
    if not self.goal_reached:
      self.goal_reached = (np.abs(self.pos - self.goal) <= goal_dist_thresh).all()
      
    if self.return_to_init:
      init_reached = (np.abs(self.pos - self.init_pos) <= goal_dist_thresh).all()
      self.succ = self.goal_reached and init_reached
    else:
      self.succ = self.goal_reached
      
    oot = self.curr_step >= self.max_ep_len # out of time
    
    obs = self._obs()
    r = self.reward_func(self.prev_obs, action, obs)
    done = oot or self.succ or oob
    info = {'goal': self.goal, 'succ': self.succ, 'ttt': self.curr_step, 'dfot': self.dfot}
      
    self.prev_obs = obs
    
    return obs, r, done, info
    
  def _reset(self):
    if self.rand_init:
      self.pos = np.random.random(2)
      self.init_pos = copy(self.pos)
    else:
      self.pos = copy(self.init_pos)
      
    self.vel = np.zeros(2)
    
    if self.rand_goal:
      self._set_goal(np.random.random(2))
    
    self.succ = False
    self.goal_reached = False
    self.init_reached = False
    self.dfot = 0
      
    self.curr_step = 0
    self.prev_obs = self._obs()
    return self.prev_obs
  
  def _render(self, mode='human', close=False):
    if close:
      if self.viewer is not None:
        self.viewer.close()
        self.viewer = None
      return
    
    if self.viewer is None:
      self.viewer = rendering.SimpleImageViewer()
    
    fig = plt.figure()
    canvas = FigureCanvas(fig)
    
    size = 100
    swap_goal_colors_flag = not self.goal_reached or not self.return_to_init
    
    plt.scatter(
      [self.goal[0]], [self.goal[1]], 
      color=('green' if swap_goal_colors_flag else 'gray'), 
      linewidth=0, alpha=0.75, marker='o', s=size*5
    )
    
    if self.return_to_init:
      plt.scatter(
        [self.init_pos[0]], [self.init_pos[1]], 
        color=('gray' if swap_goal_colors_flag else 'green'), 
        linewidth=0, alpha=0.75, marker='s', s=size
      )
      
    plt.scatter(
      [self.pos[0]], [self.pos[1]], 
      color=('orange' if not self.succ else 'teal'), linewidth=0, 
      alpha=0.75, s=size
    )
    
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.axis('off')
    
    agg = canvas.switch_backends(FigureCanvas)
    agg.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    self.viewer.imshow(
      np.fromstring(agg.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3))
    plt.close()

In [None]:
extract_goal_reached_ind = lambda obs: obs[n_ext_obs_dim-1]
extract_pos = lambda obs: obs[:2]
extract_init_pos = lambda obs: obs[n_ext_obs_dim-3:n_ext_obs_dim-1]
extract_bci_obs = lambda obs: obs[n_ext_obs_dim:]
extract_ext_obs = lambda obs: obs[:n_ext_obs_dim]

In [None]:
# simulate user with optimal intended actions
# that go directly to the goal, then directly back to the initial position if needed
def make_synth_user_policy(goal, using_ext_obs=False):
  def synth_user_policy(obs):
    if not using_ext_obs:
      obs = extract_ext_obs(obs)
    init_pos = extract_init_pos(obs)
    g = goal if extract_goal_reached_ind(obs) == 0 else init_pos
    p = extract_pos(obs)
    v = g - p
    return cart_to_polar(v)
  return synth_user_policy

human_user = False

In [None]:
# instead of simulating ground-truth user actions, get keyboard input
init_human_action = lambda: 4 # noop
human_action = init_human_action()
human_active = False

UP = pygkey.LEFT
DOWN = pygkey.RIGHT
RIGHT = pygkey.UP
LEFT = pygkey.DOWN

def key_press(key, mod):
  global human_action
  global human_active
  human_active = True
  a = int(key)
  if a == LEFT:
    human_action = 0
  elif a == RIGHT:
    human_action = 1
  elif a == UP:
    human_action = 2
  elif a == DOWN:
    human_action = 3
    
def key_release(key, mod):
  global human_action
  a = int(key)
  if a in [LEFT, RIGHT, UP, DOWN]:
    human_action = 4
    
cont_act_of_disc = np.array([
  [-max_speed, 0], # left
  [max_speed, 0], # right
  [0, max_speed], # up
  [0, -max_speed], # down
  [0, 0] # noop
])
  
human_policy = lambda obs: cont_act_of_disc[human_action, :]
make_human_user_policy = lambda *args, **kwargs: human_policy

#human_user = True # DEBUG

In [None]:
def make_reward_func(goal, using_reward_shaping=False, return_to_init=False):

  def reward_shaping(obs, init_pos, goal_reached_ind):
    # penalize distance to target
    pos = extract_pos(obs)
    if not return_to_init or (return_to_init and goal_reached_ind == 0):
      phi = -np.linalg.norm(pos - goal)
      if return_to_init and goal_reached_ind == 0:
        phi += -np.linalg.norm(goal - init_pos)
    else:
      phi = -np.linalg.norm(pos - init_pos)
    return phi

  def reward_func(prev_obs, action, obs):
    pos = extract_pos(obs)
    init_pos = extract_init_pos(obs)
    goal_reached_ind = extract_goal_reached_ind(obs)
    
    if ((np.abs(pos - goal) <= goal_dist_thresh).all() and not return_to_init) or (
      (np.abs(pos - init_pos) <= goal_dist_thresh).all() and goal_reached_ind == 1 and return_to_init):
      r = succ_rew_bonus # bonus for reaching target
    else:
      r = 0
      
    if using_reward_shaping:
      prev_goal_reached_ind = extract_goal_reached_ind(prev_obs)
      r += gamma * reward_shaping(obs, init_pos, goal_reached_ind) - reward_shaping(
        prev_obs, init_pos, goal_reached_ind) # standard reward shaping formula

    return r
  
  return reward_func

In [None]:
# one env/task per goal
envs = [CursorControl(goal=goal, human_user=human_user) for goal in goals]

In [None]:
def run_ep(policy, env, max_ep_len=max_ep_len, render=False, blending=0, human_user=False):
  if human_user:
    global human_action
    human_action = init_human_action()
    render = True
    env.reset()
    env.render()
    env.unwrapped.viewer.window.on_key_press = key_press
    env.unwrapped.viewer.window.on_key_release = key_release
    
  old_blending = copy(env.blending)
  env.blending = blending
  
  obs = env.reset()
  
  try:
    policy.reset()
  except:
    pass
  
  done = False
  prev_obs = obs
  rollout = []
  
  for step_idx in range(max_ep_len+1):
    if done:
      break
      
    try:
      policy.observe(obs)
      action = policy.act()
    except:
      action = policy(obs)
      
    obs, r, done, info = env.step(action)
    
    rollout.append((prev_obs, action, r, obs, float(done), info))
    prev_obs = obs
    if render:
      env.render()
      
  env.blending = old_blending
  
  if render:
    env.close()
    
  if human_user:
    global human_active
    human_active = False
    
  return rollout

In [None]:
oracle_policies = [make_synth_user_policy(env.goal) for env in envs]

In [None]:
oracle_decoder_policy = internal_decode_act

In [None]:
# fixed, random, linear decoder
D_rand = np.random.random((n_act_dim, n_obs_dim))
rand_decoder_policy = D_rand.dot

fit bci simulator to recordings

In [None]:
rec_dir = os.path.join(data_dir, 'Bravo1_PythonData')

screen_size = 250
normalize_pos = lambda p: (p + screen_size) / (2 * screen_size)
normalize_vel = lambda v: v / (2 * screen_size)

def rollout_from_raw_rec(raw_rec):
  goal = normalize_pos(raw_rec['TargetPosition'])
  reward_func = make_reward_func(goal)

  contexts = raw_rec['CursorState']
  bci_outs = raw_rec['NeuralFeatures']
  assert contexts.shape[1] == bci_outs.shape[1]
  T_end = contexts.shape[1] - 1
  
  T_start = 0
  while (contexts[:2, T_start] == contexts[:2, T_start+1]).all():
    T_start += 1
  T_start += 1
  
  init_pos = normalize_pos(contexts[:2, T_start])
  goal_reached = False
  goal_reached_ind = np.array([1.0 if goal_reached else 0.0])
  action = None
  succ = False
  info = {'goal': goal, 'succ': succ}

  rollout = []
  for t in range(T_start, T_end):
    pos = normalize_pos(contexts[:2, t])
    vel = normalize_vel(contexts[2:4, t])
    bci_feats = bci_outs[:, t]
    obs = np.concatenate((pos, vel, init_pos, goal_reached_ind, bci_feats))

    if t > T_start:
      done = t == T_end - 1
      r = reward_func(prev_obs, action, obs)
      rollout.append((prev_obs, action, r, obs, float(done), info))

    prev_obs = obs
  return rollout

In [None]:
rec_rollouts = []
for sess_dir in os.listdir(rec_dir):
  if os.path.isdir(os.path.join(rec_dir, sess_dir)):
    for rec_file in os.listdir(os.path.join(rec_dir, sess_dir)):
      if rec_file.endswith('.pkl'):
        with open(os.path.join(rec_dir, sess_dir, rec_file), 'rb') as f:
          raw_rec = pickle.load(f)
          rollout = rollout_from_raw_rec(raw_rec)
          rec_rollouts.append(rollout)

In [None]:
with open(os.path.join(data_dir, 'rec_rollouts.pkl'), 'wb') as f:
  pickle.dump(rec_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'rec_rollouts.pkl'), 'rb') as f:
  rec_rollouts = pickle.load(f)

In [None]:
make_context = lambda ext_obs, goal: np.concatenate((ext_obs, goal))

In [None]:
def vectorize_rollouts(rollouts):
  contexts = []
  bci_outs = []
  obses = []
  acts = []
  for rollout in rollouts:
    more_obses = list(zip(*rollout))[0]
    goal = rollout[0][-1]['goal']
    more_contexts = [make_context(extract_ext_obs(obs), goal) for obs in more_obses]
    more_bci_outs = [extract_bci_obs(obs) for obs in more_obses]
    more_acts = [cart_to_polar(goal - extract_pos(obs)) for obs in more_obses]
    contexts.extend(more_contexts)
    bci_outs.extend(more_bci_outs)
    obses.extend(more_obses)
    acts.extend(more_acts)
  contexts = np.array(contexts)
  bci_outs = np.array(bci_outs)
  obses = np.array(obses)
  acts = np.array(acts)
  return contexts, bci_outs, obses, acts

In [None]:
contexts = None
bci_outs = None
train_idxes = None
val_batch = None
obses = None
acts = None

In [None]:
def process_rec_rollouts(rec_rollouts):
  global contexts
  global bci_outs
  global train_idxes
  global val_batch
  global obses
  global acts
  
  vectorized_rec_rollouts = vectorize_rollouts(rec_rollouts)

  contexts, bci_outs, obses, acts = vectorized_rec_rollouts
  idxes = list(range(contexts.shape[0]))

  random.shuffle(idxes)
  n_train_examples = int(0.9 * len(idxes))
  train_idxes = idxes[:n_train_examples]
  val_idxes = idxes[n_train_examples:]
  val_batch = contexts[val_idxes], bci_outs[val_idxes], obses[val_idxes], acts[val_idxes]

In [None]:
process_rec_rollouts(rec_rollouts)

In [None]:
contexts.shape, bci_outs.shape

In [None]:
# BEGIN DEBUG

In [None]:
from sklearn.manifold import TSNE

In [None]:
X = TSNE(n_components=2).fit_transform(bci_outs)

In [None]:
plt.scatter(X[:, 0], X[:, 1], alpha=0.1)
plt.show()

In [None]:
# END DEBUG

In [None]:
def sample_batch(size):
  idxes = random.sample(train_idxes, size)
  batch = contexts[idxes], bci_outs[idxes], obses[idxes], acts[idxes]
  return batch

In [None]:
iterations = 100000
batch_size = 512
learning_rate = 1e-3

val_update_freq = 100

n_layers = 1
layer_size = 256
activation = tf.nn.relu

In [None]:
with open(os.path.join(data_dir, 'bci_sim_scope.pkl'), 'rb') as f:
  bci_sim_enc_scope, bci_sim_dec_scope = pickle.load(f)

In [None]:
bci_sim_enc_scope = str(uuid.uuid4())

In [None]:
bci_sim_dec_scope = str(uuid.uuid4())

In [None]:
goal_dim = 2
context_ph = tf.placeholder(tf.float32, [None, n_ext_obs_dim + goal_dim]) # see make_context
bci_out_ph = tf.placeholder(tf.float32, [None, bci_dim])

bci_out = build_mlp(
  context_ph, bci_dim, bci_sim_enc_scope, 
  n_layers=n_layers, size=layer_size,
  activation=activation, output_activation=None
)

enc_loss = tf.reduce_mean((bci_out - bci_out_ph)**2)

In [None]:
obs_ph = tf.placeholder(tf.float32, [None, n_obs_dim])
act_ph = tf.placeholder(tf.float32, [None, n_act_dim])

decoded_act = build_mlp(
  obs_ph, n_act_dim, bci_sim_dec_scope, 
  n_layers=n_layers, size=layer_size,
  activation=activation, output_activation=None
)

dec_loss = tf.reduce_mean((act_ph - decoded_act)**2)

In [None]:
loss = enc_loss + dec_loss

In [None]:
update_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

In [None]:
def trained_bci_sim_enc(ext_obs, goal):
  context = make_context(ext_obs, goal)
  with tf.variable_scope(bci_sim_enc_scope, reuse=tf.AUTO_REUSE):
    bci_feats = sess.run(bci_out, feed_dict={context_ph: context[None, :]})[0, :]
  return bci_feats

def trained_bci_sim_dec(obs):
  with tf.variable_scope(bci_sim_dec_scope, reuse=tf.AUTO_REUSE):
    act = sess.run(decoded_act, feed_dict={obs_ph: obs[None, :]})[0, :]
  return act

In [None]:
def compute_batch_loss(batch, step=False, t=None):
  batch_context, batch_bci_out, batch_obs, batch_act = batch 
  feed_dict = {
    context_ph: batch_context,
    bci_out_ph: batch_bci_out,
    obs_ph: batch_obs,
    act_ph: batch_act
  }
  [loss_eval, enc_loss_eval, dec_loss_eval] = sess.run([loss, enc_loss, dec_loss], feed_dict=feed_dict)
  
  if step:
    sess.run(update_op, feed_dict=feed_dict)
  
  d = {'loss': loss_eval, 'enc_loss': enc_loss_eval, 'dec_loss': dec_loss_eval}
  return d

In [None]:
tf.global_variables_initializer().run(session=sess)

In [None]:
train_logs = {
  'train_loss': [],
  'val_loss': [],
  'train_enc_loss': [],
  'val_enc_loss': [],
  'train_dec_loss': [],
  'val_dec_loss': []
}

In [None]:
val_log = None
while len(train_logs['train_loss']) < iterations:
  batch = sample_batch(batch_size)
  
  t = len(train_logs['train_loss'])
  train_log = compute_batch_loss(batch, step=True, t=t)
  if val_log is None or t % val_update_freq == 0:
    val_log = compute_batch_loss(val_batch, step=False, t=t)
  
  print('%d %d %f %f %f %f' % (
    t, iterations, train_log['loss'], val_log['loss'], val_log['enc_loss'], val_log['dec_loss']))
  
  for k, v in train_log.items():
    train_logs['%s%s' % ('train_' if 'loss' in k else '', k)].append(v)
  for k, v in val_log.items():
    train_logs['%s%s' % ('val_' if 'loss' in k else '', k)].append(v)

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Validation Encoder Loss')
plt.plot(train_logs['val_enc_loss'])
plt.yscale('log')
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Validation Decoder Loss')
plt.plot(train_logs['val_dec_loss'])
plt.yscale('log')
plt.show()

In [None]:
internal_encode_obs = lambda action, ext_obs, goal: trained_bci_sim_enc(ext_obs, goal)
internal_decode_act = trained_bci_sim_dec

In [None]:
with open(os.path.join(data_dir, 'bci_sim_scope.pkl'), 'wb') as f:
  pickle.dump((bci_sim_enc_scope, bci_sim_dec_scope), f, pickle.HIGHEST_PROTOCOL)

In [None]:
save_tf_vars(bci_sim_enc_scope, os.path.join(data_dir, 'bci_sim_enc.tf'))

In [None]:
save_tf_vars(bci_sim_dec_scope, os.path.join(data_dir, 'bci_sim_dec.tf'))

In [None]:
load_tf_vars(bci_sim_enc_scope, os.path.join(data_dir, 'bci_sim_enc.tf'))

In [None]:
load_tf_vars(bci_sim_dec_scope, os.path.join(data_dir, 'bci_sim_dec.tf'))

In [None]:
oracle_decoder_policy = internal_decode_act

# one env/task per goal
envs = [CursorControl(goal=goal, human_user=human_user) for goal in goals]

sanity-check env, decoders

In [None]:
task_idx = 0

In [None]:
rollout = run_ep(oracle_policies[task_idx], envs[task_idx], render=True, human_user=human_user)

In [None]:
rollout = run_ep(oracle_decoder_policy, envs[task_idx], render=True, human_user=human_user)

In [None]:
rollout = run_ep(rand_decoder_policy, envs[task_idx], render=True, human_user=human_user)

In [None]:
envs[task_idx].close()

evaluate oracle and random decoders

In [None]:
n_eval_rollouts = 100

def make_env(train_goal=True):
  test_goal = np.random.random(2) if not train_goal else goals[np.random.choice(
    list(range(n_tasks)))]
  env = CursorControl(goal=test_goal)
  return env

def compute_act_pred_err(decoder_policy):
  errs = []
  for rollout in rec_rollouts:
    goal = rollout[0][-1]['goal']
    oracle_policy = make_synth_user_policy(goal)
    try:
      decoder_policy.reset()
    except:
      pass
    for obs, act, rew, next_obs, done, info in rollout:
      opt_act = oracle_policy(obs)
      try:
        decoder_policy.observe(obs)
        decoded_act = decoder_policy.act()
      except:
        decoded_act = decoder_policy(obs)
      errs.append(np.linalg.norm(opt_act - decoded_act))
  return np.mean(errs)

def evaluate_decoder_policy(decoder_policy, env=None, n_rollouts=n_eval_rollouts):
  if env is None:
    env = make_env(train_goal=False)
  rollouts = [run_ep(
    decoder_policy, env, render=human_user, 
    blending=0, human_user=human_user) for _ in range(n_rollouts)]
  perf = {
    'rew': np.mean([sum(x[2] for x in rollout) for rollout in rollouts]),
    'succ': np.mean([1 if is_succ(rollout) else 0 for rollout in rollouts]),
    'ttt': np.mean([get_ttt(rollout) for rollout in rollouts]),
    'dfot': np.mean([get_dfot(rollout) for rollout in rollouts]),
    'ape': compute_act_pred_err(decoder_policy)
  }
  return rollouts, perf

In [None]:
oracle_rollouts, oracle_perf = evaluate_decoder_policy(oracle_decoder_policy, n_rollouts=n_eval_rollouts)

In [None]:
with open(os.path.join(data_dir, 'oracle_eval.pkl'), 'wb') as f:
  pickle.dump((oracle_rollouts, oracle_perf), f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'oracle_eval.pkl'), 'rb') as f:
  oracle_rollouts, oracle_perf = pickle.load(f)

In [None]:
rand_rollouts, rand_perf = evaluate_decoder_policy(rand_decoder_policy, n_rollouts=n_eval_rollouts)

In [None]:
with open(os.path.join(data_dir, 'rand_eval.pkl'), 'wb') as f:
  pickle.dump((rand_rollouts, rand_perf), f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'rand_eval.pkl'), 'rb') as f:
  rand_rollouts, rand_perf = pickle.load(f)

train decoder with imitation learning

[Neuroprosthetic decoder training as imitation learning](https://arxiv.org/abs/1511.04156)

In [None]:
n_demo_rollouts_per_task = 10

In [None]:
demo_policies = [rand_decoder_policy for _ in range(n_tasks)]

In [None]:
demo_blending = 0.75 # mixture coefficient for random vs. oracle decoder (1 -> pure oracle)

In [None]:
def label_actions(rollout, policy):
  for i, x in enumerate(rollout):
    x = list(x)
    x[-1]['action_taken'] = x[1]
    x[1] = policy(x[0]) # replace taken action with action label
    rollout[i] = tuple(x)
  return rollout

In [None]:
demo_rollouts = [label_actions(run_ep(
  demo_policies[task_idx], env, render=human_user, 
  blending=demo_blending, human_user=human_user
), oracle_policies[task_idx]) for _ in range(
  n_demo_rollouts_per_task) for task_idx, env in enumerate(envs)]

In [None]:
demo_rollouts += [label_actions(
  rollout, make_synth_user_policy(rollout[0][-1]['goal'])) for rollout in rec_rollouts]

In [None]:
with open(os.path.join(data_dir, 'demo_rollouts.pkl'), 'wb') as f:
  pickle.dump(demo_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'demo_rollouts.pkl'), 'rb') as f:
  demo_rollouts = pickle.load(f)

In [None]:
# max number of timesteps into the past 
# that the RNN decoder can look at
history_len = 1

In [None]:
def build_mask(i, n): # for RNN training
  x = np.zeros(n)
  x[:i] = 1
  return x

pad_obses = lambda obses, n: list(obses) + [np.zeros(obses[-1].shape)] * (n - len(obses))
pad_acts = lambda acts, n: list(acts) + [np.zeros(acts[-1].shape)] * (n - len(acts))

def vectorize_rollouts(rollouts):
  obses = []
  actions = []
  masks = []
  for rollout in rollouts:
    more_obses, more_actions = list(zip(*rollout))[:2]
    for i in range(max(1, len(more_obses)-history_len+1)):
      unpadded_obses = more_obses[i:i+history_len]
      obses.append(pad_obses(unpadded_obses, history_len))
      actions.append(pad_acts(more_actions[i:i+history_len], history_len))
      masks.append(build_mask(len(unpadded_obses), history_len))
  obses = np.array(obses)
  actions = np.array(actions)
  masks = np.array(masks)
  return obses, actions, masks

In [None]:
demo_obses = None
demo_actions = None
demo_masks = None
train_idxes = None
val_batch = None

In [None]:
def process_demo_rollouts(demo_rollouts):
  global demo_obses
  global demo_actions
  global demo_masks
  global train_idxes
  global val_batch
  
  vectorized_demo_rollouts = vectorize_rollouts(demo_rollouts)

  demo_obses, demo_actions, demo_masks = vectorized_demo_rollouts
  demo_idxes = list(range(demo_obses.shape[0]))

  random.shuffle(demo_idxes)
  n_train_examples = int(0.9 * len(demo_idxes))
  train_idxes = demo_idxes[:n_train_examples]
  val_idxes = demo_idxes[n_train_examples:]
  val_batch = demo_obses[val_idxes], demo_actions[val_idxes], demo_masks[val_idxes]

In [None]:
process_demo_rollouts(demo_rollouts)

In [None]:
demo_obses.shape, demo_actions.shape

In [None]:
def aggregate_rollouts(): # DAgger step
  global demo_rollouts
  rollouts = []
  for oracle_policy, env in zip(oracle_policies, envs):
    if human_user:
      input('Hit ENTER to begin %d episodes...' % n_agg_rollouts)
      time.sleep(5)
    for _ in range(n_agg_rollouts):
      rollouts.append(label_actions(run_ep(
        trained_decoder_policy, env, render=human_user, 
        blending=dagger_blending, human_user=human_user), oracle_policy))
  demo_rollouts += rollouts
  process_demo_rollouts(demo_rollouts)
  
  global dagger_blending
  # dynamically adjust blending coeff
  dagger_blending = 1 - np.mean([1 if is_succ(rollout) else 0 for rollout in rollouts])

In [None]:
def sample_batch(size):
  idxes = random.sample(train_idxes, size)
  batch = demo_obses[idxes], demo_actions[idxes], demo_masks[idxes]
  return batch

In [None]:
iterations = 100000
batch_size = 512
learning_rate = 1e-3

# RNN hidden layer size
num_hidden = 256

val_update_freq = 100 # how frequently to evaluate trained decoder on validation env
n_val_eval_rollouts = 10 # number of rollouts in validation env

# DAgger params
agg_freq = 100
n_agg_rollouts = 10 # number of rollouts to aggregate into dataset per iteration of DAgger
dagger_blending = 0.75 # initial blending coeff

In [None]:
with open(os.path.join(data_dir, 'imi_decoder_scope.pkl'), 'rb') as f:
  imi_decoder_scope = pickle.load(f)

In [None]:
imi_decoder_scope = str(uuid.uuid4())

In [None]:
obs_ph = tf.placeholder(tf.float32, [None, history_len, n_obs_dim]) # observations
act_ph = tf.placeholder(tf.float32, [None, history_len, n_act_dim]) # actions
mask_ph = tf.placeholder(tf.float32, [None, history_len]) # masks for RNN training
init_state_a_ph = tf.placeholder(tf.float32, [None, num_hidden]) # initial state for RNN training
init_state_b_ph = tf.placeholder(tf.float32, [None, num_hidden])

In [None]:
with tf.variable_scope(imi_decoder_scope, reuse=tf.AUTO_REUSE):
  weights = {'out': tf.Variable(tf.random_normal([num_hidden, n_act_dim]))}
  biases = {'out': tf.Variable(tf.random_normal([n_act_dim]))}

  unstacked_X = tf.unstack(obs_ph, history_len, 1)

  lstm_cell = tf.nn.rnn_cell.LSTMCell(num_hidden)

  state = (init_state_a_ph, init_state_b_ph)
  rnn_outputs = []
  rnn_states = []
  for input_ in unstacked_X:
    output, state = lstm_cell(input_, state)
    rnn_outputs.append(tf.matmul(output, weights['out']) + biases['out'])
    rnn_states.append(state)

reshaped_rnn_outputs = tf.reshape(
  tf.concat(rnn_outputs, axis=1), 
  shape=[tf.shape(obs_ph)[0], history_len, n_act_dim]
)

loss = tf.reduce_sum(tf.reduce_mean((
  reshaped_rnn_outputs - act_ph)**2, axis=2) * mask_ph) / tf.reduce_sum(mask_ph)

In [None]:
update_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

In [None]:
class TrainedDecoderPolicy(object):
  
  def __init__(self):
    self.hidden_state = None
    self.action = None
    self.obs_feed = np.zeros((1, history_len, n_obs_dim))
    
  def reset(self):
    self.hidden_state = (np.zeros((1, num_hidden)), np.zeros((1, num_hidden)))
    self.action = None
    
  def _feed_dict(self, obs):
    self.obs_feed[0, 0, :] = obs
    return {
      init_state_a_ph: self.hidden_state[0],
      init_state_b_ph: self.hidden_state[1],
      obs_ph: self.obs_feed
    }
    
  def observe(self, obs):
    with tf.variable_scope(imi_decoder_scope, reuse=tf.AUTO_REUSE):
      self.action, hidden_state = sess.run(
        [rnn_outputs[0], rnn_states[0]], feed_dict=self._feed_dict(obs))
    if history_len > 1:
      self.hidden_state = hidden_state
    
  def act(self):
    assert self.action.shape[0] == 1
    return self.action[0, :]
  
  def get_hidden_state(self):
    assert self.hidden_state.c.shape[0] == 1
    return self.hidden_state.c[0, :]
  
trained_decoder_policy = TrainedDecoderPolicy()

In [None]:
tf.global_variables_initializer().run(session=sess)

In [None]:
train_logs = {
  'train_loss': [],
  'val_loss': [],
  'rew': [],
  'succ': [],
  'ttt': [],
  'dfot': [],
  'ape': []
}

In [None]:
def compute_batch_loss(batch, step=False, t=None):
  batch_obs, batch_act, batch_mask = batch 
  feed_dict = {
    obs_ph: batch_obs,
    act_ph: batch_act,
    mask_ph: batch_mask,
    init_state_a_ph: np.zeros((batch_obs.shape[0], num_hidden)),
    init_state_b_ph: np.zeros((batch_obs.shape[0], num_hidden))
  }
  loss_eval = sess.run(loss, feed_dict=feed_dict)
  
  if step:
    sess.run(update_op, feed_dict=feed_dict)
  
  d = {'loss': loss_eval}
  if not step:
    _, val_perf = evaluate_decoder_policy(
      trained_decoder_policy, 
      env=make_env(train_goal=False), 
      n_rollouts=n_val_eval_rollouts
    )
    d.update(val_perf)
  return d

In [None]:
val_log = None
while len(train_logs['train_loss']) < iterations:
  batch = sample_batch(batch_size)
  
  t = len(train_logs['train_loss'])
  train_log = compute_batch_loss(batch, step=True, t=t)
  if val_log is None or t % val_update_freq == 0:
    val_log = compute_batch_loss(val_batch, step=False, t=t)
    
  if t % agg_freq == 0:
    aggregate_rollouts()
  
  print('%d %d %f %f %f %f %f %f %f' % (
    t, iterations, train_log['loss'], val_log['loss'], 
    val_log['rew'], val_log['succ'], val_log['ttt'], val_log['dfot'], val_log['ape']))
  
  for k, v in train_log.items():
    train_logs['%s%s' % ('train_' if k == 'loss' else '', k)].append(v)
  for k, v in val_log.items():
    train_logs['%s%s' % ('val_' if k == 'loss' else '', k)].append(v)

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Validation Loss')
plt.plot(train_logs['val_loss'])
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Reward')
plt.axhline(y=oracle_perf['rew'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['rew'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['rew'], color='orange', label='Trained')
plt.legend(loc='best')
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Success Rate')
plt.axhline(y=oracle_perf['succ'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['succ'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['succ'], color='orange', label='Trained')
plt.ylim([-0.05, 1.05])
plt.legend(loc='best')
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Time to Target')
plt.axhline(y=oracle_perf['ttt'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['ttt'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['ttt'], color='orange', label='Trained')
plt.ylim([-0.05, 1.05])
plt.legend(loc='best')
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Deviation from Optimal Trajectory')
plt.axhline(y=oracle_perf['dfot'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['dfot'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['dfot'], color='orange', label='Trained')
plt.ylim([-0.05, 1.05])
plt.legend(loc='best')
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Action Prediction Error')
plt.axhline(y=oracle_perf['ape'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['ape'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['ape'], color='orange', label='Trained')
plt.ylim([-0.05, 1.05])
plt.legend(loc='best')
plt.show()

In [None]:
with open(os.path.join(data_dir, 'train_logs.pkl'), 'wb') as f:
  pickle.dump(train_logs, f, pickle.HIGHEST_PROTOCOL)
  
with open(os.path.join(data_dir, 'agg_rollouts.pkl'), 'wb') as f:
  pickle.dump(demo_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'train_logs.pkl'), 'rb') as f:
  train_logs = pickle.load(f)
  
with open(os.path.join(data_dir, 'agg_rollouts.pkl'), 'rb') as f:
  agg_rollouts = pickle.load(f)

In [None]:
n_viz_rollouts = 50

center_goal = np.array([0.5, 0.5])
viz_env = CursorControl(goal=center_goal)

In [None]:
rand_rollouts, _ = evaluate_decoder_policy(
  rand_decoder_policy, env=viz_env, n_rollouts=n_viz_rollouts)

In [None]:
rand_rollouts_sample = random.sample(rand_rollouts, 20)

In [None]:
plot_trajectories(rand_rollouts_sample, center_goal, 'Random', 'rand-traj.pdf')

In [None]:
trained_rollouts, _ = evaluate_decoder_policy(
  trained_decoder_policy, env=viz_env, n_rollouts=n_viz_rollouts)

In [None]:
trained_rollouts_sample = random.sample(trained_rollouts, 20)

In [None]:
plot_trajectories(trained_rollouts_sample, center_goal, 'Trained', 'trained-traj.pdf')

In [None]:
with open(os.path.join(data_dir, 'imi_decoder_scope.pkl'), 'wb') as f:
  pickle.dump(imi_decoder_scope, f, pickle.HIGHEST_PROTOCOL)

In [None]:
save_tf_vars(imi_decoder_scope, os.path.join(data_dir, 'imi_decoder.tf'))

In [None]:
load_tf_vars(imi_decoder_scope, os.path.join(data_dir, 'imi_decoder.tf'))

In [None]:
rollout = run_ep(trained_decoder_policy, viz_env, render=True, human_user=human_user)

In [None]:
viz_env.close()

fine-tune decoder with q-learning + hard-coded reward function (simulating user feedback)

In [None]:
def co_build_act(make_obs_ph, q_func, mu_func, scope='deepq', reuse=tf.AUTO_REUSE):
  with tf.variable_scope(scope, reuse=reuse):
    observations_ph = U.ensure_tf_input(make_obs_ph('observation'))
    pilot_action_ph = tf.placeholder(tf.float32, [None, n_act_dim], name='pilot_action')
    pilot_tol_ph = tf.placeholder(tf.float32, (), name='pilot_tol')

    opt_actions = mu_func(observations_ph.get(), scope=scope, reuse=tf.AUTO_REUSE)
    actions = pilot_tol_ph * pilot_action_ph + (1 - pilot_tol_ph) * opt_actions

    act = U.function(inputs=[observations_ph, pilot_action_ph, pilot_tol_ph], outputs=[actions])
    return act

In [None]:
def co_build_train(
  make_obs_ph, q_func, mu_func, v_func, optimizer, grad_norm_clipping=None, gamma=1.0,
  double_q=True, scope='deepq', reuse=tf.AUTO_REUSE):
  
  act_f = co_build_act(make_obs_ph, q_func, mu_func, scope=scope, reuse=reuse)

  with tf.variable_scope(scope, reuse=reuse):
    # set up placeholders
    obs_t_input = U.ensure_tf_input(make_obs_ph('obs_t'))
    act_t_ph = tf.placeholder(tf.float32, [None, n_act_dim], name='action')
    rew_t_ph = tf.placeholder(tf.float32, [None], name='reward')
    obs_tp1_input = U.ensure_tf_input(make_obs_ph('obs_tp1'))
    done_mask_ph = tf.placeholder(tf.float32, [None], name='done')
    importance_weights_ph = tf.placeholder(tf.float32, [None], name='weight')

    obs_t_input_get = obs_t_input.get()
    obs_tp1_input_get = obs_tp1_input.get()

    # q network evaluation
    q_t_selected = q_func(obs_t_input_get, act_t_ph, scope='q_func', reuse=reuse)  # reuse parameters from act
    q_func_vars = U.scope_vars(U.absolute_scope_name('q_func'))

    # compute estimate of best possible value starting from state at t + 1
    if double_q:
      q_tp1_best_using_online_net = mu_func(obs_tp1_input_get, scope='q_func/mu_func', reuse=reuse)
      q_tp1_best = q_func(obs_tp1_input_get, q_tp1_best_using_online_net, scope='target_q_func')
    else:
      q_tp1_best = v_func(obs_tp1_input_get, scope='target_q_func/v_func')
    q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
    
    target_q_func_vars = U.scope_vars(U.absolute_scope_name('target_q_func'))
    
    # compute RHS of bellman equation
    q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked

    # compute the error (potentially clipped)
    td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
    errors = U.huber_loss(td_error)
    weighted_error = tf.reduce_mean(importance_weights_ph * errors)

    # compute optimization op (potentially with gradient clipping)
    if grad_norm_clipping is not None:
      optimize_expr = U.minimize_and_clip(
        optimizer,
        weighted_error,
        var_list=q_func_vars,
        clip_val=grad_norm_clipping
      )
    else:
        optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars)

    # update_target_fn will be called periodically to copy Q network to target Q network
    update_target_expr = []
    for var, var_target in zip(
      sorted(q_func_vars, key=lambda v: v.name),
      sorted(target_q_func_vars, key=lambda v: v.name)):
      update_target_expr.append(var_target.assign(var))
    update_target_expr = tf.group(*update_target_expr)

    # Create callable functions
    train = U.function(
      inputs=[
        obs_t_input,
        act_t_ph,
        rew_t_ph,
        obs_tp1_input,
        done_mask_ph,
        importance_weights_ph
      ],
      outputs=td_error,
      updates=[optimize_expr]
    )
    update_target = U.function([], [], updates=[update_target_expr])

    q_values = U.function([obs_t_input], q_t_selected)

  return act_f, train, update_target, {'q_values': q_values}

In [None]:
def co_dqn_learn(
  env,
  q_func,
  mu_func,
  v_func,
  lr=1e-3,
  max_timesteps=100000,
  buffer_size=50000,
  train_freq=1,
  batch_size=32,
  print_freq=1,
  checkpoint_freq=10000,
  learning_starts=1000,
  gamma=1.0,
  target_network_update_freq=500,
  exploration_fraction=0.1,
  exploration_final_eps=0.02,
  num_cpu=5,
  callback=None,
  scope='deepq',
  pilot_tol=0,
  pilot_is_human=False,
  reuse=tf.AUTO_REUSE,
  buff_init_rollouts=None):

  # Create all the functions necessary to train the model

  sess = U.get_session()
  if sess is None:
    sess = U.make_session(num_cpu=num_cpu)
    sess.__enter__()

  def make_obs_ph(name):
    return U.BatchInput(env.observation_space.shape, name=name)

  act, train, update_target, debug = co_build_train(
    scope=scope,
    make_obs_ph=make_obs_ph,
    q_func=q_func,
    mu_func=mu_func,
    v_func=v_func,
    optimizer=tf.train.AdamOptimizer(learning_rate=lr),
    gamma=gamma,
    grad_norm_clipping=10,
    reuse=reuse
  )

  act_params = {
    'make_obs_ph': make_obs_ph,
    'q_func': q_func,
    'mu_func': mu_func,
    'v_func': v_func
  }

  replay_buffer = ReplayBuffer(buffer_size)
  
  if buff_init_rollouts is not None:
    for rollout in buff_init_rollouts:
      for obs, _, rew, new_obs, done, info in rollout:
        action = info['action_taken']
        replay_buffer.add(obs, info['action_taken'], rew, new_obs, done)

  # Initialize the parameters and copy them to the target network.
  U.initialize()
  update_target()

  episode_rewards = [0.0]
  episode_outcomes = []
  saved_mean_reward = None
  obs = env.reset()
  prev_t = 0
  rollouts = []

  if pilot_is_human:
    global human_agent_action
    global human_agent_active
    human_agent_action = init_human_action()
    human_agent_active = False

  with tempfile.TemporaryDirectory() as td:
    model_saved = False
    model_file = os.path.join(td, 'model')
    for t in range(max_timesteps):
      act_kwargs = {
        'pilot_action': extract_pilot_action(obs)[None, :],
        'pilot_tol': pilot_tol if not pilot_is_human or (pilot_is_human and human_agent_active) else 0
      }

      action = act(obs[None, :], **act_kwargs)[0][0]
      new_obs, rew, done, info = env.step(action)

      if pilot_is_human:
        env.render()
        time.sleep(delay_between_steps_for_human)

      # Store transition in the replay buffer.
      replay_buffer.add(obs, action, rew, new_obs, float(done))
      obs = new_obs

      episode_rewards[-1] += rew

      if done:
        if t > learning_starts:
          for _ in range(t - prev_t):
            obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size)
            weights, batch_idxes = np.ones_like(rewards), None
            td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights)

        obs = env.reset()

        episode_outcomes.append(info)
        episode_rewards.append(0.0)

        if pilot_is_human:
          global human_agent_action
          human_agent_action = init_human_action()

        prev_t = t

        if pilot_is_human:
          time.sleep(delay_between_episodes_for_human)

      if t > learning_starts and t % target_network_update_freq == 0:
        # Update target network periodically.
        update_target()

      mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 1)
      mean_100ep_succ = round(np.mean([1 if x['succ'] else 0 for x in episode_outcomes[-101:-1]]), 2)
      mean_100ep_ttt = round(np.mean([x['ttt'] for x in episode_outcomes[-101:-1]]), 2)
      mean_100ep_dfot = round(np.mean([x['dfot'] for x in episode_outcomes[-101:-1]]), 2)
      num_episodes = len(episode_rewards)
      if done and print_freq is not None and len(episode_rewards) % print_freq == 0:
        logger.record_tabular('steps', t)
        logger.record_tabular('episodes', num_episodes)
        logger.record_tabular('mean 100 episode reward', mean_100ep_reward)
        logger.record_tabular('mean 100 episode succ', mean_100ep_succ)
        logger.record_tabular('mean 100 episode ttt', mean_100ep_ttt)
        logger.record_tabular('mean 100 episode dfot', mean_100ep_dfot)
        logger.dump_tabular()

      if checkpoint_freq is not None and t > learning_starts and num_episodes > 100 and t % checkpoint_freq == 0 and (saved_mean_reward is None or mean_100ep_reward > saved_mean_reward):
        if print_freq is not None:
          print('Saving model due to mean reward increase:')
          print(saved_mean_reward, mean_100ep_reward)
        U.save_state(model_file)
        model_saved = True
        saved_mean_reward = mean_100ep_reward

    if model_saved:
      U.load_state(model_file)

  reward_data = {
    'rewards': episode_rewards,
    'outcomes': episode_outcomes
  }

  return ActWrapper(act, act_params), reward_data

In [None]:
def make_co_policy(
  env, pilot_tol, pilot_is_human, 
  n_training_episodes, copilot_scope, copilot_q_func, 
  copilot_mu_func, copilot_v_func, buff_init_rollouts, reuse=tf.AUTO_REUSE):
    
  return co_dqn_learn(
    env,
    scope=copilot_scope,
    q_func=copilot_q_func,
    mu_func=copilot_mu_func,
    v_func=copilot_v_func,
    max_timesteps=max_ep_len*n_training_episodes,
    pilot_tol=pilot_tol,
    pilot_is_human=pilot_is_human,
    reuse=reuse,
    buff_init_rollouts=buff_init_rollouts,
    **copilot_dqn_learn_kwargs
  )

In [None]:
delay_between_episodes_for_human = 1
delay_between_steps_for_human = 0

In [None]:
pilot_tol = 0.5 # blending coefficient
# 1 -> full "pilot", i.e., imitation-learned decoder
# 0 -> full "copilot", i.e., Q-learned decoder

pilot_policy = trained_decoder_policy
# we will be learning to assist `trained_decoder_policy`, 
# which was trained earlier with imitation learning and is now a fixed decoder

buff_init_rollouts = agg_rollouts
# initialize replay buffer with rollouts collected during imitation learning

using_reward_shaping = False
# WARNING: make sure this is consistent with the rewards in buff_init_rollouts

return_to_init = False # True -> need to return to initial position

In [None]:
augment_obs = lambda obs, pilot_hidden_state, pilot_action: np.concatenate((obs, pilot_hidden_state, pilot_action))
extract_pilot_action = lambda obs: obs[-n_act_dim:]

In [None]:
if buff_init_rollouts is not None:
  for i, rollout in enumerate(buff_init_rollouts):
    pilot_policy.reset()
    aug_obses = []
    unzipped_rollout = [list(x) for x in zip(*rollout)]
    obses = unzipped_rollout[0]
    next_obses = unzipped_rollout[3]
    obses.append(next_obses[-1])
    for obs in obses:
      pilot_policy.observe(obs)
      pilot_action = pilot_policy.act()
      pilot_hidden_state = pilot_policy.get_hidden_state()
      aug_obses.append(augment_obs(obs, pilot_hidden_state, pilot_action))
    aug_next_obses = aug_obses[1:]
    aug_obses = aug_obses[:-1]
    unzipped_rollout[0] = aug_obses
    unzipped_rollout[3] = aug_next_obses
    buff_init_rollouts[i] = zip(*unzipped_rollout)

In [None]:
n_training_episodes = 500

copilot_mu_func = lambda obs_ph, scope, reuse: deepq.models._mlp(
  [64], obs_ph, n_act_dim, scope=scope+'/mu_func', reuse=reuse)

copilot_v_func = lambda obs_ph, scope, reuse: deepq.models._mlp(
  [64], obs_ph, 1, scope=scope+'/v_func', reuse=reuse)

copilot_l_func = lambda obs_ph, scope, reuse: deepq.models._mlp(
  [64], obs_ph, 1, scope=scope+'/l_func', reuse=reuse)

def copilot_q_func(obs_ph, act_ph, scope, reuse=tf.AUTO_REUSE): # NAF
  opt_act = copilot_mu_func(obs_ph, scope=scope, reuse=reuse)
  adv_std = copilot_l_func(obs_ph, scope=scope, reuse=reuse)
  A = -tf.einsum('ij,ij->i', act_ph - opt_act, act_ph - opt_act) * (adv_std**2)
  V = copilot_v_func(obs_ph, scope=scope, reuse=reuse)
  return A + V

copilot_dqn_learn_kwargs = {
  'lr': 1e-3,
  'exploration_fraction': 0.1,
  'exploration_final_eps': 0.02,
  'target_network_update_freq': 1500,
  'print_freq': 100,
  'num_cpu': 5,
  'gamma': 0.99,
}

In [None]:
copilot_scope = str(uuid.uuid4())

In [None]:
env = CursorControl(
  rand_goal=True, # new, random goal for each episode
  human_user=human_user, return_to_init=return_to_init, 
  using_reward_shaping=using_reward_shaping, 
)

In [None]:
env.unwrapped.observation_space = spaces.Box(
  np.zeros(n_obs_dim + num_hidden + n_act_dim), 
  np.ones(n_obs_dim + num_hidden + n_act_dim))

env.unwrapped.pilot_policy = pilot_policy

def _obs(self):
  goal_reached_ind = np.array([1.0 if self.goal_reached else 0.0])
  ext_obs = np.concatenate((self.pos, self.vel, self.init_pos, goal_reached_ind)) # external state observations ("context")
  int_act = self.user_policy(ext_obs) # intended action
  bci_obs = internal_encode_obs(int_act, ext_obs, self.goal) # BCI output
  self.curr_obs = np.concatenate((ext_obs, bci_obs))
  self.pilot_policy.observe(self.curr_obs) # WARNING: side effect
  self.curr_obs = np.concatenate((self.curr_obs, self.pilot_policy.get_hidden_state(), self.pilot_policy.act()))
  return self.curr_obs
env.unwrapped._obs = types.MethodType(_obs, env.unwrapped)

env.unwrapped._reset_orig = env.unwrapped._reset
def _reset(self):
  self.pilot_policy.reset()
  return self._reset_orig()
env.unwrapped._reset = types.MethodType(_reset, env.unwrapped)

In [None]:
raw_copilot_policy, reward_data = make_co_policy(
  env=env, pilot_tol=pilot_tol, 
  pilot_is_human=human_user, copilot_scope=copilot_scope, 
  copilot_q_func=copilot_q_func, copilot_mu_func=copilot_mu_func,
  copilot_v_func=copilot_v_func, n_training_episodes=n_training_episodes,
  buff_init_rollouts=buff_init_rollouts)

In [None]:
save_tf_vars(copilot_scope, os.path.join(data_dir, 'copilot.tf'))

In [None]:
load_tf_vars(copilot_scope, os.path.join(data_dir, 'copilot.tf'))

In [None]:
with open(os.path.join(data_dir, 'copilot_training_data.pkl'), 'wb') as f:
  pickle.dump((copilot_scope, reward_data), f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'copilot_training_data.pkl'), 'wb') as f:
  copilot_scope, reward_data = pickle.load(f)

In [None]:
def copilot_policy(obs):
  pilot_action = extract_pilot_action(obs)
  with tf.variable_scope(copilot_scope, reuse=tf.AUTO_REUSE):
    return raw_copilot_policy._act(
      obs[None, :], pilot_action=pilot_action, pilot_tol=pilot_tol)[0]

In [None]:
rollout = run_ep(copilot_policy, env, render=True, human_user=human_user)

In [None]:
env.close()