In [None]:
import mouse
import keyboard

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

import tensorflow as tf
import tensorflow.keras as keras
import gym
import transformers

import ComputerEnv

from .task_guided_behavior_distilation_env_wrapper \
  import TaskGuidedBehaviorDistilationEnvWrapper

In [None]:
num_demonstrations_to_collect = 10
demonstration_length = 10

## Collect Data

In [None]:
task_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
task_encoder = transformers.TFBertModel.from_pretrained('bert-base-uncased')

def encode(sentence: str):
  tokens = task_tokenizer.encode(sentence, return_tensors='tf')
  return task_encoder(tokens)

In [None]:
def collect_demo():
  replayer.start_record()
  sleep(demonstration_length)
  replayer.stop_record()
  task_description = input("Enter task description: ")
  yield {
    'demo': demo, 
    'task_description': task_description,
    'task_description_encoding': encode(task_description),
  }

In [None]:
labeled_demos = [collect_demo() for _ in range(10)]

## Evaluate Policy

In [None]:
class ReplayPolicy:

  def __init__(self, demo, task_description):
    self.demo = demo
    self.task_description = task_description
    self.demo_index = 0

  def __call__(self, obs):
    del obs

    def replay_policy(self, obs):
      replay(self.demo[self.demo_index])
      if obs['task'] == self.demo:
        return {'task_eval': 1.0, 'task_eval_confidence': 1.0}
      else:
        return {'task_eval': -1.0, 'task_eval_confidence': 1.0}

  @property
  def done(self):
    return self.demo_index == len(self.demo)

In [None]:
def evaluate(policy):
  records = []
  for demo, task_description in labeled_demos:
    env = ComputerEnv.LocalGUIEnv(...)
    replay_policy = ReplayPolicy(demo, task_description)
    env = TaskGuidedBehaviorDistilationEnvWrapper(
      env=env,
      task=task_description,
      teacher_policy=replay_policy,
      loss_fn=tf.keras.losses.binary_crossentropy,
      task_space=gym.spaces.Box(None, None, shape=[768,]),
      task_eval_space=gym.spaces.Box(0, 1, shape=[1]),
    )
    step = 0
    obs, done = env.reset(), False
    while not (done or replay_policy.done):
      action = policy(obs)
      obs, reward, done, info = env.step(action)
      record = {
        'step': step,
        'task': obs['task'],
        'task_eval': action['task_eval'],
        'task_eval_confidence': action['task_eval_confidence'],
        'reward': reward,
      }
      record.update({
        f'info_{key}': value
        for key, value in info.items()
      })
      records.append(record)
      step += 1
  return records

In [None]:
class RandomPolicy:

  def __call__(self, obs):
    return {'task_eval': np.random.uniform(0, 1), 'task_eval_confidence': 1.0}

In [None]:
policy = RandomPolicy()
records = evaluate(policy)
df = pd.DataFrame(records)
display(df)
sns.barplot(data=df, x='task', y='reward')
plt.show()
fig, axes = plt.subplots(nrows=3, ncols=1)
sns.lineplot(data=df, x='step', y='task_eval', hue='task', ax=axes[0])
sns.lineplot(data=df, x='step', y='task_eval_confidence', hue='task', ax=axes[1])
sns.lineplot(data=df, x='step', y='reward', hue='task', ax=axes[2])
plt.show()