In [None]:
from dataclasses import dataclass
from jax import random, numpy as jnp
import optax
import equinox as eqx
from pathlib import Path
from tqdm import trange
import gym
import pandas as pd
import timeit

from src.actorcritic import ActorCritic
from src.training import policy_trajectory, step_model_ppo

import warnings
warnings.filterwarnings("ignore")


@dataclass
class Hyperparameters:
    batchsize=1
    epochs=1000
    steps_per_episode=500
    learning_rate=0.003
    reward_discount=0.99
    epsilon=0.2

hyperparameters = Hyperparameters()


if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    observation = env.reset()

    key = random.PRNGKey(0)
    agent = ActorCritic(observation.shape[0], env.action_space.n, [64], [64], key=key)

    optimizer = optax.adam(hyperparameters.learning_rate)
    optimizer_state = optimizer.init(eqx.filter(agent, eqx.filters.is_array))

    env.seed(0)

    # with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    df = pd.DataFrame(columns=["time", "parameters", "avg_reward"])
    for epoch in (pbar := trange(hyperparameters.epochs)):
        batch_rewards = []

        for batchsize in range(hyperparameters.batchsize):
            key = random.split(key, 2)[1]

            import timeit

            start = timeit.default_timer()

            trajectory = policy_trajectory(
                env,
                agent,
                hyperparameters.steps_per_episode,
                hyperparameters.reward_discount,
                key,
            )

            loss, agent, optimizer_state = step_model_ppo(
                agent, trajectory, optimizer, optimizer_state
            )

            try:
                trajectory_reward = jnp.where(trajectory["discounts"] == 0.0)[0][0]
            except:
                trajectory_reward = hyperparameters.steps_per_episode
            batch_rewards.append(trajectory_reward)
        avg_reward = jnp.array(batch_rewards).mean()
        pbar.set_description(f"{avg_reward}")
        save_path = Path(f"checkpoints/cartpole/checkpoint_{epoch}.eqx")
        save_path.parent.mkdir(parents=True, exist_ok=True)
        df = pd.concat(
            [
                df,
                pd.DataFrame(
                    {
                        "time": epoch,
                        "parameters": save_path.absolute().as_posix(),
                        # "loss": loss,
                        "avg_reward": avg_reward,
                    },
                    index=[0],
                ),
            ],
            ignore_index=True,
        )
        if epoch % 100 == 0:
            eqx.tree_serialise_leaves(save_path, agent)
            pd.to_pickle(df, save_path.parent / f"data_{epoch}.pkl")

    # %%

In [18]:
# Render an episode and save as a GIF file

from IPython import display as ipythondisplay
from PIL import Image
from pyvirtualdisplay import Display
import gym
from src.actorcritic import ActorCritic
import equinox as eqx
import jax
import jax.numpy as jnp

display = Display(visible=0, size=(400, 300))
display.start()


def render_episode(env: gym.Env, model: ActorCritic, max_steps: int, key: jax.random.KeyArray): 
  screen = env.render(mode='rgb_array')
  im = Image.fromarray(screen)

  images = [im]

  state = jnp.array(env.reset())
  for i in range(1, max_steps + 1):
    key = jax.random.split(key, 2)[1]
    action = model.act(state, key)
    state, _, done, _ = env.step(action.item())
    state = jnp.array(state)

    # Render screen every 10 steps
    if i % 10 == 0:
      screen = env.render(mode='rgb_array')
      images.append(Image.fromarray(screen))


  return images


env = gym.make("CartPole-v1")
# Save GIF image
observation = env.reset()
key = jax.random.PRNGKey(0)
images = render_episode(env, eqx.tree_deserialise_leaves("checkpoints/cartpole/checkpoint_500.eqx", ActorCritic(observation.shape[0], env.action_space.n, [64], [64], key=key)), 1000, key=key)
image_file = 'cartpole-v0.gif'
# loop=0: loop forever, duration=1: play each frame for 1ms
images[0].save(
    image_file, save_all=True, append_images=images[1:], loop=0, duration=1)
