In [62]:
# Auxiliar imports
import sys, os, time, importlib
import matplotlib.pyplot as plt
import numpy as np

# Gym imports
import gym
from gym.vector import SyncVectorEnv

# PyTorch imports
import torch
from torch import nn, optim

# Custom imports
sys.path.append(os.path.abspath('..')) # Add parent directory to path

import ppo_network
importlib.reload(ppo_network) # Prevents caching issues with notebooks
from ppo_network import PPONetworkContinuous

import ppo_wrapper
importlib.reload(ppo_wrapper) # Prevents caching issues with notebooks
from ppo_wrapper import PPOWrapper

import hp_optimizer
importlib.reload(hp_optimizer) # Prevents caching issues with notebooks
from hp_optimizer import HPOptimizer

In [63]:
# BipedalWalker environment
env_id = 'BipedalWalker-v3'
num_envs = 16

env_kwargs = {
    'id': env_id,
}

# Create vectorized environment
envs_vector = SyncVectorEnv([lambda: gym.make(**env_kwargs)] * num_envs)
states, infos = envs_vector.reset()

In [64]:
# Policy-Value Network
# TODO - Move to PPO-kwargs
input_dims = 24
output_dims = 4

shared_hidden_dims = [1024, 512, 256]
shared_norm = nn.LayerNorm
shared_activation = nn.SiLU

mean_hidden_dims = [256, 128, 64]
mean_norm = nn.LayerNorm
mean_activation = nn.SiLU

log_var_hidden_dims = [256, 128, 64]
log_var_norm = nn.LayerNorm
log_var_activation = nn.SiLU

value_hidden_dims = [256, 128, 64]
value_norm = nn.LayerNorm
value_activation = nn.SiLU

network_kwargs = {
    'input_dims': input_dims,
    'output_dims': output_dims,
    
    'shared_hidden_dims': shared_hidden_dims,
    'shared_norm': shared_norm,
    'shared_activation': shared_activation,
    
    'mean_hidden_dims': mean_hidden_dims,
    'mean_norm': mean_norm,
    'mean_activation': mean_activation,
    
    'log_var_hidden_dims': log_var_hidden_dims,
    'log_var_norm': log_var_norm,
    'log_var_activation': log_var_activation,
    
    'value_hidden_dims': value_hidden_dims,
    'value_norm': value_norm,
    'value_activation': value_activation,
}

network = PPONetworkContinuous(**network_kwargs)

In [65]:
# Test forward passes
now = time.time()
for _ in range(1000):
    states_tensor = torch.tensor(states, dtype=torch.float32)
    mean, log_var, value = network(states_tensor)
    std_dev = torch.exp(log_var / 2)
    
    actions_dist = torch.distributions.Normal(mean, std_dev)
    actions = actions_dist.sample()
    
    states, rewards, dones, truncateds, infos = envs_vector.step(actions)
    if dones.any():
        print(min(rewards), max(rewards), rewards.mean())
print(
    f'Elapsed time: per vectorized env: {(time.time() - now)/num_envs:.2f} s'
    )

  if not isinstance(terminated, (bool, np.bool8)):
  logger.warn(


-100.0 0.2304859267423526 -12.622472972945621
-100.0 0.2873644568522795 -6.36015145400865
-100.0 0.19427680503328523 -6.343521347976601
-100.0 0.20793572314580164 -6.317216069847345
-100.0 0.12989913108944895 -6.327480354468649
-100.0 0.13981885327895363 -6.413346759195362
-100.0 0.15931183652083197 -6.385776503688734
-100.0 -0.0014435435930864257 -6.39929694516138
-100.0 0.1446106155316047 -12.612024927697764
-100.0 0.2925105965733528 -6.3645866462118965
-100.0 0.038820299565792096 -6.3681835183574975
-100.0 0.34818235574165735 -6.32815624402153
-100.0 0.2098088111480065 -6.333597906532697
-100.0 0.21618881428241726 -6.339153132613127
-100.0 0.11594956006606302 -6.364745146981558
-100.0 0.11598382264375687 -6.332459840530996
-100.0 0.05749896252155304 -6.321506784322516
-100.0 0.0444431150754281 -6.326169084674213
-100.0 0.14825553478797515 -6.31801773931152
-100.0 0.012553291082382201 -6.341946622476447
-100.0 0.04517381640275201 -6.364416866095679
-100.0 0.09939028012752534 -6.34166