In [1]:
# Auxiliar imports
import sys, os, time, importlib
import matplotlib.pyplot as plt
import numpy as np

# Gym imports
import gym
from gym.vector import SyncVectorEnv

# PyTorch imports
import torch
from torch import nn, optim

# Custom imports
sys.path.append(os.path.abspath('..')) # Add parent directory to path

import ppo_network
importlib.reload(ppo_network) # Prevents caching issues with notebooks
from ppo_network import PPONetworkContinuous

import ppo
importlib.reload(ppo) # Prevents caching issues with notebooks
from ppo import PPOContinuous

import hp_tuner
importlib.reload(hp_tuner) # Prevents caching issues with notebooks
from hp_tuner import HPTuner

In [2]:
# BipedalWalker environment
env_id = 'BipedalWalker-v3'
max_episode_steps = 1024
num_envs = 16

env_kwargs = {
    'id': env_id,
    'max_episode_steps': max_episode_steps,
}

# Create vectorized environment
envs_vector = SyncVectorEnv([lambda: gym.make(**env_kwargs)] * num_envs)
states, infos = envs_vector.reset()

In [3]:
# Mean-var-value network
network_kwargs = {
    'input_dims': 24,
    'output_dims': 4,
    'shared_hidden_dims': [1024, 1024, 512],
    'shared_norm': nn.LayerNorm,
    'shared_activation': nn.ReLU,
    'mean_hidden_dims': [512, 256, 128, 64],
    'mean_norm': nn.LayerNorm,
    'mean_activation': nn.ReLU,
    'log_var_hidden_dims': [512, 256, 128, 64],
    'log_var_norm': nn.LayerNorm,
    'log_var_activation': nn.ReLU,
    'value_hidden_dims': [512, 256, 128, 64],
    'value_norm': nn.LayerNorm,
    'value_activation': nn.ReLU,
}

# Create the mean-var-value network
network = PPONetworkContinuous(**network_kwargs)

In [4]:
# Test forward passes
for _ in range(10):
    states_tensor = torch.tensor(states, dtype=torch.float32)
    mean, log_var, value = network(states_tensor)
    std_dev = torch.exp(log_var / 2)
    
    actions_dist = torch.distributions.Normal(mean, std_dev)
    actions = actions_dist.sample()
    
    states, rewards, dones, truncateds, infos = envs_vector.step(actions)
    print(f"State: {states[0]}"[:65])

  if not isinstance(terminated, (bool, np.bool8)):
  logger.warn(


State: [-0.02082795 -0.03081958 -0.03166328 -0.0137893   0.477930
State: [-0.03141257 -0.02118443 -0.01232278 -0.01719513  0.522367
State: [-0.05009324 -0.03701037 -0.01954351 -0.02008068  0.589879
State: [-0.06876309 -0.03750424 -0.01577852 -0.02837856  0.660927
State: [-0.09932724 -0.06123087 -0.0172622  -0.00654481  0.717406
State: [-0.12633781 -0.05424867 -0.01464321 -0.03339153  0.80065 
State: [-0.14317894 -0.03389863 -0.012965   -0.06695844  0.879454
State: [-0.16906707 -0.05184544 -0.01146565 -0.07258829  0.958023
State: [-1.8185523e-01 -2.5651516e-02 -9.9529477e-04 -9.0283722e-
State: [-0.20715906 -0.0506266  -0.0123196  -0.11230776  1.066414


In [5]:
# PPO hyperparameters
ppo_kwargs = {
    'network_class': PPONetworkContinuous,
    'network_kwargs': network_kwargs,
    'action_dims': 4,
    'num_envs': num_envs,
    'lr': 3e-4,
    'final_lr': 5e-6,
    'gamma': 0.99,
    'lam': 0.95,
    'clip_eps': 0.25,
    'final_clip_eps': 0.025,
    'value_coef': 0.7,
    'entropy_coef': 0.05,
    'final_entropy_coef': 0.025,
    'batch_size': 2048,
    'mini_batch_size': 512,
    'batch_epochs': 8,
    'batch_shuffle': True,
    'seperate_envs_shuffle': True,
    'reward_normalize': True,
    'truncated_reward': 50,
    'debug_prints': False,
}

ppo = PPOContinuous(envs_vector, **ppo_kwargs)

# Test training
ppo.train(1)


array([-100.48])

In [None]:
# Hyperparameter optimization
hp_tuner = HPTuner(
    env_kwargs=env_kwargs,
    num_envs=num_envs,
    ppo_class=PPOContinuous,
    ppo_kwargs=ppo_kwargs,
)

# Define hyperparameters to optimize
parameters = [
    ('entropy_coef', [0.1, -0.1]),
    ('batch_size', [64, 128, 256, 512]),
    ('batch_epochs', [2, 4, 8, 16]),
    ]

# Optimize hyperparameters
evolutions = hp_tuner.optimize_hyperparameters(
    parameters, generations=50, num_trials = 16,
    )

# Save evolution data
hp_tuner.evolution_video(
    generations=100, video_folder = 'videos', increments=10, max_frames=max_episode_steps,
    )