# A naive approach: let's try PPO

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import torch
from torch.utils.tensorboard import SummaryWriter


import os
from datetime import datetime

from utils import ResettingEnvironment
from networks import SimpleNetHackActor, SimpleNetHackCritic

from nle.env.tasks import *

## Setup

In [2]:
env = ResettingEnvironment(NetHackGold())

In [3]:
num_train_envs = 5
num_test_envs = 5

train_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_test_envs)])

In [4]:
actor_net = SimpleNetHackActor(env.observation_space, env.action_space)
critic_net = SimpleNetHackCritic(env.observation_space)

In [5]:
# using a single optimizer for actor and critic simplifies the training loop and is more computationally efficient
# BUT gradient updates in one network will influence the gradient updates in the other, and this might create unexpected problems...
combined_params = list(actor_net.parameters()) + list(critic_net.parameters())
optimizer = torch.optim.Adam(combined_params, lr=3e-4)

## PPO

In [6]:
def dist_fn(logits: torch.Tensor):
    return torch.distributions.Categorical(logits=logits)

policy = ts.policy.PPOPolicy(
    actor=actor_net, 
    critic=critic_net, 
    optim=optimizer,
    dist_fn=dist_fn,
    action_space=env.action_space,
    action_scaling=False,
)

In [7]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(2000, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

In [8]:
num_epochs = 5
num_steps_per_epoch = 1000

step_per_collect = 10
episode_per_test = 6
batch_size = 10

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("logs", "ppo", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

In [9]:
trainer = ts.trainer.OnpolicyTrainer(
    policy=policy, 
    train_collector=train_collector, 
    test_collector=test_collector,
    repeat_per_collect=1,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

In [12]:
for epoch_stats in trainer:
    # TODO a more informative print, plots, logging, etc.
    print(epoch_stats)

Epoch #1: 1001it [00:12, 78.86it/s, env_step=1000, gradient_step=300, len=523, n/ep=0, n/st=10, rew=-1.73]                          


Epoch #1: test_reward: -2.983333 ± 2.134359, best_reward: -2.363333 ± 2.438399 in #0
Epoch: EpochStats(epoch=1, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.020969867706298828, collect_speed=476.8747299724856, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=7153, collect_time=20.25036096572876, collect_speed=353.2282714419546, returns=array([-0.6 , -1.44, -2.08, -2.53, -4.16, -7.09]), returns_stat=SequenceSummaryStats(mean=-2.983333333333307, std=2.1343591283775853, max=-0.6000000000000003, min=-7.089999999999893), lens=array([ 258,  537,  833, 1031, 1581, 2913]), lens_stat=SequenceSummaryStats(mean=1192.1666666666667, std=872.6726031119703, max=2913.0, min=258.0)), training_stat=PPOTrainingStats(train_time=0.10072898864746094, smoothed_loss={'loss': 3.236992934944283e-05, 'clip_loss': 1.117121399829557e-09, 'vf_loss

Epoch #2: 1001it [00:12, 78.33it/s, env_step=2000, gradient_step=400, len=523, n/ep=0, n/st=10, rew=-1.73]                          


Epoch #2: test_reward: -2.648333 ± 2.292105, best_reward: -2.363333 ± 2.438399 in #0
Epoch: EpochStats(epoch=2, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.021883249282836914, collect_speed=456.9705289535327, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=6400, collect_time=19.48100209236145, collect_speed=328.52519442567365, returns=array([-0.56, -1.32, -1.92, -1.3 , -3.4 , -7.39]), returns_stat=SequenceSummaryStats(mean=-2.6483333333333103, std=2.2921054125458618, max=-0.5600000000000003, min=-7.389999999999887), lens=array([ 274,  553,  765,  510, 1459, 2839]), lens_stat=SequenceSummaryStats(mean=1066.6666666666667, std=874.4660593121318, max=2839.0, min=274.0)), training_stat=PPOTrainingStats(train_time=0.10609817504882812, smoothed_loss={'loss': 1.4486352015410375e-05, 'clip_loss': -5.55068220275956e-10, 'vf_l

Epoch #3: 1001it [00:12, 79.26it/s, env_step=3000, gradient_step=500, len=523, n/ep=0, n/st=10, rew=-1.73]                          


Epoch #3: test_reward: -2.711667 ± 2.068860, best_reward: -2.363333 ± 2.438399 in #0
Epoch: EpochStats(epoch=3, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.021863222122192383, collect_speed=457.3891233465284, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=6465, collect_time=19.63693904876709, collect_speed=329.2264636532498, returns=array([-0.61, -0.99, -1.  , -2.92, -4.62, -6.13]), returns_stat=SequenceSummaryStats(mean=-2.7116666666666407, std=2.0688597235084405, max=-0.6100000000000003, min=-6.129999999999914), lens=array([ 283,  365,  418, 1151, 1721, 2527]), lens_stat=SequenceSummaryStats(mean=1077.5, std=826.0721820761186, max=2527.0, min=283.0)), training_stat=PPOTrainingStats(train_time=0.10517597198486328, smoothed_loss={'loss': 1.427446832167334e-05, 'clip_loss': -2.499669899957979e-09, 'vf_loss': 2.85539

Epoch #4: 1001it [00:13, 76.17it/s, env_step=4000, gradient_step=600, len=523, n/ep=0, n/st=10, rew=-1.73]                          


Epoch #4: test_reward: -2.315000 ± 2.001564, best_reward: -2.315000 ± 2.001564 in #4
Epoch: EpochStats(epoch=4, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.02273416519165039, collect_speed=439.8666023449462, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=5661, collect_time=17.361329078674316, collect_speed=326.0695062196393, returns=array([-0.02, -0.66, -1.12, -2.26, -4.18, -5.65]), returns_stat=SequenceSummaryStats(mean=-2.3149999999999795, std=2.0015639718313327, max=-0.02, min=-5.649999999999924), lens=array([  18,  267,  539, 1001, 1701, 2135]), lens_stat=SequenceSummaryStats(mean=943.5, std=760.9870235424518, max=2135.0, min=18.0)), training_stat=PPOTrainingStats(train_time=0.10118818283081055, smoothed_loss={'loss': 1.4058788274269318e-05, 'clip_loss': 1.7695129150840927e-09, 'vf_loss': 2.8114101323808426e-05

Epoch #5: 1001it [00:13, 72.89it/s, env_step=5000, gradient_step=700, len=523, n/ep=0, n/st=10, rew=-1.73]                          


Epoch #5: test_reward: -2.535000 ± 1.931474, best_reward: -2.315000 ± 2.001564 in #4
Epoch: EpochStats(epoch=5, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.022660017013549805, collect_speed=441.30593520827415, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=6078, collect_time=18.988116979599, collect_speed=320.09493129467535, returns=array([-0.66, -1.44, -1.76, -1.81, -2.96, -6.58]), returns_stat=SequenceSummaryStats(mean=-2.5349999999999815, std=1.93147396220258, max=-0.6600000000000004, min=-6.579999999999904), lens=array([ 274,  553,  614,  762, 1134, 2741]), lens_stat=SequenceSummaryStats(mean=1013.0, std=814.5088499621515, max=2741.0, min=274.0)), training_stat=PPOTrainingStats(train_time=0.10525321960449219, smoothed_loss={'loss': 1.410766929438978e-05, 'clip_loss': 1.8142162561129994e-09, 'vf_loss': 2.8211754

Epoch #6: 1001it [00:13, 74.57it/s, env_step=6000, gradient_step=800, len=523, n/ep=0, n/st=10, rew=-1.73]                          


Epoch #6: test_reward: -2.428333 ± 2.337095, best_reward: -2.315000 ± 2.001564 in #4
Epoch: EpochStats(epoch=6, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.022851228713989258, collect_speed=437.6132296937764, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=5811, collect_time=19.133152961730957, collect_speed=303.71366452893733, returns=array([-0.06, -0.71, -1.69, -1.53, -3.49, -7.09]), returns_stat=SequenceSummaryStats(mean=-2.428333333333311, std=2.3370951818205246, max=-0.060000000000000005, min=-7.089999999999893), lens=array([  42,  313,  650,  715, 1381, 2710]), lens_stat=SequenceSummaryStats(mean=968.5, std=880.9810346047941, max=2710.0, min=42.0)), training_stat=PPOTrainingStats(train_time=0.11280608177185059, smoothed_loss={'loss': 1.3667898247149424e-05, 'clip_loss': -1.788139403213762e-10, 'vf_loss': 2.733

Epoch #7: 1001it [00:14, 67.96it/s, env_step=7000, gradient_step=900, len=1149, n/ep=0, n/st=10, rew=-3.06]                          


Epoch #7: test_reward: -2.458333 ± 2.130614, best_reward: -2.315000 ± 2.001564 in #4
Epoch: EpochStats(epoch=7, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.024194717407226562, collect_speed=413.3133622388648, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=5860, collect_time=19.84908103942871, collect_speed=295.2277734349288, returns=array([-0.67, -1.42, -1.61, -0.88, -3.36, -6.81]), returns_stat=SequenceSummaryStats(mean=-2.4583333333333126, std=2.130613500588207, max=-0.6700000000000004, min=-6.809999999999899), lens=array([ 245,  517,  650,  413, 1344, 2691]), lens_stat=SequenceSummaryStats(mean=976.6666666666666, std=841.1235871671231, max=2691.0, min=245.0)), training_stat=PPOTrainingStats(train_time=0.12093210220336914, smoothed_loss={'loss': 7.57765032858515e-05, 'clip_loss': -1.0672957451163256e-09, 'vf_loss

Epoch #8: 1001it [00:13, 75.34it/s, env_step=8000, gradient_step=1000, len=1884, n/ep=0, n/st=10, rew=-5.18]                          


Epoch #8: test_reward: -2.451667 ± 2.436099, best_reward: -2.315000 ± 2.001564 in #4
Epoch: EpochStats(epoch=8, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.023846149444580078, collect_speed=419.3549161150993, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=5616, collect_time=18.7675302028656, collect_speed=299.2402270993813, returns=array([-0.01, -0.67, -1.48, -1.76, -3.43, -7.36]), returns_stat=SequenceSummaryStats(mean=-2.4516666666666436, std=2.4360994551855555, max=-0.01, min=-7.3599999999998875), lens=array([   3,  269,  596,  654, 1368, 2726]), lens_stat=SequenceSummaryStats(mean=936.0, std=903.8091612724447, max=2726.0, min=3.0)), training_stat=PPOTrainingStats(train_time=0.10635495185852051, smoothed_loss={'loss': 4.757172201607318e-05, 'clip_loss': 1.436099410501157e-09, 'vf_loss': 9.514066794963583e-05, 'e

Epoch #9: 1001it [00:13, 75.51it/s, env_step=9000, gradient_step=1100, len=1884, n/ep=0, n/st=10, rew=-5.18]                          


KeyboardInterrupt: 