### Disclaimer

Distribution authorized to U.S. Government agencies and their contractors. Other requests for this document shall be referred to the MIT Lincoln Laboratory Technology Office.

This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.

© 2019 Massachusetts Institute of Technology.

The software/firmware is provided to you on an As-Is basis

Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other than as specifically authorized by the U.S. Government may violate any copyrights that exist in this work.


### Treasure Hunt Challenge
Train an agent to find 'treasures' placed around a TESSE environment.

In [4]:
from stable_baselines.common.policies import CnnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, VecVideoRecorder, DummyVecEnv
from stable_baselines import PPO2

%load_ext autoreload
%autoreload 2
from tesse.msgs import *
from tesse_gym.navigation import Navigation
from tesse_gym.treasure_hunt import TreasureHunt
import time
import os

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Path to TESSE build

In [10]:
filename = 'TESSE_BUILD_PATH'

#### Save checkpoints here

In [None]:
log_dir = Path('results/treasure-hunt-agent')
log_dir.mkdir(parents=True)

### Agent training

This section contains code to train an agent using PPO2.

Configure cameras

In [38]:
def set_cameras(tesse_interface):
    tesse_interface.env.request(SetCameraParametersRequest(camera=Camera.RGB_LEFT, 
                                                           height_in_pixels=240, 
                                                           width_in_pixels=320, 
                                                           field_of_view=45, 
                                                           near_clip_plane=0.05, 
                                                           far_clip_plane=50))
    tesse_interface.env.request(SetCameraParametersRequest(camera=Camera.SEGMENTATION, 
                                                           height_in_pixels=240, 
                                                           width_in_pixels=320, 
                                                           field_of_view=45, 
                                                           near_clip_plane=0.05, 
                                                           far_clip_plane=50))
    tesse_interface.env.request(SetCameraPositionRequest(camera=Camera.RGB_LEFT, 
                                                         x=0, 
                                                         y=0, 
                                                         z=-0.1))
    tesse_interface.env.request(SetCameraPositionRequest(camera=Camera.SEGMENTATION, 
                                                         x=0, 
                                                         y=0, 
                                                         z=-0.1))

Callback to save checkpoints

In [7]:
def save_checkpoint_callback(local_vars,  global_vars):
    total_updates = local_vars['update'] 
    if total_updates % 50 == 0:
        local_vars["self"].save(str(log_dir / f'{total_updates:09d}.pkl'))

In [5]:
total_timesteps = 100000       # total training iterations
scene_id = 5                   # small decluttered office
success_dist = 2               # distance from target to be considered found
n_targets = 50                 # number of targets spawned in scene
max_steps = 100                # max episode length
hunt_mode = HuntMode.MULTIPLE  # find as many targets as possible
                               # As opposed to only having to find one
    
def make_unity_env(filename, num_env, base_id):
    """ Create a wrapped Unity environment. """
    def make_env(rank):
        def _thunk():
            env = TreasureHunt(filename, 
                                'localhost',
                                'localhost', 
                                max_steps=max_steps,
                                worker_id=rank, 
                                step_rate=30,
                                scene_id=scene_id,
                                n_targets=n_targets,
                                success_dist = success_dist,
                                init_hook=set_cameras,
                                hunt_mode=hunt_mode,
                                target_found_reward=1)
            return env
        return _thunk
    
    return SubprocVecEnv([make_env(i + base_id) for i in range(num_env)])

Next, we launch environments.

In [40]:
env = make_unity_env(filename, 1)

Specify the agent model for learning.

In [None]:
model = PPO2(CnnLstmPolicy, env, verbose=1, tensorboard_log="./tensorboard/", nminibatches=3)

In [None]:
model.learn(total_timesteps=total_timesteps, callback=save_checkpoint_callback, 
            reset_num_timesteps=False)

In [11]:
model.save(log_dir + "the.policy")  # save finals policy

### Make a video

Demonstrates loading the model and executing it to construct a video.

In [None]:
model = PPO2.load('the.policy')

In [None]:
video_length = 500

video_env = VecVideoRecorder(env,
                             video_folder='videos',
                             record_video_trigger=lambda x: x == 0,
                             video_length=video_length,
                             name_prefix='tesse'
                            )

obs = video_env.reset()
for _ in range(video_length + 1):
    action, _ = model.predict(obs)
    obs, _, _, _ = video_env.step(action)

### Random tests

This just includes a few simple test snippets

In [17]:
env = TreasureHuntEnv(filename, 'localhost', 'localhost', worker_id = 25, scene_id=3)

In [None]:
for i in range(10):
    env.reset()
    time.sleep(.2)
    for _ in range(35):
        (obs, reward, done, _) = env.step(0)
        if done:
            print("collision", i)
            break