# Hyperparameters

In [1]:
total_timesteps = 10000000 # total timesteps of the experiment
learning_rate = 2.5e-4 # the learning rate of the optimizer
num_envs = 8 # the number of parallel environments
num_steps = 128 # the number of steps to run in each environment per policy rollout
gamma = 0.99 # the discount factor gamma
gae_lambda = 0.95 # the lambda for the general advantage estimation
num_minibatches = 4 # the number of mini batches
update_epochs = 4 # the K epochs to update the policy
clip_coef = 0.1 # the surrogate clipping coefficient
ent_coef = 0.01 # coefficient of the entropy
vf_coef = 0.5 # coefficient of the value function
max_grad_norm = 0.5 # the maximum norm for the gradient clipping
seed = 1 # seed for reproducible benchmarks
exp_name = 'PPO' # unique experiment name
env_id= "KungFuMasterNoFrameskip-v4" # id of the environment
capture_video = True # whether to save video of agent gameplay

batch_size = num_envs * num_steps # size of the batch after one rollout
minibatch_size = batch_size // num_minibatches # size of the mini batch
num_updates = total_timesteps // batch_size # the number of learning cycle

# Disabling gpu on tensorflow (it is used internally by some flax modules)

In [2]:
import tensorflow as tf

tf.config.experimental.set_visible_devices([], 'GPU')

2023-04-10 21:03:15.690172: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-10 21:03:17.598713: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-10 21:03:17.617083: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-

# Make environment

In [3]:
import string
import gymnasium as gym
import numpy as np

def make_env(env_id: string, idx: int, capture_video: bool, run_name: string):
    def thunk():
        if capture_video:
            env = gym.make(env_id, render_mode='rgb_array')
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{env_id}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env = gym.wrappers.AtariPreprocessing(env, grayscale_newaxis=True, scale_obs=True)
        return env

    return thunk

In [4]:
envs = gym.vector.SyncVectorEnv(
    [make_env(env_id, i, capture_video, exp_name) for i in range(num_envs)]
) # AsyncVectorEnv is faster, but we cannot extract single environment from it
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
obs, _ = envs.reset()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
  logger.warn(
  logger.warn(


# Create Agent model

In [5]:
import flax.linen as nn

# Helper function to quickly declare linear layer with weight and bias initializers
def linear_layer_init(features, std=np.sqrt(2), bias_const=0.0):
    layer = nn.Dense(features=features, kernel_init=nn.initializers.orthogonal(std), bias_init=nn.initializers.constant(bias_const))
    return layer

# Helper function to quickly declare convolution layer with weight and bias initializers
def convolution_layer_init(features, kernel_size, strides, std=np.sqrt(2), bias_const=0.0):
    layer = nn.Conv(features=features, kernel_size=(kernel_size, kernel_size), strides=(strides, strides), padding='VALID', kernel_init=nn.initializers.orthogonal(std), bias_init=nn.initializers.constant(bias_const))
    return layer

In [6]:
from jax import Array
import jax.numpy as jnp

class Network(nn.Module):
    @nn.compact
    def __call__(self, x: Array):
        x = nn.Sequential([
            convolution_layer_init(32, 8, 4),
            nn.relu,
            convolution_layer_init(64, 4, 2),
            nn.relu,
            convolution_layer_init(64, 3, 1),
            nn.relu,
        ])(x)
        x = jnp.reshape(x, (x.shape[0], -1))
        return nn.Sequential([
            linear_layer_init(512),
            nn.relu
        ])(x)

class Actor(nn.Module):
    action_n: int

    @nn.compact
    def __call__(self, x: Array):
        return linear_layer_init(self.action_n, std=0.01)(x)

class Critic(nn.Module):
    @nn.compact
    def __call__(self, x: Array):
        return linear_layer_init(1, std=1)(x)

network = Network()
actor = Actor(action_n=envs.single_action_space.n) # For jit we need to declare prod outside of class
critic = Critic()

# Create AgentState

In [7]:
import jax.random as random

# Setting seed of the environment for reproduction
key = random.PRNGKey(seed)
np.random.seed(seed)

key, network_key, actor_key, critic_key, action_key, permutation_key = random.split(key, num=6)

# Initializing agent parameters
network_params = network.init(network_key, obs)

logits = network.apply(network_params, obs)

actor_params = actor.init(actor_key, logits)
critic_params = critic.init(critic_key, logits)

In [8]:
import optax

# Anneal learning rate over time
def linear_schedule(count):
    frac = 1.0 - (count // (num_minibatches * update_epochs)) / num_updates
    return learning_rate * frac

tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.inject_hyperparams(optax.adamw)(
        learning_rate=linear_schedule,
        eps=1e-5
    )
)

In [9]:
from flax.core import FrozenDict
from flax.struct import dataclass

@dataclass
class AgentParams:
    actor_params: FrozenDict
    critic_params: FrozenDict
    network_params: FrozenDict

In [10]:
from jax import jit
from typing import Callable
from flax.training.train_state import TrainState
from flax import struct

# Probably jitting isn't needed as this functions should be jitted already
actor.apply = jit(actor.apply)
critic.apply = jit(critic.apply)
network.apply = jit(network.apply)

class AgentState(TrainState):
    # Setting default values for agent functions to make TrainState work in jitted function
    actor_fn: Callable = struct.field(pytree_node=False)
    critic_fn: Callable = struct.field(pytree_node=False)
    network_fn: Callable = struct.field(pytree_node=False)

agent_state = AgentState.create(
    params=AgentParams(
        network_params=network_params,
        actor_params=actor_params,
        critic_params=critic_params
    ),
    tx=tx,
    # As we have separated actor and critic we don't use apply_fn
    apply_fn=None,
    actor_fn=actor.apply,
    critic_fn=critic.apply,
    network_fn=network.apply
)

### Only run this if you want to continue training

In [11]:
tx = optax.chain(
    optax.clip_by_global_norm(max_grad_norm),
    optax.inject_hyperparams(optax.adamw)(
        learning_rate=linear_schedule,
        eps=1e-5
    )
)

agent_state = AgentState.create(
    params=AgentParams(
        actor_params=agent_state.params.actor_params,
        critic_params=agent_state.params.critic_params,
        network_params=agent_state.params.network_params
    ),
    tx=tx,
    apply_fn=None,
    actor_fn=actor.apply,
    critic_fn=critic.apply,
    network_fn=network.apply
)

# Create storage

In [12]:
@dataclass
class Storage:
    obs: jnp.array
    actions: jnp.array
    logprobs: jnp.array
    dones: jnp.array
    values: jnp.array
    advantages: jnp.array
    returns: jnp.array
    rewards: jnp.array

# Sample action

In [13]:
from numpy import ndarray
import tensorflow_probability.substrates.jax.distributions as tfp

@jit
def get_action_and_value(agent_state: AgentState, next_obs: ndarray, next_done: ndarray, storage: Storage, step: int, key: random.PRNGKeyArray):
    hidden = agent_state.network_fn(agent_state.params.network_params, next_obs)
    action_logits = agent_state.actor_fn(agent_state.params.actor_params, hidden)
    value = agent_state.critic_fn(agent_state.params.critic_params, hidden)

    # Sample discrete actions from Normal distribution
    probs = tfp.Categorical(action_logits)
    key, subkey = random.split(key)
    action = probs.sample(seed=subkey)
    logprob = probs.log_prob(action)
    storage = storage.replace(
        obs=storage.obs.at[step].set(next_obs),
        dones=storage.dones.at[step].set(next_done),
        actions=storage.actions.at[step].set(action),
        logprobs=storage.logprobs.at[step].set(logprob),
        values=storage.values.at[step].set(value.squeeze()),
    )
    return storage, action, key

In [14]:
@jit
def get_action_and_value2(agent_state: AgentState, params: AgentParams, obs: ndarray, action: ndarray):
    hidden = agent_state.network_fn(params.network_params, obs)
    action_logits = agent_state.actor_fn(params.actor_params, hidden)
    value = agent_state.critic_fn(params.critic_params, hidden)

    probs = tfp.Categorical(action_logits)
    return probs.log_prob(action), probs.entropy(), value.squeeze()

# Rollout

In [15]:
from flax.metrics.tensorboard import SummaryWriter
from jax import device_get

def rollout(
        agent_state: AgentState,
        next_obs: ndarray,
        next_done: ndarray,
        storage: Storage,
        key: random.PRNGKeyArray,
        global_step: int,
        writer: SummaryWriter,
):
    for step in range(0, num_steps):
        global_step += 1 * num_envs
        storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)
        next_obs, reward, terminated, truncated, infos = envs.step(device_get(action))
        next_done = terminated | truncated
        storage = storage.replace(rewards=storage.rewards.at[step].set(reward))

        # Only print when at least 1 env is done
        if "final_info" not in infos:
            continue

        for info in infos["final_info"]:
            # Skip the envs that are not done
            if info is None:
                continue
            writer.scalar("charts/episodic_return", info["episode"]["r"], global_step)
            writer.scalar("charts/episodic_length", info["episode"]["l"], global_step)
    return next_obs, next_done, storage, key, global_step

# Compute gae

In [16]:
@jit
def compute_gae(
        agent_state: AgentState,
        next_obs: ndarray,
        next_done: ndarray,
        storage: Storage
):
    # Reset advantages values
    storage = storage.replace(advantages=storage.advantages.at[:].set(0.0))
    hidden = agent_state.network_fn(agent_state.params.network_params, next_obs)
    next_value = agent_state.critic_fn(agent_state.params.critic_params, hidden).squeeze()
    # Compute advantage using generalized advantage estimate
    lastgaelam = 0
    for t in reversed(range(num_steps)):
        if t == num_steps - 1:
            nextnonterminal = 1.0 - next_done
            nextvalues = next_value
        else:
            nextnonterminal = 1.0 - storage.dones[t + 1]
            nextvalues = storage.values[t + 1]
        delta = storage.rewards[t] + gamma * nextvalues * nextnonterminal - storage.values[t]
        lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
        storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam))
    # Save returns as advantages + values
    storage = storage.replace(returns=storage.advantages + storage.values)
    return storage

# PPO loss

In [17]:
from jax.lax import stop_gradient

@jit
def ppo_loss(
        agent_state: AgentState,
        params: AgentParams,
        obs: ndarray,
        act: ndarray,
        logp: ndarray,
        adv: ndarray,
        ret: ndarray,
        val: ndarray,
):
    newlogprob, entropy, newvalue = get_action_and_value2(agent_state, params, obs, act)
    logratio = newlogprob - logp
    ratio = jnp.exp(logratio)

    # Calculate how much policy is changing
    approx_kl = ((ratio - 1) - logratio).mean()

    # Advantage normalization
    adv = (adv - adv.mean()) / (adv.std() + 1e-8)

    # Policy loss
    pg_loss1 = -adv * ratio
    pg_loss2 = -adv * jnp.clip(ratio, 1 - clip_coef, 1 + clip_coef)
    pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()

    # Value loss
    v_loss_unclipped = (newvalue - ret) ** 2
    v_clipped = val + jnp.clip(
        newvalue - val,
        -clip_coef,
        clip_coef,
    )
    v_loss_clipped = (v_clipped - ret) ** 2
    v_loss_max = jnp.maximum(v_loss_unclipped, v_loss_clipped)
    v_loss = 0.5 * v_loss_max.mean()

    # Entropy loss
    entropy_loss = entropy.mean()

    # main loss as sum of each part loss
    loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef
    return loss, (pg_loss, v_loss, entropy_loss, stop_gradient(approx_kl))

# Update PPO

In [18]:
from jax import value_and_grad


def update_ppo(
        agent_state: AgentState,
        storage: Storage,
        key: random.PRNGKeyArray
):
    # Flatten collected experiences
    b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape)
    b_logprobs = storage.logprobs.reshape(-1)
    b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape)
    b_advantages = storage.advantages.reshape(-1)
    b_returns = storage.returns.reshape(-1)
    b_values = storage.values.reshape(-1)

    # Create function that will return gradient of the specified function
    ppo_loss_grad_fn = jit(value_and_grad(ppo_loss, argnums=1, has_aux=True))

    for epoch in range(update_epochs):
        key, subkey = random.split(key)
        b_inds = random.permutation(subkey, batch_size, independent=True)
        for start in range(0, batch_size, minibatch_size):
            end = start + minibatch_size
            mb_inds = b_inds[start:end]
            (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
                agent_state,
                agent_state.params,
                b_obs[mb_inds],
                b_actions[mb_inds],
                b_logprobs[mb_inds],
                b_advantages[mb_inds],
                b_returns[mb_inds],
                b_values[mb_inds],
            )
            # Update an agent
            agent_state = agent_state.apply_gradients(grads=grads)

    # Calculate how good an approximation of the return is the value function
    y_pred, y_true = b_values, b_returns
    var_y = jnp.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
    return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, explained_var, key

# Train agent

In [19]:
import time
import wandb
from tqdm.notebook import tqdm
from termcolor import colored
import signal

# Make kernel interrupt be handled as normal python error
signal.signal(signal.SIGINT, signal.default_int_handler)

run_name = f"{exp_name}_{seed}_{time.asctime(time.localtime(time.time())).replace('  ', ' ').replace(' ', '_')}"
wandb.init(
    project=env_id,
    sync_tensorboard=True,
    name=run_name,
    save_code=True,
    monitor_gym=capture_video,
    config={
        'total_timesteps': total_timesteps,
        'learning_rate': learning_rate,
        'num_envs': num_envs,
        'num_steps': num_steps,
        'gamma': gamma,
        'gae_lambda': gae_lambda,
        'num_minibatches': num_minibatches,
        'update_epochs': update_epochs,
        'clip_coef': clip_coef,
        'ent_coef': ent_coef,
        'vf_coef': vf_coef,
        'max_grad_norm': max_grad_norm,
        'seed': seed,
        'batch_size': batch_size,
        'minibatch_size': minibatch_size,
        'num_updates': num_updates,
    }
)
writer = SummaryWriter(f'runs/{env_id}/{run_name}')

# Initialize the storage
storage = Storage(
    obs=jnp.zeros((num_steps, num_envs) + envs.single_observation_space.shape),
    actions=jnp.zeros((num_steps, num_envs) + envs.single_action_space.shape),
    logprobs=jnp.zeros((num_steps, num_envs)),
    dones=jnp.zeros((num_steps, num_envs)),
    values=jnp.zeros((num_steps, num_envs)),
    advantages=jnp.zeros((num_steps, num_envs)),
    returns=jnp.zeros((num_steps, num_envs)),
    rewards=jnp.zeros((num_steps, num_envs)),
)
global_step = 0
start_time = time.time()
next_obs, _ = envs.reset(seed=seed)
next_done = jnp.zeros(num_envs)

try:
    for update in tqdm(range(1, num_updates + 1)):
        next_obs, next_done, storage, action_key, global_step = rollout(agent_state, next_obs, next_done, storage, action_key, global_step, writer)
        storage = compute_gae(agent_state, next_obs, next_done, storage)
        agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, explained_var, permutation_key = update_ppo(agent_state, storage, permutation_key)

        writer.scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step)
        writer.scalar("losses/value_loss", v_loss.item(), global_step)
        writer.scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.scalar("losses/explained_variance", explained_var, global_step)
        writer.scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
    print(colored('Training complete!', 'green'))
except KeyboardInterrupt:
    print(colored('Training interrupted!', 'red'))
finally:
    envs.close()
    writer.close()
    wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33majaskowiec[0m. Use [1m`wandb login --relogin`[0m to force relogin




  0%|          | 0/9765 [00:00<?, ?it/s]

Moviepy - Building video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-0.mp4.
Moviepy - Writing video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-0.mp4




t:   0%|          | 0/3338 [00:00<?, ?it/s, now=None][A
t:  10%|█         | 341/3338 [00:00<00:00, 3407.58it/s, now=None][A
t:  20%|██        | 682/3338 [00:00<00:00, 3361.76it/s, now=None][A
t:  31%|███       | 1019/3338 [00:00<00:00, 3316.33it/s, now=None][A
t:  41%|████      | 1374/3338 [00:00<00:00, 3406.35it/s, now=None][A
t:  51%|█████▏    | 1715/3338 [00:00<00:00, 3314.69it/s, now=None][A
t:  61%|██████▏   | 2047/3338 [00:00<00:00, 3293.31it/s, now=None][A
t:  71%|███████▏  | 2386/3338 [00:00<00:00, 3320.73it/s, now=None][A
t:  82%|████████▏ | 2723/3338 [00:00<00:00, 3333.55it/s, now=None][A
t:  92%|█████████▏| 3057/3338 [00:00<00:00, 3258.51it/s, now=None][A
                                                                  [A

Moviepy - Done !
Moviepy - video ready /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-0.mp4
Moviepy - Building video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-1.mp4.
Moviepy - Writing video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-1.mp4




t:   0%|          | 0/2530 [00:00<?, ?it/s, now=None][A
t:  17%|█▋        | 433/2530 [00:00<00:00, 4327.48it/s, now=None][A
t:  34%|███▍      | 866/2530 [00:00<00:00, 4248.43it/s, now=None][A
t:  51%|█████     | 1291/2530 [00:00<00:00, 4164.52it/s, now=None][A
t:  68%|██████▊   | 1711/2530 [00:00<00:00, 4177.36it/s, now=None][A
t:  84%|████████▍ | 2129/2530 [00:00<00:00, 4075.22it/s, now=None][A
                                                                  [A

Moviepy - Done !
Moviepy - video ready /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-1.mp4
Moviepy - Building video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-8.mp4.
Moviepy - Writing video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-8.mp4




t:   0%|          | 0/4383 [00:00<?, ?it/s, now=None][A
t:   7%|▋         | 318/4383 [00:00<00:01, 3179.25it/s, now=None][A
t:  15%|█▍        | 643/4383 [00:00<00:01, 3219.14it/s, now=None][A
t:  22%|██▏       | 965/4383 [00:00<00:01, 3187.47it/s, now=None][A
t:  29%|██▉       | 1284/4383 [00:00<00:00, 3100.03it/s, now=None][A
t:  36%|███▋      | 1595/4383 [00:00<00:00, 3102.65it/s, now=None][A
t:  43%|████▎     | 1906/4383 [00:00<00:00, 3072.68it/s, now=None][A
t:  51%|█████     | 2214/4383 [00:00<00:00, 3043.29it/s, now=None][A
t:  57%|█████▋    | 2519/4383 [00:00<00:00, 2994.61it/s, now=None][A
t:  64%|██████▍   | 2819/4383 [00:00<00:00, 2972.53it/s, now=None][A
t:  71%|███████   | 3117/4383 [00:01<00:00, 2899.57it/s, now=None][A
t:  78%|███████▊  | 3408/4383 [00:01<00:00, 2821.44it/s, now=None][A
t:  84%|████████▍ | 3691/4383 [00:01<00:00, 2808.98it/s, now=None][A
t:  91%|█████████ | 3973/4383 [00:01<00:00, 2792.59it/s, now=None][A
t:  97%|█████████▋| 4253/4383 [00:0

Moviepy - Done !
Moviepy - video ready /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-8.mp4
Moviepy - Building video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-27.mp4.
Moviepy - Writing video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-27.mp4




t:   0%|          | 0/5519 [00:00<?, ?it/s, now=None][A
t:   5%|▍         | 259/5519 [00:00<00:02, 2582.29it/s, now=None][A
t:   9%|▉         | 518/5519 [00:00<00:01, 2569.42it/s, now=None][A
t:  14%|█▍        | 775/5519 [00:00<00:01, 2462.34it/s, now=None][A
t:  19%|█▊        | 1032/5519 [00:00<00:01, 2502.16it/s, now=None][A
t:  23%|██▎       | 1288/5519 [00:00<00:01, 2522.62it/s, now=None][A
t:  28%|██▊       | 1541/5519 [00:00<00:01, 2474.00it/s, now=None][A
t:  33%|███▎      | 1794/5519 [00:00<00:01, 2489.76it/s, now=None][A
t:  37%|███▋      | 2045/5519 [00:00<00:01, 2493.24it/s, now=None][A
t:  42%|████▏     | 2296/5519 [00:00<00:01, 2497.01it/s, now=None][A
t:  46%|████▌     | 2546/5519 [00:01<00:01, 2486.24it/s, now=None][A
t:  51%|█████     | 2798/5519 [00:01<00:01, 2495.60it/s, now=None][A
t:  55%|█████▌    | 3048/5519 [00:01<00:01, 2449.38it/s, now=None][A
t:  60%|█████▉    | 3294/5519 [00:01<00:00, 2434.10it/s, now=None][A
t:  64%|██████▍   | 3539/5519 [00:0

Moviepy - Done !
Moviepy - video ready /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-27.mp4
Moviepy - Building video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-64.mp4.
Moviepy - Writing video /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-64.mp4




t:   0%|          | 0/5487 [00:00<?, ?it/s, now=None][A
t:   4%|▍         | 244/5487 [00:00<00:02, 2436.71it/s, now=None][A
t:   9%|▉         | 516/5487 [00:00<00:01, 2600.64it/s, now=None][A
t:  14%|█▍        | 793/5487 [00:00<00:01, 2673.70it/s, now=None][A
t:  19%|█▉        | 1061/5487 [00:00<00:01, 2654.58it/s, now=None][A
t:  24%|██▍       | 1327/5487 [00:00<00:01, 2554.25it/s, now=None][A
t:  29%|██▉       | 1584/5487 [00:00<00:01, 2529.43it/s, now=None][A
t:  33%|███▎      | 1838/5487 [00:00<00:01, 2504.50it/s, now=None][A
t:  38%|███▊      | 2089/5487 [00:00<00:01, 2480.09it/s, now=None][A
t:  43%|████▎     | 2338/5487 [00:00<00:01, 2445.66it/s, now=None][A
t:  47%|████▋     | 2583/5487 [00:01<00:01, 2376.01it/s, now=None][A
t:  52%|█████▏    | 2833/5487 [00:01<00:01, 2412.38it/s, now=None][A
t:  56%|█████▌    | 3075/5487 [00:01<00:00, 2413.51it/s, now=None][A
t:  60%|██████    | 3317/5487 [00:01<00:00, 2410.55it/s, now=None][A
t:  65%|██████▍   | 3559/5487 [00:0

Moviepy - Done !
Moviepy - video ready /home/adrian/Projekty/PPO/videos/KungFuMasterNoFrameskip-v4/rl-video-episode-64.mp4


# Validation

In [23]:
# Only works in SyncVectorEnv!
# Get first environment from used environment pool, because it learned the value of running mean for reward and observation normalization
test_env = envs.envs[1]
# Make it render on the screen
test_env.unwrapped.render_mode = 'human'
test_env

In [29]:
test_env = gym.make(env_id, render_mode='rgb_array')
test_env = gym.wrappers.AtariPreprocessing(test_env, grayscale_newaxis=True, scale_obs=True)
test_env = gym.wrappers.RecordVideo(test_env, 'test_run/')
test_env.metadata['render_fps'] = 24

In [27]:
from numpy import ndarray
import tensorflow_probability.substrates.jax.distributions as tfp

def get_action(agent_state: AgentState, obs: ndarray, key: random.PRNGKeyArray):
    hidden = agent_state.network_fn(agent_state.params.network_params, obs)
    action_logits = agent_state.actor_fn(agent_state.params.actor_params, hidden)
    probs = tfp.Categorical(action_logits)
    key, subkey = random.split(key)
    action = probs.sample(seed=subkey)
    return action, key

In [None]:
from termcolor import colored

try:
    observation, _ = test_env.reset()
    while True:
        observation = np.expand_dims(observation, 0)
        action, action_key = get_action(agent_state, observation, action_key)
        action = action.item()
        observation, reward, terminated, truncated, info = test_env.step(action)
        if terminated or truncated:
            break
except KeyboardInterrupt:
    print(colored('Validation stopped!', 'red'))
finally:
    test_env.close()

# Saving model for future usage

In [27]:
from flax.training import checkpoints

checkpoints.save_checkpoint(ckpt_dir='checkpoints/',  # Folder to save checkpoint in
                            target=agent_state,  # What to save. To only save parameters, use model_state.params
                            step=update,  # Training step or other metric to save best model on
                            prefix=f'{env_id}-',  # Checkpoint file name prefix
                            overwrite=True   # Overwrite existing checkpoint files
                            )

In [12]:
from flax.training import checkpoints

agent_state = checkpoints.restore_checkpoint(ckpt_dir='checkpoints/',  # Folder to save checkpoint in
                               target=agent_state,  # What to save. To only save parameters, use model_state.params
                               prefix=f'{env_id}-',  # Checkpoint file name prefix # Overwrite existing checkpoint files
                               )