In [None]:
# stdlib
import pathlib

# Environment
import gym

# Jax, Flax, Optax
import jax
from flax.metrics import tensorboard
from flax.training import checkpoints
import optax

# PPO
import proximal_policy_optimization as ppo

# other
from tqdm.notebook import tqdm


This is need on my machine

In [None]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.85

Functions for creating the environment

In [None]:
def make_env():
    return gym.make("CartPole-v1")


def make_vector_env(num_envs: int = 8, asynchronous: bool = False):
    if asynchronous:
        env = gym.vector.AsyncVectorEnv([make_env for _ in range(num_envs)])  # type: ignore
    else:
        env = gym.vector.SyncVectorEnv([make_env for _ in range(num_envs)])  # type: ignore
    return env


train_env = make_vector_env()
eval_env = make_env()


Configuration

In [None]:
def ppo_num_opt_steps(
    total_frames: int, horizon: int, n_actors: int, epochs: int, mini_batch_size: int
) -> int:
    """Compute the number of optimization steps."""
    batch_size = horizon * n_actors
    # Number of frames we see per train step
    frames_per_train_step = batch_size
    # Number of times we call optimizer per step
    opt_steps_per_train_step = epochs * (batch_size // mini_batch_size)
    # Number of train steps
    num_train_steps = total_frames // frames_per_train_step
    # Total number of optimizer calls
    total_opt_steps = opt_steps_per_train_step * num_train_steps

    return total_opt_steps


total_frames = int(1e5)
n_actors = 8
horizon = 32
mini_batch_size = 256
epochs = 20
total_opt_steps = ppo_num_opt_steps(
    total_frames, horizon, n_actors, epochs, mini_batch_size
)

epsilon = optax.linear_schedule(0.1, 0.0, total_opt_steps)
learning_rate = optax.linear_schedule(2.5e-4, 0.0, total_opt_steps)
max_grad_norm = 0.5

# Configuration
config = ppo.PPOConfig(
    n_actors=n_actors,
    total_frames=total_frames,
    horizon=horizon,
    mini_batch_size=mini_batch_size,
    lam=0.8,
    gamma=0.98,
    epochs=epochs,
    c1=1.0,
    c2=0.0,
    epsilon=epsilon,
)


Model creation

In [None]:
# Create Model
key = jax.random.PRNGKey(0)
n_hidden = 512
n_actions = train_env.action_space[0].n  # type: ignore
model = ppo.ActorCriticMlp(n_hidden=n_hidden, n_actions=n_actions)

# Initialize model
observation = ppo.env_reset(train_env)
key, rng = jax.random.split(key, 2)
params = model.init(rng, observation)
state = ppo.PPOTrainState.create(
    apply_fn=model.apply,
    params=params,
    lr=learning_rate,
    config=config,
    max_grad_norm=max_grad_norm,
)
del params


Log and checkpoint configuration

In [None]:
checkpoint_dir = (
    pathlib.Path(".").absolute().joinpath("checkpoints/cartpole/run1").as_posix()
)
log_dir = pathlib.Path(".").absolute().joinpath("logs/cartpole/run1").as_posix()

log_frequency = 1
eval_frequency = 1
eval_episodes = 100

summary_writer = tensorboard.SummaryWriter(log_dir)
summary_writer.hparams(config._asdict())


Train!

In [None]:
batch_size = config.horizon * config.n_actors
frames_per_train_step = batch_size
num_train_steps = config.total_frames // frames_per_train_step

reward = 0.0

horizon = state.config.horizon
gamma = state.config.gamma
lam = state.config.lam

with tqdm(range(num_train_steps)) as t:
    for step in t:
        frame = step * frames_per_train_step
        t.set_description(f"frame: {step}")

        key, rng1, rng2 = jax.random.split(key, 3)
        trajectory, observation = ppo.create_trajectory(
            observation,
            state.apply_fn,
            state.params,
            train_env,
            rng1,
            horizon,
            gamma,
            lam,
        )
        state, losses = ppo.train_step(state, trajectory, rng2)

        if step % log_frequency == 0:
            summary_writer.scalar("train/loss", losses["total"], frame)
            summary_writer.scalar("train/loss-actor", losses["actor"], frame)
            summary_writer.scalar("train/loss-critic", losses["critic"], frame)
            summary_writer.scalar("train/loss-entropy", losses["entropy"], frame)
            summary_writer.scalar("train/learning-rate", state.learning_rate(), frame)
            summary_writer.scalar("train/clipping", state.epsilon(), frame)

        if step % 25 == 0:
            key, rng = jax.random.split(key, 2)
            reward = ppo.evaluate_model(state, eval_env, eval_episodes, rng)
            summary_writer.scalar("train/reward", reward, frame)

        t.set_description_str(f"loss: {losses['total']}, reward: {reward}")

        if checkpoint_dir is not None:
            checkpoints.save_checkpoint(checkpoint_dir, state, frame)
