In [7]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm
import sys

sys.path.append("../../../")

from src import CliffWalking, Q_learning, EpsilonGreedy, animated_heatmap, tabular_rollout, tabular_parallel_rollout, plot_path

In [8]:
SEED = 2
INITIAL_STATE = jnp.array([3, 0])
GOAL_STATE = jnp.array([3, 10])
GRID_SIZE = jnp.array([4, 11])
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1
TIME_STEPS = 100_000
STOCHASTIC_RESET = False

key = random.PRNGKey(SEED)

env = CliffWalking(INITIAL_STATE, GOAL_STATE, GRID_SIZE, STOCHASTIC_RESET)
policy = EpsilonGreedy(0.01)
agent = Q_learning(
    DISCOUNT,
    LEARNING_RATE,
)


scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



In [9]:
env_states, action_keys, obs, rewards, done, q_values = tabular_rollout(key, TIME_STEPS, N_ACTIONS, GRID_SIZE, env, agent, policy)
animated_heatmap(q_values, dims=jnp.asarray(GRID_SIZE), agent_name="Q-learning", sample_freq=500, log_scale=False)
plot_path(obs, title="Q-learning state visit count")


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:09<00:00, 10706.24it/s]


In [11]:
trajectories = pd.DataFrame(obs)
trajectories["episode"] = done.cumsum()
trajectories["episode"] = trajectories["episode"].shift().fillna(0)
trajectories["rewards"] = rewards

rewards_per_ep = trajectories[["rewards", "episode"]].groupby("episode").agg("sum")
px.line(rewards_per_ep)

In [12]:
N_ENV = 30

key = random.PRNGKey(SEED)
keys = random.split(key, N_ENV)

all_obs, all_rewards, all_done, all_q_values = tabular_parallel_rollout(
    keys, TIME_STEPS, N_ACTIONS, GRID_SIZE, N_ENV, env, agent, policy
)
animated_heatmap(jnp.mean(all_q_values, axis=-1), 
                 dims=jnp.asarray(GRID_SIZE), 
                 agent_name=f"{N_ENV} Q-learning", 
                 sample_freq=10_000, 
                 log_scale=False
                )


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.

Running for 100,000 iterations: 100%|██████████| 100000/100000 [05:43<00:00, 291.06it/s]


In [15]:
rewards_df = pd.DataFrame(all_rewards)
done_df = pd.DataFrame(all_done.cumsum(axis=0)).shift().fillna(0)
all_trajectories = {}

for i in range(N_ENV):
    all_trajectories[i] = pd.DataFrame(done_df[i])
    all_trajectories[i].columns = ["Episodes"]
    all_trajectories[i]["rewards"] = rewards_df[i]

    all_trajectories[i] = all_trajectories[i].groupby("Episodes").agg("sum")
    all_trajectories[i] = all_trajectories[i].head(5000)
    all_trajectories[i] = all_trajectories[i].squeeze()

    title = (f"Average Reward per Episode, averaged over {N_ENV} runs",)
px.line(
    pd.DataFrame(all_trajectories).mean(axis=1),
    range_y=[-200, 0],
)