In [210]:
# Auxiliar imports
import sys, os, time, importlib
import matplotlib.pyplot as plt
import numpy as np

# Gym imports
import gym
from gym.vector import SyncVectorEnv

# PyTorch imports
import torch
from torch import nn, optim

# Custom imports
sys.path.append(os.path.abspath('..')) # Add parent directory to path

import ppo_network
importlib.reload(ppo_network) # Prevents caching issues with notebooks
from ppo_network import PPONetworkContinuous

import ppo
importlib.reload(ppo) # Prevents caching issues with notebooks
from ppo import PPOContinuous

import hp_optimizer # TODO - Rename to hp_tuner
importlib.reload(hp_optimizer) # Prevents caching issues with notebooks
from hp_optimizer import HPOptimizer

In [211]:
# BipedalWalker environment
env_id = 'BipedalWalker-v3'
max_episode_steps = 1024
num_envs = 16

env_kwargs = {
    'id': env_id,
    'max_episode_steps': max_episode_steps,
}


# Create vectorized environment
envs_vector = SyncVectorEnv([lambda: gym.make(**env_kwargs)] * num_envs)
states, infos = envs_vector.reset()

In [212]:
# Policy-Value Network
# TODO - Move to PPO-kwargs
input_dims = 24
output_dims = 4

shared_hidden_dims = [1024, 512, 256]
shared_norm = nn.LayerNorm
shared_activation = nn.SiLU

mean_hidden_dims = [256, 128, 64]
mean_norm = nn.LayerNorm
mean_activation = nn.SiLU

log_var_hidden_dims = [256, 128, 64]
log_var_norm = nn.LayerNorm
log_var_activation = nn.SiLU

value_hidden_dims = [256, 128, 64]
value_norm = nn.LayerNorm
value_activation = nn.SiLU

network_kwargs = {
    'input_dims': input_dims,
    'output_dims': output_dims,
    
    'shared_hidden_dims': shared_hidden_dims,
    'shared_norm': shared_norm,
    'shared_activation': shared_activation,
    
    'mean_hidden_dims': mean_hidden_dims,
    'mean_norm': mean_norm,
    'mean_activation': mean_activation,
    
    'log_var_hidden_dims': log_var_hidden_dims,
    'log_var_norm': log_var_norm,
    'log_var_activation': log_var_activation,
    
    'value_hidden_dims': value_hidden_dims,
    'value_norm': value_norm,
    'value_activation': value_activation,
}

network = PPONetworkContinuous(**network_kwargs)

In [213]:
# Test forward passes
now = time.time()
for _ in range(10):
    states_tensor = torch.tensor(states, dtype=torch.float32)
    mean, log_var, value = network(states_tensor)
    std_dev = torch.exp(log_var / 2)
    
    actions_dist = torch.distributions.Normal(mean, std_dev)
    actions = actions_dist.sample()
    
    states, rewards, dones, truncateds, infos = envs_vector.step(actions)
print(
    f'Elapsed time: per vectorized env: {(time.time() - now)/num_envs:.2f} s'
    )

Elapsed time: per vectorized env: 0.01 s


In [None]:
# PPO 
action_dims = 4

lr = 3e-4
final_lr = 5e-6

gamma = 0.99
lam = 0.98

clip_eps = 0.25
final_clip_eps = 0.01

value_coef = 0.7

entropy_coef = 0.1
final_entropy_coef = 0.025

batch_size = 256 # TODO - rename to mini_batch
batch_epochs = 8
batch_shuffle = True
seperate_envs_shuffle = True

iterations = 2048  # TODO - rename to batch

reward_normalize = True
truncated_reward = 10

debug_prints = False

ppo_kwargs = {
    'action_dims': action_dims,
    'num_envs': num_envs,
    'lr': lr,
    'final_lr': final_lr,
    'gamma': gamma,
    'lam': lam,
    'clip_eps': clip_eps,
    'final_clip_eps': final_clip_eps,
    'value_coef': value_coef,
    'entropy_coef': entropy_coef,
    'final_entropy_coef': final_entropy_coef,
    'batch_size': batch_size,
    'batch_epochs': batch_epochs,
    'batch_shuffle': batch_shuffle,
    'seperate_envs_shuffle': seperate_envs_shuffle,
    'iterations': iterations,
    'reward_normalize': reward_normalize,
    'truncated_reward': truncated_reward,
    'debug_prints': debug_prints,   
}

ppo = PPOContinuous(envs_vector, network, **ppo_kwargs)

# now = time.time()
ppo.train(generations=100)
# after = time.time()
# print(f'Elapsed time: {after - now:.2f} s')


### 5 generations ###
# not specifed  - 48s
# cpu:          - 236s
# MPS:          - 110s
#? So slow if device is selected, default is sooo much faster

Generation    0 - Reward:  -124.74, w/o trunc.:  -124.74
