# A naive approach: let's try PPO

In [1]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import torch
from torch.utils.tensorboard import SummaryWriter


import os

from environments import Resetting
from networks import SimpleNetHackActor, SimpleNetHackCritic, NetHackObsNet

from nle.env.tasks import *

## Setup

In [2]:
env = Resetting(NetHackGold())

In [3]:
num_train_envs = 5
num_test_envs = 5

train_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_test_envs)])

In [4]:
obs_net = NetHackObsNet(env.observation_space)
actor_net = SimpleNetHackActor(obs_net, env.action_space)
critic_net = SimpleNetHackCritic(obs_net)

In [5]:
# using a single optimizer for actor and critic simplifies the training loop and is more computationally efficient
# BUT gradient updates in one network will influence the gradient updates in the other, and this might create unexpected problems...
combined_params = set(list(actor_net.parameters()) + list(critic_net.parameters()))
optimizer = torch.optim.Adam(combined_params, lr=3e-4)

## PPO

In [6]:
def dist_fn(logits: torch.Tensor):
    return torch.distributions.Categorical(logits=logits)

policy = ts.policy.PPOPolicy(
    actor=actor_net, 
    critic=critic_net, 
    optim=optimizer,
    dist_fn=dist_fn,
    action_space=env.action_space,
    action_scaling=False,
)

In [7]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(2000, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

In [8]:
num_epochs = 5
num_steps_per_epoch = 1000

step_per_collect = 10
episode_per_test = 6
batch_size = 10

log_path = os.path.join("../logs", "ppo")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

In [9]:
trainer = ts.trainer.OnpolicyTrainer(
    policy=policy, 
    train_collector=train_collector, 
    test_collector=test_collector,
    repeat_per_collect=1,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

In [10]:
epoch_stats = []
for epoch_stat in trainer:
    epoch_stats.append(epoch_stat)

Epoch #1: 1001it [00:12, 78.85it/s, env_step=1000, gradient_step=100, len=0, n/ep=0, n/st=10, rew=0.00]                          


Epoch #1: test_reward: -2.931679 ± 2.327637, best_reward: -2.931679 ± 2.327637 in #1


Epoch #2: 1001it [00:14, 71.07it/s, env_step=2000, gradient_step=200, len=0, n/ep=0, n/st=10, rew=0.00]                          


Epoch #2: test_reward: -2.898345 ± 2.343906, best_reward: -2.898345 ± 2.343906 in #2


Epoch #3: 1001it [00:12, 80.73it/s, env_step=3000, gradient_step=300, len=0, n/ep=0, n/st=10, rew=0.00]                          


Epoch #3: test_reward: -3.365018 ± 2.612935, best_reward: -2.898345 ± 2.343906 in #2


Epoch #4: 1001it [00:13, 73.30it/s, env_step=4000, gradient_step=400, len=0, n/ep=0, n/st=10, rew=0.00]                          


Epoch #4: test_reward: -3.446687 ± 2.733425, best_reward: -2.898345 ± 2.343906 in #2


Epoch #5: 1001it [00:12, 82.47it/s, env_step=5000, gradient_step=500, len=0, n/ep=0, n/st=10, rew=0.00]                          


Epoch #5: test_reward: -2.871670 ± 1.614332, best_reward: -2.871670 ± 1.614332 in #5


In [12]:
epoch_stats

[EpochStats(epoch=1, train_collect_stat=CollectStats(n_collected_episodes=0, n_collected_steps=10, collect_time=0.02412104606628418, collect_speed=414.575718338259, returns=array([], dtype=float64), returns_stat=None, lens=array([], dtype=int64), lens_stat=None), test_collect_stat=CollectStats(n_collected_episodes=6, n_collected_steps=7060, collect_time=24.492645740509033, collect_speed=288.24979035740836, returns=array([-0.65999967, -0.7499996 , -2.2399983 , -2.3799982 , -4.109999  ,
        -7.4500756 ], dtype=float32), returns_stat=SequenceSummaryStats(mean=-2.931678533554077, std=2.3276374340057373, max=-0.659999668598175, min=-7.450075626373291), lens=array([ 274,  274,  872,  971, 1645, 3024]), lens_stat=SequenceSummaryStats(mean=1176.6666666666667, std=948.39437413393, max=3024.0, min=274.0)), training_stat=PPOTrainingStats(train_time=0.10226130485534668, smoothed_loss={'loss': -0.005536845440346951, 'clip_loss': 3.2857062004509886e-09, 'vf_loss': 0.00011184314902493498, 'ent_lo