### Disclaimer

Distribution authorized to U.S. Government agencies and their contractors. Other requests for this document shall be referred to the MIT Lincoln Laboratory Technology Office.

This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.

© 2019 Massachusetts Institute of Technology.

The software/firmware is provided to you on an As-Is basis

Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other than as specifically authorized by the U.S. Government may violate any copyrights that exist in this work.


### Treasure Hunt Challenge

This notebook demonstrates using [Stable Baselines](https://stable-baselines.readthedocs.io/en/master/) Proximal Policy Optimization to train a CNN-LSTM agent for the GOSEEK-Challenge. An agent must find as many treasures, placed around a TESSE environment, as possible in the alloted time.

`tesse_gym` allows for interface customizations, some of which are demonstrated here. Specifically, this notebook contains an example of using combined rgb, segmentation, depth, and pose as the agent's observation.

__Contents__
- [Configure Environment](#Configuration)
- [Define Model](#Define-the-Model)
- [Train Model](#Train-the-Model)
- [Visualize Results](#Visualize-Results)

In [None]:
from pathlib import Path

from gym import spaces
from stable_baselines.common.policies import CnnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines import PPO2

from tesse.msgs import *

from tesse_gym.tasks.goseek import MultiModalGoSeek
from tesse_gym import get_network_config

## Configuration

#### Set TESSE build path

In [None]:
filename = Path.home() / 'tess/builds/goseek/v0.0.2/goseek-0.0.2.x86_64'

#### Set environment parameters


In [None]:
training = True

total_timesteps = 5000000
scene_id = [1, 2, 4, 5]  # holdout scenes 3, 6
success_dist = 2
n_targets = [30, 30, 30, 30]
max_steps = 400
n_environments = 4
target_found_reward = 2
step_rate = 20

if training:
    VecEnv = SubprocVecEnv  
else:
    VecEnv = DummyVecEnv  
    n_environments = 1
        
def make_unity_env(filename, num_env):
    """ Create a wrapped Unity environment. """
    def make_env(rank):
        def _thunk():
            env = MultiModalGoSeek(filename, 
                                   network_config=get_network_config(worker_id=rank),
                                   n_targets=n_targets[rank],
                                   success_dist=success_dist,
                                   max_steps=max_steps,
                                   step_rate=step_rate,
                                   scene_id=scene_id[rank],
                                   target_found_reward=target_found_reward)
            return env
        return _thunk
    
    return VecEnv([make_env(i) for i in range(num_env)])

#### Next, we launch environments.

In [None]:
env = make_unity_env(filename, n_environments)

# Define the Model 
The following network assumes an observation of RGB, segmentation, and depth images along with the agent's relative pose. Images are processed using the Stable-Baseline default CNN. The resulting feature vector is concatenated with the pose vector and fed into an LSTM (defined when we initialize PPO

The OpenAI Gym [dictionary space](https://github.com/openai/gym/blob/master/gym/spaces/dict.py) is not supported, so we'll flatten the images into one vector and concatenate that with pose. This 

In [None]:
from stable_baselines.common.policies import nature_cnn
import tensorflow as tf
from tensorflow import keras
from stable_baselines.a2c.utils import conv, linear, conv_to_fc

In [None]:
def postprocess_observation(obs):
    """ Decode observation numpy array into images and pose.
    TODO: don't hardcode this
    
    Args:
        observation (np.ndarray): 1D observation array.
    
    Returns:
        Tuple[np.ndarray, np.ndarray]: Images and pose tensors. """
    imgs = obs[:, :-3]
    pose = obs[:, -3:]
    
    imgs = imgs.reshape(-1, 240, 320, 5)
    return imgs, pose

In [None]:
def postprocess_observation_tensor(observation):
    """ Decode observation tensorflow Tensor into images and pose.
    
    Args:
        observation (tf.Tensor): 1D observation tensor.
    
    Returns:
        Tuple[tf.Tensor, tf.Tensor]: Images and pose tensors. """
    imgs = tf.reshape(observation[:, :-3], shape=(-1, 240, 320, 5))
    pose = observation[:, -3:]
    
    return imgs, pose

In [None]:
def cnn(scaled_images, **kwargs):
    """ Stable Baselines default cnn with batch norm """
    activ = tf.nn.relu
    layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=8, stride=4, init_scale=np.sqrt(2), **kwargs))
    layer_2 = activ(conv(layer_1, 'c2', n_filters=64, filter_size=4, stride=2, init_scale=np.sqrt(2), **kwargs))
    layer_3 = activ(conv(layer_2, 'c3', n_filters=64, filter_size=3, stride=1, init_scale=np.sqrt(2), **kwargs))
    layer_3 = conv_to_fc(layer_3)
    return activ(linear(layer_3, 'fc1', n_hidden=512, init_scale=np.sqrt(2)))

In [None]:
def image_and_pose_network(raw_observation, **kwargs):
    """ Network to process image and pose data.
    
    Args:
        raw_observations (tf.Tensor): 1D tensor containing image and 
            pose data.
        
    Returns:
        tf.Tensor: Feature vector. 
    """
    imgs, pose = postprocess_observation_tensor(raw_observation)
    image_features = cnn(imgs)
    return tf.concat((image_features, pose), axis=-1)

In [None]:
policy_kwargs = {'cnn_extractor': image_and_pose_network}

In [None]:
model = PPO2(CnnLstmPolicy, env, verbose=1, tensorboard_log="./tensorboard/", 
             nminibatches=2,
             gamma=0.995,
             learning_rate=0.00025,
             policy_kwargs=policy_kwargs
            )

# Train the Model

#### Define logging directory and callback function to save checkpoints
This will save intermediate checkpoints

In [None]:
log_dir = Path('results/stable-baselines-ppo-1')
log_dir.mkdir(parents=True, exist_ok=True)

def save_checkpoint_callback(local_vars,  global_vars):
    total_updates = local_vars['update'] 
    if total_updates % 100 == 0:
        local_vars["self"].save(str(log_dir / f'{total_updates:09d}.pkl'))

In [None]:
model.learn(total_timesteps=total_timesteps, callback=save_checkpoint_callback)

# Visualize Results

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

In [None]:
MODEL_PATH = ''
model = PPO2.load(str(MODEL_PATH))

In [None]:
obs = env.reset()
imgs, pose = postprocess_observation(obs)
lstm_state = None 

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(imgs[0, ..., :3])
ax[1].imshow(imgs[0, ..., 3])

In [None]:
done = False
fig, ax = plt.subplots()
n_train_envs = model.act_model.initial_state.shape[0]

for i in range(400):
    actions, lstm_state = model.predict(np.repeat(obs, n_train_envs, 0), 
                                        state=lstm_state, 
                                        deterministic=False)
    
    action = actions[0]
    obs, reward, done, _ = env.step([action])
    done = done[0]
    
    plt.cla()
    imgs, pose = postprocess_observation(obs)
    
    # display RGB image
    ax.imshow((255*imgs[0, ..., :3]).astype(np.uint8))
    fig.canvas.draw()
    
obs = env.reset()
imgs, pose = postprocess_observation(obs)
lstm_state = None 