In [7]:
import os
import sys
import argparse
import time
import datetime
import json
import random
import numpy as np
import torch

from bnp_options import *
from utils import *
from eval import *
from train import *
from env.toy_env import ToyEnv
from env.grid_env import GridEnv
from env.line_env import LineEnv
from env.roboturk_env import RoboturkEnv
from env.atari_env import AtariEnv
from env.augmented_atari_env import AugmentedAtariEnv

sys.path.append('../multilevel_discovery')
from models.AtariRamModel import AtariRamModel

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
args = {'K': 1, 'tolerance': 0.1, 'hidden_layer_sizes_policy': [32, 32], 'hidden_layer_sizes_termination': [32, 32], 
        'LSTM_hidden_layer_size': 32, 'LSTM_MLP_hidden_layer_sizes': [32, 32], 'action_space': 'discrete', 
        'learning_rate': 0.001, 'clip': 5.0, 'batch_size': 128, 'max_epochs': 20, 'random_seed': 0, 
        'relaxation_type': 'GS', 'temperature': 1.0, 'temperature_ratio': 0.95, 'env_name': 'atari', 
        'nb_rooms': 5, 'nb_traj': 1000, 'noise_level': 0.0, 'max_steps': 500, 
        'demo_file': '../datasets/atari/seaquest/trajectories.npy', 'baseline': False}

In [10]:
# This will be used to generate the seeds for other RNGs.
random_seed = args['random_seed']
rng_master = np.random.RandomState(random_seed)
np.random.seed(random_seed) # there were some issue with reproducibility
random.seed(random_seed)
torch.manual_seed(random_seed)

env_seed = rng_master.randint(100000)
action_seed = rng_master.randint(100000)
split_seed = rng_master.randint(100000)
rng_env = np.random.RandomState(env_seed)
rng_split = np.random.RandomState(split_seed)

In [47]:
state_dim = 1024
action_dim = 18
model = BNPOptions(None, state_dim, action_dim, device, rng=rng_master, **args)
model.load("runs/atari_Sep07_11-50-54/checkpoint.pth")

In [11]:
model = AtariRamModel(5, statedim=(1024,), actiondim=(18,))
model.load('runs/atari_Sep17_10-55-10/checkpoint.pth')

In [48]:
sys.path.append('../stable-baselines3')

In [49]:
from stable_baselines3.common.env_util import make_vec_env, make_atari_env
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.vec_env import SubprocVecEnv, VecEnv
from stable_baselines3.common.utils import set_random_seed

def augmented_atari_wrapper(env, model):
        env = AtariWrapper(env)
        env = AugmentedAtariEnv(env, model)
        return env

In [50]:
augmented_env = make_vec_env('SeaquestNoFrameskip-v4', n_envs=10, seed=0, wrapper_class=lambda env: augmented_atari_wrapper(env, model))

In [51]:
from stable_baselines3 import PPO

In [52]:
ppo = PPO('CnnPolicy', augmented_env, verbose=1, tensorboard_log="./tensorboard_logs/")

Using cuda device
Wrapping the env in a VecTransposeImage.


In [11]:
ppo.learn(total_timesteps=4000, tb_log_name='augmented_seaquest')

2021-09-14 18:23:29.401997: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/valentinv/.mujoco/mujoco200/bin:/home/valentinv/.mujoco/mujoco200/bin
2021-09-14 18:23:29.402032: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


Logging to ./tensorboard_logs/augmented_seaquest_6
{'r': 140.0, 'l': 2369, 't': 10.846369}
{'r': 180.0, 'l': 2094, 't': 10.242359}
{'r': 100.0, 'l': 1954, 't': 12.152447}
{'r': 180.0, 'l': 2093, 't': 12.187857}
{'r': 180.0, 'l': 2006, 't': 12.041183}
{'r': 180.0, 'l': 2154, 't': 13.763974}
{'r': 120.0, 'l': 1578, 't': 14.16591}
{'r': 120.0, 'l': 2230, 't': 19.907314}
{'r': 140.0, 'l': 2218, 't': 21.87494}
{'r': 60.0, 'l': 1450, 't': 22.500906}
{'r': 180.0, 'l': 2053, 't': 21.154456}
{'r': 140.0, 'l': 1826, 't': 24.47659}
{'r': 140.0, 'l': 2182, 't': 24.201934}
{'r': 200.0, 'l': 2354, 't': 26.678966}
{'r': 140.0, 'l': 2398, 't': 24.485709}
{'r': 240.0, 'l': 3658, 't': 27.429169}
{'r': 180.0, 'l': 2145, 't': 30.954999}
{'r': 140.0, 'l': 2609, 't': 30.884346}
{'r': 120.0, 'l': 2186, 't': 30.384131}
{'r': 180.0, 'l': 2606, 't': 32.316584}
{'r': 140.0, 'l': 1886, 't': 34.032183}
{'r': 160.0, 'l': 1902, 't': 35.517934}
{'r': 220.0, 'l': 2302, 't': 33.990178}
{'r': 200.0, 'l': 2421, 't': 39.9

{'r': 120.0, 'l': 1769, 't': 260.623291}
{'r': 100.0, 'l': 1913, 't': 263.75253}
{'r': 120.0, 'l': 2378, 't': 263.255622}
{'r': 140.0, 'l': 2310, 't': 265.031072}
{'r': 180.0, 'l': 1986, 't': 265.766071}
{'r': 280.0, 'l': 2554, 't': 269.64805}
{'r': 180.0, 'l': 3442, 't': 270.3524}
{'r': 220.0, 'l': 2926, 't': 271.973063}
{'r': 180.0, 'l': 3841, 't': 273.175966}
{'r': 300.0, 'l': 3634, 't': 276.401135}
{'r': 140.0, 'l': 2038, 't': 277.445981}
{'r': 140.0, 'l': 2410, 't': 275.887526}
{'r': 80.0, 'l': 1802, 't': 279.155359}
{'r': 200.0, 'l': 2249, 't': 278.612675}
{'r': 160.0, 'l': 2798, 't': 279.818805}
{'r': 200.0, 'l': 2822, 't': 282.827326}
{'r': 80.0, 'l': 1926, 't': 282.493991}
{'r': 60.0, 'l': 1598, 't': 284.420278}
{'r': 80.0, 'l': 2234, 't': 288.067638}
{'r': 160.0, 'l': 2545, 't': 290.960187}
{'r': 100.0, 'l': 2190, 't': 289.227834}
{'r': 100.0, 'l': 2278, 't': 290.281069}
{'r': 80.0, 'l': 1613, 't': 291.62313}
{'r': 160.0, 'l': 1698, 't': 295.682851}
{'r': 180.0, 'l': 2758, 't

{'r': 160.0, 'l': 1906, 't': 526.073398}
{'r': 60.0, 'l': 1850, 't': 527.737931}
{'r': 80.0, 'l': 1754, 't': 528.489425}
{'r': 220.0, 'l': 2634, 't': 532.083752}
{'r': 180.0, 'l': 3078, 't': 529.564309}
{'r': 180.0, 'l': 2594, 't': 532.914026}
{'r': 260.0, 'l': 2994, 't': 536.913466}
{'r': 200.0, 'l': 2162, 't': 535.999662}
{'r': 260.0, 'l': 3177, 't': 537.148235}
{'r': 180.0, 'l': 2273, 't': 539.519397}
{'r': 100.0, 'l': 1605, 't': 537.056704}
{'r': 60.0, 'l': 1798, 't': 542.509742}
{'r': 80.0, 'l': 1722, 't': 541.765539}
{'r': 200.0, 'l': 2014, 't': 543.068954}
{'r': 100.0, 'l': 1869, 't': 541.672893}
{'r': 140.0, 'l': 1902, 't': 541.57345}
{'r': 340.0, 'l': 3246, 't': 544.536937}
{'r': 220.0, 'l': 2378, 't': 547.025638}
{'r': 100.0, 'l': 1713, 't': 551.456368}
{'r': 220.0, 'l': 2894, 't': 553.060215}
{'r': 200.0, 'l': 3402, 't': 552.584579}
{'r': 160.0, 'l': 3270, 't': 552.469811}
{'r': 200.0, 'l': 2065, 't': 555.295087}
{'r': 160.0, 'l': 1914, 't': 556.201832}
{'r': 140.0, 'l': 200

<stable_baselines3.ppo.ppo.PPO at 0x7f112c11d640>

In [10]:
base_env = make_atari_env('SeaquestNoFrameskip-v4', n_envs=20, seed=0)

In [13]:
ppo2 = PPO('CnnPolicy', base_env, verbose=1, tensorboard_log="./tensorboard_logs/")

Using cuda device
Wrapping the env in a VecTransposeImage.


In [14]:
ppo2.learn(total_timesteps=1000000, tb_log_name='base_seaquest')

2021-09-15 13:58:48.187933: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/valentinv/.mujoco/mujoco200/bin:/home/valentinv/.mujoco/mujoco200/bin
2021-09-15 13:58:48.187965: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


Logging to ./tensorboard_logs/base_seaquest_4
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1.9e+03  |
|    ep_rew_mean     | 57.1     |
| time/              |          |
|    fps             | 741      |
|    iterations      | 1        |
|    time_elapsed    | 55       |
|    total_timesteps | 40960    |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 2.22e+03    |
|    ep_rew_mean          | 92.8        |
| time/                   |             |
|    fps                  | 490         |
|    iterations           | 2           |
|    time_elapsed         | 167         |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.025736084 |
|    clip_fraction        | 0.162       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.88       |
|    explained_variance   

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 3.76e+03   |
|    ep_rew_mean          | 299        |
| time/                   |            |
|    fps                  | 381        |
|    iterations           | 12         |
|    time_elapsed         | 1289       |
|    total_timesteps      | 491520     |
| train/                  |            |
|    approx_kl            | 0.37504393 |
|    clip_fraction        | 0.697      |
|    clip_range           | 0.2        |
|    entropy_loss         | -2.33      |
|    explained_variance   | 0.58       |
|    learning_rate        | 0.0003     |
|    loss                 | -0.126     |
|    n_updates            | 110        |
|    policy_gradient_loss | -0.101     |
|    value_loss           | 0.0355     |
----------------------------------------
---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 3.96e+03  |
|    ep_rew_mean   

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.61e+03   |
|    ep_rew_mean          | 417        |
| time/                   |            |
|    fps                  | 391        |
|    iterations           | 22         |
|    time_elapsed         | 2299       |
|    total_timesteps      | 901120     |
| train/                  |            |
|    approx_kl            | 0.75776017 |
|    clip_fraction        | 0.71       |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.61      |
|    explained_variance   | 0.675      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0949    |
|    n_updates            | 210        |
|    policy_gradient_loss | -0.0933    |
|    value_loss           | 0.0498     |
----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.72e+03   |
|    ep_rew_mean

<stable_baselines3.ppo.ppo.PPO at 0x7f2254736a00>

In [53]:
ppo.load('runs/atari_exp_Sep14_18-35-07/ppo_augmented_seaquest.zip')

<stable_baselines3.ppo.ppo.PPO at 0x7f8b14748a30>

In [17]:
ppo.load('runs/atari_exp_Sep19_16-59-23/ppo_augmented_seaquest_baseline.zip')

<stable_baselines3.ppo.ppo.PPO at 0x7f8b5b0ce730>

In [54]:
obs = augmented_env.reset()

In [67]:
action, _states = ppo.predict(obs)
obs, reward, done, info = augmented_env.step(action)

In [68]:
action

array([21,  3, 22,  2, 23, 16,  2,  1,  6, 21])

In [69]:
reward

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [70]:
info

[{'ale.lives': 4, 'steps_count': 41},
 {'ale.lives': 4, 'steps_count': 1},
 {'ale.lives': 4, 'steps_count': 2},
 {'ale.lives': 4, 'steps_count': 1},
 {'ale.lives': 3, 'steps_count': 16},
 {'ale.lives': 4, 'steps_count': 1},
 {'ale.lives': 3, 'steps_count': 1},
 {'ale.lives': 4, 'steps_count': 1},
 {'ale.lives': 4, 'steps_count': 1},
 {'ale.lives': 4, 'steps_count': 5}]