In [1]:
from dataclasses import dataclass
from typing import Sequence

import distrax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from gymnax.wrappers.purerl import LogWrapper
import sys
import time
import plotly.express as px

sys.path.append("src/")
from utils import (
    BraxGymnaxWrapper,
    ClipAction,
    NormalizeVecObservation,
    NormalizeVecReward,
    Transition,
    VecEnv,
)


class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))

        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)

In [2]:
@dataclass
class args:   
    seed: int = 0
    save_model: bool = False
    log_results: bool = False

    wandb_project_name: str = "improved-gradient-steps"
    wandb_entity: str = "rpegoud"
    logging_dir: str = "."

    # Algorithm specific arguments
    trainer: str = "base_ppo_continuous"
    env_name: str = "hopper"
    total_timesteps: int = 5e5
    learning_rate: float = 3e-4
    n_agents: int = 16
    num_envs: int = 4
    num_steps: int = 128
    update_epochs: int = 4
    num_minibatches: int = 32
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.0
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    alpha: float = 0.2
    activation: str = "tanh"
    anneal_lr: bool = False
    normalize_env: bool = True
    debug: bool = False

In [3]:
from continuous_algs import (
    base_ppo_continuous,
    parallel_ppo_1_continuous,
    parallel_ppo_1a_continuous,
    parallel_ppo_1b_continuous,
    parallel_ppo_1c_continuous,
    parallel_ppo_1d_continuous,
)

trainers = {
    "base_ppo_continuous": base_ppo_continuous,
    "parallel_ppo_1_continuous": parallel_ppo_1_continuous,
    "parallel_ppo_1a_continuous": parallel_ppo_1a_continuous,
    "parallel_ppo_1b_continuous": parallel_ppo_1b_continuous,
    "parallel_ppo_1c_continuous": parallel_ppo_1c_continuous,
    "parallel_ppo_1d_continuous": parallel_ppo_1d_continuous,
}

In [4]:
# args.trainer = "parallel_ppo_1a_continuous"

t = time.time()
rng = jax.random.PRNGKey(args.seed)
rngs = jax.random.split(rng, args.n_agents)
train_vjit = jax.jit(jax.vmap(trainers[args.trainer](args)))
outs = train_vjit(rngs)
exec_time = time.gmtime(time.time() - t)
print(f'Finished training in {time.strftime("%H:%M:%S", exec_time)}')

Finished training in 00:30:11


In [5]:
outs["metrics"]["returned_episode_returns"]

Array([[[[   0.      ,    0.      ,    0.      ,    0.      ],
         [   0.      ,    0.      ,    0.      ,    0.      ],
         [   0.      ,    0.      ,    0.      ,    0.      ],
         ...,
         [   0.      ,   44.498642,   11.634997,   43.922867],
         [   0.      ,   44.498642,   11.634997,   43.922867],
         [   0.      ,   44.498642,   11.634997,   43.922867]],

        [[   0.      ,   44.498642,   11.634997,   43.922867],
         [   0.      ,   44.498642,   11.634997,   27.215988],
         [   0.      ,   44.498642,   11.634997,   27.215988],
         ...,
         [ 116.00406 ,   32.294777,   11.634997,   52.76649 ],
         [ 116.00406 ,   32.294777,   11.634997,   52.76649 ],
         [ 116.00406 ,   32.294777,   11.634997,   52.76649 ]],

        [[ 116.00406 ,   32.294777,   11.634997,   52.76649 ],
         [ 116.00406 ,   32.294777,   11.634997,   52.76649 ],
         [ 116.00406 ,   32.294777,   11.634997,   52.76649 ],
         ...,
         

In [6]:
import pandas as pd

In [7]:
avg_ep_returns = pd.Series(
    outs["metrics"]["returned_episode_returns"].mean(axis=(0, 2, 3))
)
n_episodes = avg_ep_returns.shape[0]
px.line(avg_ep_returns)