In [20]:
import argparse
import os
import random
import time
from distutils.util import strtobool
from typing import Sequence

import flax
import flax.linen as nn
import gymnasium as gym
import jax
import jax.numpy as jnp
import distrax #probs. distribs. en jax
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
#from torch.utils.tensorboard import SummaryWriter

In [52]:
class SharedNetwork(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.))(x)
        x = nn.relu(x)
        x = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.))(x)
        x = nn.relu(x)

        return x

class Actor(nn.Module):
    @nn.compact
    def __call__(self, x):
        actor_head = nn.Dense(2, kernel_init=orthogonal(0.01), bias_init=constant(0.))(x)
        pi = distrax.Categorical(logits=actor_head)
        
        return pi

class Critic(nn.Module):
    @nn.compact
    def __call__(self, x):
        critic_head = nn.Dense(1, kernel_init=orthogonal(1.), bias_init=constant(0.))(x)

        return jnp.squeeze(critic_head, axis=-1)

In [53]:
@flax.struct.dataclass
class AgentParams:
    shared_network_params: flax.core.FrozenDict
    actor_params: flax.core.FrozenDict
    critic_params: flax.core.FrozenDict

@flax.struct.dataclass
class Storage:
    obs: jnp.array
    actions: jnp.array
    logprobs: jnp.array
    dones: jnp.array
    values: jnp.array
    advantages: jnp.array
    returns: jnp.array
    rewards: jnp.array

@flax.struct.dataclass
class EpisodeStatistics:
    episode_returns: jnp.array
    episode_lengths: jnp.array
    returned_episode_returns: jnp.array
    returned_episode_lengths: jnp.array

In [54]:
#env = gym.make("CartPole-v1")
#env = gym.wrappers.RecordEpisodeStatistics(env)

def make_env():
    def thunk():
        return gym.make("CartPole-v1")
    return thunk

env = gym.vector.SyncVectorEnv([make_env()])
obs = env.reset()

In [48]:
for i in range(200):
    act = env.action_space.sample()
    obs, rew, done, _, _ = env.step(act)

In [6]:
obs = env.reset()
done = False

while not done:
    act = env.action_space.sample()
    obs, rew, done, _, _ = env.step(act)

env.close()

In [7]:
info

NameError: name 'info' is not defined

In [55]:
key = jax.random.PRNGKey(0)

key, sn_key, a_key, c_key = jax.random.split(key, 4)

episode_stats = EpisodeStatistics(episode_returns=jnp.zeros((1,), dtype=jnp.float32),
                                  episode_lengths=jnp.zeros((1,), dtype=jnp.int32),
                                  returned_episode_returns=jnp.zeros((1,), dtype=jnp.float32),
                                  returned_episode_lengths=jnp.zeros((1,), dtype=jnp.int32))

In [56]:
def _step_env(episode_stats: EpisodeStatistics, action):
    next_obs, rew, terminated, truncated, info = env.step(jax.device_get(action))

    next_done = terminated | truncated

    new_episode_return = episode_stats.episode_returns + rew
    new_episode_length = episode_stats.episode_lengths + 1

    episode_stats = episode_stats.replace(
            episode_returns=(new_episode_return) * (1 - next_done) * (1 - truncated),
            episode_lengths=(new_episode_length) * (1 - next_done) * (1 - truncated),
            returned_episode_returns=jnp.where(next_done + truncated, new_episode_return, episode_stats.returned_episode_returns),
            returned_episode_lengths=jnp.where(next_done + truncated, new_episode_length, episode_stats.returned_episode_lengths)
    )
    
    return episode_stats, (next_obs, rew, next_done, info)

In [57]:
learning_rate = 2.5e-4
num_minibatches = 4 # 
update_epochs = 4 # nombre de updates de la policy
num_steps = 128 # nombre de steps récoltés par env entre chaque update
num_envs = 1
batch_size = int(num_envs * num_steps) # nombre de steps total pour une update
minibatch_size = int(batch_size // num_minibatches)
total_timesteps = 1000000 # nombre total de steps du training
num_updates = total_timesteps // batch_size # nombre total d'update du training

max_grad_norm = 0.5

def linear_schedule(count):
    # anneal learning rate linearly after one training iteration which contains
    # (args.num_minibatches * args.update_epochs) gradient updates
    frac = 1.0 - (count // (num_minibatches * update_epochs)) / num_updates
    return learning_rate * frac

In [58]:
network = SharedNetwork()
actor = Actor()
critic = Critic()

network_params = network.init(sn_key, np.array([env.observation_space.sample()]))
print(network_params["params"]["Dense_0"]["kernel"].shape) #convention (fan_in, fan_out) pour les W

agent_state = TrainState.create(apply_fn=None,
                                params=AgentParams(network_params,
                                            actor.init(a_key, network.apply(network_params, np.array([env.observation_space.sample()]))),
                                            critic.init(c_key, network.apply(network_params, np.array([env.observation_space.sample()])))),
                                tx=optax.chain(optax.clip_by_global_norm(max_grad_norm), optax.inject_hyperparams(optax.adam)(learning_rate=linear_schedule)))

network.apply = jax.jit(network.apply)
actor.apply = jax.jit(actor.apply)
critic.apply = jax.jit(critic.apply)

(4, 64)


In [59]:
storage = Storage(
    obs = jnp.zeros((num_steps, num_envs) + env.observation_space.shape), # (num_steps, num_envs, obs shape)
    actions = jnp.zeros((num_steps, num_envs) + env.action_space.shape, dtype=jnp.int32), # (num_steps, num_envs, number of actions)
    logprobs = jnp.zeros((num_steps, num_envs)), # (num_steps, num_envs)
    dones = jnp.zeros((num_steps, num_envs)),
    values = jnp.zeros((num_steps, num_envs)),
    advantages = jnp.zeros((num_steps, num_envs)),
    returns = jnp.zeros((num_steps, num_envs)),
    rewards = jnp.zeros((num_steps, num_envs)),
)

In [60]:
@jax.jit
def get_action_and_value(agent_state, next_obs, next_done, storage:Storage, step, key):
    hidden = network.apply(agent_state.params.shared_network_params, next_obs)

    pi = actor.apply(agent_state.params.actor_params, hidden)

    key, subkey = jax.random.split(key)
    action = pi.sample(seed=subkey)
    logprob = pi.log_prob(action)

    value = critic.apply(agent_state.params.critic_params, hidden)

    storage = storage.replace(obs=storage.obs.at[step].set(next_obs),#pb de shape ici : reshape obs en (1, 4) ou bien ne plus inclure le num_envs dans les shapes
                              actions=storage.actions.at[step].set(action),
                              logprobs=storage.logprobs.at[step].set(logprob),
                              dones=storage.dones.at[step].set(next_done),
                              values=storage.values.at[step].set(value))
    
    return storage, action, key

In [61]:
@jax.jit
def get_action_and_value2(params, x:np.ndarray, action:np.ndarray):
    hidden = network.apply(params.shared_network_params, x)

    pi = actor.apply(agent_state.params.actor_params, hidden)
    logprob = pi.log_prob(action)
    p_logprob = pi.prob(action) * logprob #peut etre remplacé par exp(logprob) * logprob
    entropy = -p_logprob.sum(-1)

    value = critic.apply(agent_state.params.critic_params, hidden)

    return logprob, entropy, value

In [62]:
gamma = 0.99
gae_lambda = 0.95

@jax.jit
def compute_gae(agent_state, next_obs: np.ndarray, next_done: np.ndarray, storage):
    storage = storage.replace(advantages=storage.advantages.at[:].set(0.0))

    next_value = critic.apply(agent_state.params.critic_params, network.apply(agent_state.params.shared_network_params, next_obs))

    lastgaelam = 0
    for t in reversed(range(num_steps)):
        if t == num_steps - 1:
            nextnonterminal = 1.0 - next_done
            nextvalues = next_value
        else:
            nextnonterminal = 1.0 - storage.dones[t + 1]
            nextvalues = storage.values[t + 1]
        delta = storage.rewards[t] + gamma * nextvalues * nextnonterminal - storage.values[t]
        lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
        storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam))
    storage = storage.replace(returns=storage.advantages + storage.values)
    return storage

In [63]:
norm_adv = True
clip_coef = 0.1
ent_coef = 0.01
vf_coef = 0.5

@jax.jit
def update_ppo(agent_state: TrainState, storage: Storage, key: jax.random.PRNGKey,):
    b_obs = storage.obs.reshape((-1,) + env.observation_space.shape)
    b_logprobs = storage.logprobs.reshape(-1)
    b_actions = storage.actions.reshape((-1,) + env.action_space.shape)
    b_advantages = storage.advantages.reshape(-1)
    b_returns = storage.returns.reshape(-1)

    def ppo_loss(params, x, a, logp, mb_advantages, mb_returns):
        newlogprob, entropy, newvalue = get_action_and_value2(params, x, a)
        logratio = newlogprob - logp
        ratio = jnp.exp(logratio)
        approx_kl = ((ratio - 1) - logratio).mean()

        if norm_adv:
            mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

        # Policy loss
        pg_loss1 = -mb_advantages * ratio
        pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - clip_coef, 1 + clip_coef)
        pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()

        # Value loss
        v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean()

        entropy_loss = entropy.mean()
        loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef
        return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))

    ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
    for _ in range(update_epochs):
        key, subkey = jax.random.split(key)
        b_inds = jax.random.permutation(subkey, batch_size, independent=True)
        for start in range(0, batch_size, minibatch_size):
            end = start + minibatch_size
            mb_inds = b_inds[start:end]
            (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(
                agent_state.params,
                b_obs[mb_inds],
                b_actions[mb_inds],
                b_logprobs[mb_inds],
                b_advantages[mb_inds],
                b_returns[mb_inds],
            )
            agent_state = agent_state.apply_gradients(grads=grads)
    return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key

In [64]:
global_step = 0
start_time = time.time()
next_obs, _ = env.reset()
next_done = np.zeros(num_envs)

In [65]:
#@jax.jit #jit seulement pour envs. compatibles avec JAX, car sinon les envs. non JAX veulent une action "concrete" et non pas tracé
def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, global_step):
    for step in range(0, num_steps):
        global_step += 1 * num_envs
        storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key)

        # TRY NOT TO MODIFY: execute the game and log data.
        episode_stats, (next_obs, reward, next_done, _) = _step_env(episode_stats, action)
        storage = storage.replace(rewards=storage.rewards.at[step].set(reward))
    return agent_state, episode_stats, next_obs, next_done, storage, key, global_step

In [66]:
for update in range(1, num_updates + 1):
    update_time_start = time.time()
    agent_state, episode_stats, next_obs, next_done, storage, key, global_step = rollout(
        agent_state, episode_stats, next_obs, next_done, storage, key, global_step
    )
    storage = compute_gae(agent_state, next_obs, next_done, storage)
    agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo(
        agent_state,
        storage,
        key,
    )
    avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns))
    print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}")

env.close()

global_step=128, avg_episodic_return=41.0
global_step=256, avg_episodic_return=58.0
global_step=384, avg_episodic_return=28.0
global_step=512, avg_episodic_return=22.0
global_step=640, avg_episodic_return=18.0
global_step=768, avg_episodic_return=11.0
global_step=896, avg_episodic_return=12.0
global_step=1024, avg_episodic_return=13.0
global_step=1152, avg_episodic_return=19.0
global_step=1280, avg_episodic_return=15.0
global_step=1408, avg_episodic_return=23.0
global_step=1536, avg_episodic_return=10.0
global_step=1664, avg_episodic_return=60.0
global_step=1792, avg_episodic_return=15.0
global_step=1920, avg_episodic_return=26.0
global_step=2048, avg_episodic_return=21.0
global_step=2176, avg_episodic_return=18.0
global_step=2304, avg_episodic_return=21.0
global_step=2432, avg_episodic_return=12.0
global_step=2560, avg_episodic_return=14.0
global_step=2688, avg_episodic_return=14.0
global_step=2816, avg_episodic_return=20.0
global_step=2944, avg_episodic_return=12.0
global_step=3072, 

KeyboardInterrupt: 