# A naive approach: let's try PPO

In [None]:
import tianshou as ts 
from tianshou.utils import TensorboardLogger

import torch
from torch.utils.tensorboard import SummaryWriter


import os
from datetime import datetime

from environments import ResettingEnvironment
from networks import SimpleNetHackActor, SimpleNetHackCritic, NetHackObsNet

from nle.env.tasks import *

## Setup

In [None]:
env = ResettingEnvironment(NetHackGold())

In [None]:
num_train_envs = 5
num_test_envs = 5

train_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_train_envs)])
test_envs = ts.env.DummyVectorEnv([lambda: env for _ in range(num_test_envs)])

In [None]:
obs_net = NetHackObsNet(env.observation_space)
actor_net = SimpleNetHackActor(obs_net, env.action_space)
critic_net = SimpleNetHackCritic(obs_net)

In [None]:
# using a single optimizer for actor and critic simplifies the training loop and is more computationally efficient
# BUT gradient updates in one network will influence the gradient updates in the other, and this might create unexpected problems...
combined_params = set(list(actor_net.parameters()) + list(critic_net.parameters()))
optimizer = torch.optim.Adam(combined_params, lr=3e-4)

## PPO

In [None]:
def dist_fn(logits: torch.Tensor):
    return torch.distributions.Categorical(logits=logits)

policy = ts.policy.PPOPolicy(
    actor=actor_net, 
    critic=critic_net, 
    optim=optimizer,
    dist_fn=dist_fn,
    action_space=env.action_space,
    action_scaling=False,
)

In [None]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(2000, num_train_envs))

test_collector = ts.data.Collector(policy, test_envs)

In [None]:
num_epochs = 5
num_steps_per_epoch = 1000

step_per_collect = 10
episode_per_test = 6
batch_size = 10

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
log_path = os.path.join("../logs", "ppo", timestamp)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

In [None]:
trainer = ts.trainer.OnpolicyTrainer(
    policy=policy, 
    train_collector=train_collector, 
    test_collector=test_collector,
    repeat_per_collect=1,
    max_epoch=num_epochs,
    step_per_epoch=num_steps_per_epoch,
    step_per_collect=step_per_collect,
    episode_per_test=episode_per_test,
    batch_size=batch_size,
    logger=logger,
)

In [None]:
for epoch_stats in trainer:
    # TODO a more informative print, plots, logging, etc.
    print(epoch_stats)