In [None]:
from __future__ import division
import pickle
import os
import random
import uuid
import time
import sys
import json
from copy import deepcopy as copy

import gym
from gym import spaces
from gym.envs.classic_control import rendering
import numpy as np
import tensorflow as tf

from pyglet.window import key as pygkey

In [None]:
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
%matplotlib inline

In [None]:
import matplotlib as mpl
mpl.rc('savefig', dpi=300)

In [None]:
data_dir = os.path.join('data', 'typing')

In [None]:
sess = tf.Session()

setup env

In [None]:
n_tasks = 10 # each task/env has a unique goal/string
goal_len = 10

n_act_dim = 97 # number of discrete actions (ascii chars 32-127 + noop)
grid_size = (28, 32) # ecog grid dimensions
bci_dim = grid_size[0]*grid_size[1] # number of BCI channels
n_ext_obs_dim = n_act_dim # number of external state observation dimensions
n_obs_dim = n_ext_obs_dim + bci_dim

gamma = 0.99 # discount factor
max_ep_len = 100 # number of timesteps
succ_rew_bonus = 1 # for reaching goal

In [None]:
is_succ = lambda r: r[-1][-1]['succ']
get_ttt = lambda r: r[-1][-1]['ttt']

In [None]:
def sample_random_goal():
  return ''.join(chr(np.random.randint(32, 127)) for _ in range(goal_len))

In [None]:
gpt2_dir = os.environ['GPT2_DIR']

In [None]:
sys.path.append(os.path.join(gpt2_dir, 'src'))
import model, sample, encoder

In [None]:
model_name = '117M'

enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join(gpt2_dir, 'models', model_name, 'hparams.json')) as f:
  hparams.override_from_dict(json.load(f))
  
context = tf.placeholder(tf.int32, [1, None])

lm_samp = sample.sample_sequence(
  hparams=hparams, length=(goal_len*10),
  start_token=enc.encoder['<|endoftext|>'],
  batch_size=1,
  temperature=1, top_k=0
)[:, 1:]

saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join(gpt2_dir, 'models', model_name))
saver.restore(sess, ckpt)

In [None]:
idx_of_char = {enc.decode([i]): i for i in range(hparams.n_vocab)}
valid_idxes = [idx_of_char[chr(i)] for i in range(32, 127)]

In [None]:
normalize_logits = lambda x: np.log(np.exp(x) / np.sum(np.exp(x)))

In [None]:
def lm_prior(curr_string):
  if curr_string == '':
    curr_string = '<|endoftext|>'
    
  context_tokens = enc.encode(curr_string)

  lm_output = model.model(hparams=hparams, X=context, past=None, reuse=tf.AUTO_REUSE)

  feed_dict = {context: [context_tokens]}
  logits = sess.run(lm_output, feed_dict=feed_dict)
  logits = logits['logits'][0, -1, :hparams.n_vocab]

  return normalize_logits(logits[valid_idxes])

In [None]:
def sample_random_goal():
  filtered = None
  while filtered is None or len(filtered) < goal_len:
    raw = enc.decode(sess.run(lm_samp)[0])
    filtered = ''.join([c for c in raw if ord(c) >= 32 and ord(c) < 127])
  return filtered[:goal_len]

In [None]:
goals = [sample_random_goal() for _ in range(n_tasks)]

In [None]:
goals

In [None]:
with open(os.path.join(data_dir, 'goals.pkl'), 'wb') as f:
  pickle.dump(goals, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'goals.pkl'), 'rb') as f:
  goals = pickle.load(f)

In [None]:
def onehot_encode(i, n):
  x = np.zeros(n)
  x[i] = 1
  return x

onehot_decode = lambda x: np.nonzero(x)[0][0]

In [None]:
def make_decoder(D):
  decode_act = lambda obs: sample_act(D.dot(extract_bci_obs(obs)))
  return decode_act

def make_encoder(D):
  D_inv = np.linalg.pinv(D)
  encode_obs = lambda action: D_inv.dot(onehot_encode(action, n_act_dim))
  return encode_obs

# simulate user with fixed, random projection 
# from intended action distribution to BCI output
D_int = np.random.random((n_act_dim, bci_dim))
internal_decode_act = make_decoder(D_int)
internal_encode_obs = make_encoder(D_int)

In [None]:
NOOP = 96
BACK = 95

In [None]:
class Typing(gym.Env):
  metadata = {
    'render.modes': ['human']
  }
  
  def __init__(
      self, 
      max_ep_len=max_ep_len,
      goal=None,
      rand_goal=False,
      using_reward_shaping=True,
      blending=0 # between 0 and 1 (0 -> no blending, 1 -> ignore agent and take optimal actions)
    ):
    
    lows = np.zeros(n_obs_dim)
    highs = np.ones(n_obs_dim)
      
    self.observation_space = spaces.Box(lows, highs)
    self.action_space = spaces.Discrete(n_act_dim)
    
    self.curr_string = None
    self.curr_step = None
    self.viewer = None
    self.curr_obs = None
    self.succ = None
    self.prev_action = None
    
    self.succ_rew_bonus = succ_rew_bonus
    self.max_ep_len = max_ep_len
    self.goal = goal
    self.blending = blending
    self.using_reward_shaping = using_reward_shaping
    self.rand_goal = rand_goal
    
    if not rand_goal:
      assert goal is not None
      self._set_goal(goal)
      
  def _set_goal(self, goal):
    self.goal = goal
    self.optimal_user_policy = make_synth_user_policy(goal, using_ext_obs=True)
    self.user_policy = make_synth_user_policy(goal, using_ext_obs=True)
    self.reward_func = make_reward_func(goal, using_reward_shaping=self.using_reward_shaping)
    
  def _ext_obs(self):
    return onehot_encode(self.prev_action, n_act_dim) # context
        
  def _obs(self):
    ext_obs = self._ext_obs()
    int_act = self.user_policy(ext_obs) # intended action
    bci_obs = internal_encode_obs(int_act) # BCI output
    self.curr_obs = np.concatenate((ext_obs, bci_obs))
    return self.curr_obs
  
  def _state(self):
    return self.curr_string

  def _step(self, action):  
    if np.random.random() < self.blending:
      action = self.optimal_user_policy(self._ext_obs())
            
    if action == NOOP:
      pass
    elif action == BACK:
      if len(self.curr_string) > 0:
        self.curr_string = self.curr_string[:-1]
    elif 0 <= action and action < 95:
      self.curr_string += chr(32+action)
    else:
      raise ValueError('invalid action')
              
    self.curr_step += 1
    self.succ = self.curr_string == self.goal
    oot = self.curr_step >= self.max_ep_len # out of time
    
    self.prev_action = action
    
    ext_obs = self._ext_obs()
    self.optimal_user_policy.observe(ext_obs)
    self.user_policy.observe(ext_obs)
    
    obs = self._obs()
        
    state = self._state()
    r = self.reward_func(self.prev_state, action, state)
    done = oot or self.succ
    info = {'goal': self.goal, 'succ': self.succ, 'ttt': self.curr_step}
      
    self.prev_state = state
    self.prev_obs = obs
        
    return obs, r, done, info
    
  def _reset(self):
    self.curr_string = ''
    self.succ = False
    self.prev_action = NOOP
    self.optimal_user_policy.reset()
    self.user_policy.reset()
      
    self.curr_step = 0
    
    if self.rand_goal:
      self._set_goal(sample_random_goal())
    
    ext_obs = self._ext_obs()
    self.optimal_user_policy.observe(ext_obs)
    self.user_policy.observe(ext_obs)
    self.prev_obs = self._obs()
    
    self.prev_state = self._state()
    return self.prev_obs
  
  def _render(self, mode='human', close=False):
    if close:
      if self.viewer is not None:
        self.viewer.close()
        self.viewer = None
      return
    
    if self.viewer is None:
      self.viewer = rendering.SimpleImageViewer()
    
    fig = plt.figure()
    canvas = FigureCanvas(fig)
    
    plt.text(0.01, 0.5, self.curr_string)
    plt.axis('off')
    
    agg = canvas.switch_backends(FigureCanvas)
    agg.draw()
    width, height = fig.get_size_inches() * fig.get_dpi()
    self.viewer.imshow(
      np.fromstring(agg.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3))
    plt.close()

In [None]:
temperature = 0 # deterministic
sample_act = lambda logits: np.argmax(
  logits + temperature * np.random.gumbel(0, 1, logits.size))

In [None]:
extract_bci_obs = lambda obs: obs[n_ext_obs_dim:]
extract_ext_obs = lambda obs: obs[:n_ext_obs_dim]

In [None]:
def len_common_prefix(a, b):
  n = min(len(a), len(b))
  for i in range(n):
    if a[i] != b[i]:
      return i
  return n

In [None]:
# simulate user with optimal intended actions
class SynthUser(object):
  
  def __init__(self, goal, using_ext_obs=False):
    self.goal = goal
    self.using_ext_obs = using_ext_obs
    self.curr_string = ''
    
  def reset(self):
    self.curr_string = ''
    
  def observe(self, obs):
    if not self.using_ext_obs:
      obs = extract_ext_obs(obs)
    prev_action = onehot_decode(obs)
    if prev_action == NOOP:
      pass
    elif prev_action == BACK:
      if len(self.curr_string) > 0:
        self.curr_string = self.curr_string[:-1]
    elif 0 <= prev_action and prev_action < 95:
      self.curr_string += chr(32+prev_action)
    
  def __call__(self, obs):
    if len_common_prefix(self.curr_string, self.goal) < len(self.curr_string):
      return BACK
    else:
      if len(self.curr_string) < len(self.goal):
        return ord(self.goal[len(self.curr_string)])-32
      else:
        return NOOP

make_synth_user_policy = SynthUser

In [None]:
def make_reward_func(goal, using_reward_shaping=True):

  def reward_shaping(state):
    n = len(goal) - len_common_prefix(state, goal)
    if len(state) > len(goal):
      n += len(state) - len(goal)
    return -n # length of shortest path to goal

  def reward_func(prev_state, action, state):
    if state == goal:
      r = succ_rew_bonus
    else:
      r = 0
    if using_reward_shaping:
      r += gamma * reward_shaping(state) - reward_shaping(prev_state)
    return r
  
  return reward_func

In [None]:
# one env/task per goal
envs = [Typing(goal=goal) for goal in goals]

In [None]:
def run_ep(policy, env, max_ep_len=max_ep_len, render=False, blending=0):
  old_blending = copy(env.blending)
  env.blending = blending
  
  obs = env.reset()
  try:
    policy.reset()
  except:
    pass
  done = False
  prev_obs = obs
  rollout = []
  
  for step_idx in range(max_ep_len+1):
    if done:
      break
    
    try:
      policy.observe(obs)
    except:
      pass
    
    action = policy(obs)
    obs, r, done, info = env.step(action)
    
    rollout.append((prev_obs, action, r, obs, float(done), info))
    prev_obs = obs
    if render:
      env.render()
      
  env.blending = old_blending
  
  if render:
    env.close()
    
  return rollout

In [None]:
oracle_policies = [make_synth_user_policy(env.goal) for env in envs]

In [None]:
oracle_decoder_policy = internal_decode_act

In [None]:
# fixed, random, linear decoder
D_rand = np.random.random((n_act_dim, bci_dim))
rand_decode_act = lambda obs: sample_act(D_rand.dot(extract_bci_obs(obs)))
rand_decoder_policy = rand_decode_act

sanity-check envs, agents

In [None]:
task_idx = 0

In [None]:
rollout = run_ep(oracle_policies[task_idx], envs[task_idx], render=True)

In [None]:
rollout = run_ep(oracle_decoder_policy, envs[task_idx], render=True)

In [None]:
rollout = run_ep(rand_decoder_policy, envs[task_idx], render=True)

In [None]:
envs[task_idx].close()

In [None]:
n_eval_rollouts = 100

def make_env(train_goal=True):
  test_goal = sample_random_goal() if not train_goal else goals[np.random.choice(
    list(range(n_tasks)))]
  env = Typing(goal=test_goal)
  return env

def evaluate_decoder_policy(decoder_policy, env=None, n_rollouts=n_eval_rollouts):
  if env is None:
    env = make_env(train_goal=False)
  rollouts = [run_ep(
    decoder_policy, env, render=False, blending=0) for _ in range(n_rollouts)]
  perf = {
    'rew': np.mean([sum(x[2] for x in rollout) for rollout in rollouts]),
    'succ': np.mean([1 if is_succ(rollout) else 0 for rollout in rollouts]),
    'ttt': np.mean([get_ttt(rollout) for rollout in rollouts if is_succ(rollout)])
  }
  return rollouts, perf

In [None]:
oracle_rollouts, oracle_perf = evaluate_decoder_policy(oracle_decoder_policy, n_rollouts=n_eval_rollouts)

In [None]:
with open(os.path.join(data_dir, 'oracle_eval.pkl'), 'wb') as f:
  pickle.dump((oracle_rollouts, oracle_perf), f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'oracle_eval.pkl'), 'rb') as f:
  oracle_rollouts, oracle_perf = pickle.load(f)

In [None]:
rand_rollouts, rand_perf = evaluate_decoder_policy(rand_decoder_policy, n_rollouts=n_eval_rollouts)

In [None]:
with open(os.path.join(data_dir, 'rand_eval.pkl'), 'wb') as f:
  pickle.dump((rand_rollouts, rand_perf), f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'rand_eval.pkl'), 'rb') as f:
  rand_rollouts, rand_perf = pickle.load(f)

train decoder with imitation learning

[Neuroprosthetic decoder training as imitation learning](https://arxiv.org/abs/1511.04156)

In [None]:
n_demo_rollouts_per_task = 10

In [None]:
demo_policies = [rand_decoder_policy for _ in range(n_tasks)]

In [None]:
demo_blending = 0.75

In [None]:
def label_actions(rollout, policy):
  try:
    policy.reset()
  except:
    pass
  for i, x in enumerate(rollout):
    x = list(x)
    x[-1]['action_taken'] = x[1]
    try:
      policy.observe(x[0])
    except:
      pass
    x[1] = policy(x[0]) # replace taken action with action label
    rollout[i] = tuple(x)
  return rollout

In [None]:
demo_rollouts = [label_actions(run_ep(
  demo_policies[task_idx], env, render=False, blending=demo_blending
), oracle_policies[task_idx]) for _ in range(
  n_demo_rollouts_per_task) for task_idx, env in enumerate(envs)]

In [None]:
with open(os.path.join(data_dir, 'demo_rollouts.pkl'), 'wb') as f:
  pickle.dump(demo_rollouts, f, pickle.HIGHEST_PROTOCOL)

In [None]:
with open(os.path.join(data_dir, 'demo_rollouts.pkl'), 'rb') as f:
  demo_rollouts = pickle.load(f)

In [None]:
# max number of timesteps into the past 
# that the RNN decoder can look at
history_len = 1

In [None]:
def build_mask(i, n): # for RNN training
  x = np.zeros(n)
  x[:i] = 1
  return x

pad_obses = lambda obses, n: list(obses) + [np.zeros(obses[-1].shape)] * (n - len(obses))
pad_acts = lambda acts, n: list(acts) + [0] * (n - len(acts))

def vectorize_rollouts(rollouts):
  obses = []
  actions = []
  masks = []
  for rollout in rollouts:
    more_obses, more_actions = list(zip(*rollout))[:2]
    for i in range(max(1, len(more_obses)-history_len+1)):
      unpadded_obses = more_obses[i:i+history_len]
      obses.append(pad_obses(unpadded_obses, history_len))
      actions.append(pad_acts(more_actions[i:i+history_len], history_len))
      masks.append(build_mask(len(unpadded_obses), history_len))
  obses = np.array(obses)
  actions = np.array(actions)
  masks = np.array(masks)
  return obses, actions, masks

In [None]:
demo_obses = None
demo_actions = None
demo_masks = None
train_idxes = None
val_batch = None

In [None]:
def process_demo_rollouts(demo_rollouts):
  global demo_obses
  global demo_actions
  global demo_masks
  global train_idxes
  global val_batch
  
  vectorized_demo_rollouts = vectorize_rollouts(demo_rollouts)

  demo_obses, demo_actions, demo_masks = vectorized_demo_rollouts
  demo_idxes = list(range(demo_obses.shape[0]))

  random.shuffle(demo_idxes)
  n_train_examples = int(0.9 * len(demo_idxes))
  train_idxes = demo_idxes[:n_train_examples]
  val_idxes = demo_idxes[n_train_examples:]
  val_batch = demo_obses[val_idxes], demo_actions[val_idxes], demo_masks[val_idxes]

In [None]:
process_demo_rollouts(demo_rollouts)

In [None]:
demo_obses.shape, demo_actions.shape

In [None]:
def aggregate_rollouts(): # DAgger step
  global demo_rollouts
  rollouts = []
  for oracle_policy, env in zip(oracle_policies, envs):
    for _ in range(n_agg_rollouts):
      rollouts.append(label_actions(run_ep(
        trained_decoder_policy, env, render=False, 
        blending=dagger_blending), oracle_policy))
  demo_rollouts += rollouts
  process_demo_rollouts(demo_rollouts)
  
  global dagger_blending
  # dynamically adjust blending coeff
  dagger_blending = 1 - np.mean([1 if is_succ(rollout) else 0 for rollout in rollouts])

In [None]:
def sample_batch(size):
  idxes = random.sample(train_idxes, size)
  batch = demo_obses[idxes], demo_actions[idxes], demo_masks[idxes]
  return batch

In [None]:
def build_mlp(
    input_placeholder,
    output_size,
    scope,
    n_layers=1,
    size=256,
    activation=tf.nn.relu,
    output_activation=tf.nn.softmax,
    reuse=False
  ):
  out = input_placeholder
  with tf.variable_scope(scope, reuse=reuse):
    for _ in range(n_layers):
      out = tf.layers.dense(out, size, activation=activation)
    out = tf.layers.dense(out, output_size, activation=output_activation)
  return out

In [None]:
iterations = 100000
batch_size = 512
learning_rate = 1e-3

# RNN hidden layer size
num_hidden = 512

val_update_freq = 100 # how frequently to evaluate trained decoder on validation env
n_val_eval_rollouts = 10 # number of rollouts in validation env

# DAgger params
agg_freq = 1000
n_agg_rollouts = 1 # number of rollouts to aggregate into dataset per iteration of DAgger
dagger_blending = 0.75 # initial blending coeff

In [None]:
with open(os.path.join(data_dir, 'imi_decoder_scope.pkl'), 'rb') as f:
  imi_decoder_scope = pickle.load(f)

In [None]:
imi_decoder_scope = str(uuid.uuid4())

In [None]:
obs_ph = tf.placeholder(tf.float32, [None, history_len, n_obs_dim]) # observations
act_ph = tf.placeholder(tf.int32, [None, history_len]) # actions
mask_ph = tf.placeholder(tf.float32, [None, history_len]) # masks for RNN training
init_state_a_ph = tf.placeholder(tf.float32, [None, num_hidden]) # initial state for RNN training
init_state_b_ph = tf.placeholder(tf.float32, [None, num_hidden])

In [None]:
with tf.variable_scope(imi_decoder_scope, reuse=tf.AUTO_REUSE):
  weights = {'out': tf.Variable(tf.random_normal([num_hidden, n_act_dim]))}
  biases = {'out': tf.Variable(tf.random_normal([n_act_dim]))}

  unstacked_X = tf.unstack(obs_ph, history_len, 1)

  lstm_cell = tf.nn.rnn_cell.LSTMCell(num_hidden)

  state = (init_state_a_ph, init_state_b_ph)
  rnn_outputs = []
  rnn_states = []
  for input_ in unstacked_X:
    output, state = lstm_cell(input_, state)
    rnn_outputs.append(tf.log(
      tf.nn.softmax(tf.matmul(output, weights['out']) + biases['out'])))
    rnn_states.append(state)

act_log_likelihoods = tf.reshape(
  tf.concat(rnn_outputs, axis=1), 
  shape=[tf.shape(obs_ph)[0], history_len, n_act_dim]
)

selected_act_lls = tf.reduce_sum(
  act_log_likelihoods * tf.one_hot(act_ph, n_act_dim, axis=-1), axis=2)

loss = -tf.reduce_sum(selected_act_lls * mask_ph) / tf.reduce_sum(mask_ph)

In [None]:
update_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

In [None]:
class TrainedDecoderPolicy(object):
  
  def __init__(self, using_prior=True):
    self.using_prior = using_prior
    self.hidden_state = None
    self.action = None
    self.curr_string = None
    self.obs_feed = np.zeros((1, history_len, n_obs_dim))
    
  def reset(self):
    self.hidden_state = (np.zeros((1, num_hidden)), np.zeros((1, num_hidden)))
    self.curr_string = ''
    self.action = None
    
  def _feed_dict(self, obs):
    self.obs_feed[0, 0, :] = obs
    return {
      init_state_a_ph: self.hidden_state[0],
      init_state_b_ph: self.hidden_state[1],
      obs_ph: self.obs_feed
    }
    
  def observe(self, obs):
    with tf.variable_scope(imi_decoder_scope, reuse=tf.AUTO_REUSE):
      logits, self.hidden_state = sess.run(
        [rnn_outputs[0], rnn_states[0]], feed_dict=self._feed_dict(obs))
    logits = logits[0, :]
    
    if self.using_prior:
      cond_logits = normalize_logits(logits[:-2])
      prior_logits = lm_prior(self.curr_string)
      post_probs = np.exp(cond_logits + prior_logits)
      post_probs = post_probs / np.sum(post_probs) * np.sum(np.exp(logits[:-2]))
      logits[:-2] = np.log(post_probs)
    
    self.action = sample_act(logits)
    
    self.curr_string += chr(32+onehot_decode(extract_ext_obs(obs)))
    
  def __call__(self, obs):
    return self.action
  
  def get_hidden_state(self):
    assert self.hidden_state.c.shape[0] == 1
    return self.hidden_state.c[0, :]
  
trained_decoder_policy = TrainedDecoderPolicy()

In [None]:
tf.global_variables_initializer().run(session=sess)

In [None]:
train_logs = {
  'train_loss': [],
  'val_loss': [],
  'rew': [],
  'succ': [],
  'ttt': []
}

In [None]:
def compute_batch_loss(batch, step=False, t=None):
  batch_obs, batch_act, batch_mask = batch 
  feed_dict = {
    obs_ph: batch_obs,
    act_ph: batch_act,
    mask_ph: batch_mask,
    init_state_a_ph: np.zeros((batch_obs.shape[0], num_hidden)),
    init_state_b_ph: np.zeros((batch_obs.shape[0], num_hidden))
  }
  loss_eval = sess.run(loss, feed_dict=feed_dict)
  
  if step:
    sess.run(update_op, feed_dict=feed_dict)
  
  d = {'loss': loss_eval}
  if not step:
    _, val_perf = evaluate_decoder_policy(
      trained_decoder_policy, 
      env=make_env(train_goal=False), 
      n_rollouts=n_val_eval_rollouts
    )
    d.update(val_perf)
  return d

In [None]:
val_log = None
while len(train_logs['train_loss']) < iterations:
  batch = sample_batch(batch_size)
  
  t = len(train_logs['train_loss'])
  train_log = compute_batch_loss(batch, step=True, t=t)
  if val_log is None or t % val_update_freq == 0:
    val_log = compute_batch_loss(val_batch, step=False, t=t)
    
  if t % agg_freq == 0:
    aggregate_rollouts()
  
  print('%d %d %f %f %f %f %f' % (
    t, iterations, train_log['loss'], val_log['loss'], 
    val_log['rew'], val_log['succ'], val_log['ttt']))
  
  for k, v in train_log.items():
    train_logs['%s%s' % ('train_' if k == 'loss' else '', k)].append(v)
  for k, v in val_log.items():
    train_logs['%s%s' % ('val_' if k == 'loss' else '', k)].append(v)

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Validation Loss')
plt.plot(train_logs['val_loss'])
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Reward')
plt.axhline(y=oracle_perf['rew'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['rew'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['rew'], color='orange', label='Trained')
plt.legend(loc='best')
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Success Rate')
plt.axhline(y=oracle_perf['succ'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['succ'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['succ'], color='orange', label='Trained')
plt.legend(loc='best')
plt.show()

In [None]:
plt.xlabel('Iterations')
plt.ylabel('Time to Target')
plt.axhline(y=oracle_perf['ttt'], linestyle='--', color='teal', label='Oracle')
plt.axhline(y=rand_perf['ttt'], linestyle=':', color='gray', label='Random')
plt.plot(train_logs['ttt'], color='orange', label='Trained')
plt.legend(loc='best')
plt.show()

In [None]:
env = make_env(train_goal=False)

In [None]:
env.goal

In [None]:
rollout = run_ep(trained_decoder_policy, env, render=True)

In [None]:
is_succ(rollout)

In [None]:
env.close()