In [1]:
# Auxiliar imports
import sys, os, time, importlib
import matplotlib.pyplot as plt
import numpy as np

# Gym imports
import gym
from gym.vector import SyncVectorEnv

# PyTorch imports
import torch
from torch import nn, optim

# Custom imports
sys.path.append(os.path.abspath('..')) # Add parent directory to path

import ppo_network
importlib.reload(ppo_network) # Prevents caching issues with notebooks
from ppo_network import PPONetworkDiscrete

import ppo
importlib.reload(ppo) # Prevents caching issues with notebooks
from ppo import PPODiscrete

import hp_tuner
importlib.reload(hp_tuner) # Prevents caching issues with notebooks
from hp_tuner import HPTuner

In [2]:
# LunarLander environment
env_id = 'LunarLander-v2'
max_episode_steps = 1024
num_envs = 16

env_kwargs = {
    'id': env_id,
    'max_episode_steps': max_episode_steps,
}

# Create vectorized environment
envs_vector = SyncVectorEnv([lambda: gym.make(**env_kwargs)] * num_envs)
states, infos = envs_vector.reset()

In [3]:
# Policy-Value Network
network_kwargs = {
    'input_dims': 8,
    'output_dims': 4,
    'shared_hidden_dims': [1024, 512, 256],
    'shared_norm': nn.LayerNorm,
    'shared_activation': nn.SiLU,
    'policy_hidden_dims': [256, 128, 64],
    'policy_norm': nn.LayerNorm,
    'policy_activation': nn.SiLU,
    'value_hidden_dims': [256, 128, 64],
    'value_norm': nn.LayerNorm,
    'value_activation': nn.SiLU,
}

# Create policy-value network
network = PPONetworkDiscrete(**network_kwargs)

In [4]:
# Test forward passes
for _ in range(10):
    states_tensor = torch.tensor(states, dtype=torch.float32)
    policy, value = network(states_tensor)
    
    actions_dist = torch.distributions.Categorical(logits=policy)
    actions = actions_dist.sample().numpy()
    
    states, rewards, dones, truncateds, infos = envs_vector.step(actions)
    print(f"State: {states[0]}"[:65])

State: [-0.01315775  1.4243681  -0.6699277   0.28587648  0.016523
State: [-0.019767    1.4302039  -0.6699555   0.25924158  0.025447
State: [-0.02644834  1.4354489  -0.6789683   0.23289251  0.036168
State: [-0.03304501  1.4400951  -0.6683491   0.20625934  0.044751
State: [-0.03964195  1.4441421  -0.66837287  0.17958583  0.053335
State: [-0.04630079  1.447587   -0.6761035   0.1527146   0.063463
State: [-0.05285187  1.451386   -0.66604125  0.16834451  0.074294
State: [-0.05936918  1.455857   -0.66310704  0.19810703  0.085585
State: [-0.06582022  1.4597296  -0.6547938   0.17153105  0.095203
State: [-0.07227659  1.4641075  -0.65562654  0.19391146  0.105129


  if not isinstance(terminated, (bool, np.bool8)):


In [5]:
# PPO hyperparameters
ppo_kwargs = {
    'network_class': PPONetworkDiscrete,
    'network_kwargs': network_kwargs,
    'num_envs': num_envs,
    'lr': 3e-4,
    'final_lr': 5e-6,
    'gamma': 0.995,
    'lam': 0.99,
    'clip_eps': 0.25,
    'final_clip_eps': 0.01,
    'value_coef': 0.7,
    'entropy_coef': 0.1,
    'final_entropy_coef': 0.025,
    'batch_size': 2048,
    'mini_batch_size': 256,
    'batch_epochs': 8,
    'batch_shuffle': True,
    'seperate_envs_shuffle': True,
    'reward_normalization': True,
    'truncated_reward': -300,
    'debug_prints': False,
}

ppo = PPODiscrete(envs_vector, network, **ppo_kwargs)

# Test training
ppo.train(generations=1)

array([-895.74])

In [None]:
# Hyperparameter tuning
hp_tuner = HPTuner(
    env_kwargs=env_kwargs,
    num_envs=num_envs,
    ppo_class=PPODiscrete,
    ppo_kwargs=ppo_kwargs,
)

# Define hyperparameters to optimize
parameters = [  
    ('lam', [0.95, 0.975, 0.99])
]    

# Optimize hyperparameters
evolutions = hp_tuner.optimize_hyperparameters(
    parameters, generations=25, num_trials = 8,
    )

# Create video
hp_tuner.evolution_video(
    generations=100, video_folder = 'videos', increments=10, max_frames=max_episode_steps,
    )

Optimizing lam with values: [0.95, 0.975, 0.99]
Running trials for lam = 0.95


  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):


Running trials for lam = 0.975


  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
  if not isinstance(terminated, (bool, np.bool8)):
