In [None]:
from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, VecVideoRecorder, DummyVecEnv
from stable_baselines import PPO2

%load_ext autoreload
%autoreload 2
from src.tesse_gym_interface.navigation_env import NavigationEnv
from src.tesse_gym_interface.treasure_hunt_env import TreasureHuntEnv
import time
import os

In [4]:
filename = '/home/za27933/tess/builds/no-thread-v0.4.2/tesse_multiscene_v0.4.2-nothread_linux.x86_64'

### Agent training

This section contains code to train an agent using PPO2.

In [5]:
total_timesteps = 600000  
scene_id = 3
success_dist = 2
n_targets = 1
max_steps = 65

def make_unity_env(filename, num_env):
    """ Create a wrapped Unity environment. """
    def make_env(rank):
        def _thunk():
            env = TreasureHuntEnv(filename, 
                                  'localhost',
                                  'localhost', 
                                  worker_id=rank, 
                                  scene_id=scene_id,
                                  n_targets=n_targets,
                                  success_dist = success_dist)
            return env
        return _thunk
    
    return SubprocVecEnv([make_env(i) for i in range(num_env)])

Next, we launch 4 environments.

In [None]:
env = make_unity_env(filename, 4)

Specify the agent model for learning.

In [15]:
model = PPO2(CnnPolicy, env, verbose=1, tensorboard_log="./tensorboard/")

In [16]:
def save_checkpoint_callback(local_vars,  global_vars):
    total_updates = local_vars['update']
    if total_updates % 20 == 0:
        local_vars["self"].save(log_dir + f'{total_updates}.pkl')

In [None]:
log_dir = 'policy-1/'
model.learn(total_timesteps=total_timesteps, callback=save_checkpoint_callback, 
            reset_num_timesteps=False)

In [None]:
model.save("the.policy")

### Make a video

Demonstrates loading the model and executing it to construct a video.

In [41]:
model = PPO2.load("the.policy")

Loading a model without an environment, this model cannot be trained until it has a valid environment.


In [42]:
video_length = 500

video_env = VecVideoRecorder(env,
                             video_folder='results/trial-3/',
                             record_video_trigger=lambda x: x == 0,
                             video_length=video_length,
                             name_prefix='tesse-2'
                            )

obs = video_env.reset()
for _ in range(video_length + 1):
    action, _ = model.predict(obs)
    obs, _, _, _ = video_env.step(action)

Saving video to  /home/za27933/tess/TESSE_gym_interface/results/trial-3/tesse-2-step-0-to-step-500.mp4


### Random tests

This just includes a few simple test snippets

In [17]:
env = TreasureHuntEnv(filename, 'localhost', 'localhost', worker_id = 25, scene_id=3)

In [None]:
for i in range(10):
    env.reset()
    time.sleep(.2)
    for _ in range(35):
        (obs, reward, done, _) = env.step(0)
        if done:
            print("collision", i)
            break