In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm
import sys

sys.path.append("../../../")

from src import CliffWalking, Expected_Sarsa, EpsilonGreedy, animated_heatmap, tabular_rollout, tabular_parallel_rollout, plot_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
INITIAL_STATE = jnp.array([3, 0])
GOAL_STATE = jnp.array([3, 10])
GRID_SIZE = jnp.array([4, 11])
N_STATES = jnp.prod(GRID_SIZE)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1
TIME_STEPS = 100_000
STOCHASTIC_RESET = False

key = random.PRNGKey(SEED)

env = CliffWalking(INITIAL_STATE, GOAL_STATE, GRID_SIZE, STOCHASTIC_RESET)
policy = EpsilonGreedy(0.1)
agent = Expected_Sarsa(
    key,
    N_STATES,
    N_ACTIONS, 
    DISCOUNT,
    LEARNING_RATE,
)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
env_states, action_keys, obs, rewards, done, q_values = tabular_rollout(key, TIME_STEPS, N_ACTIONS, GRID_SIZE, env, agent, policy)
animated_heatmap(q_values, dims=jnp.asarray(GRID_SIZE), agent_name="Q-learning", sample_freq=500, log_scale=False)
plot_path(obs)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:09<00:00, 10673.41it/s]


In [4]:
trajectories = pd.DataFrame(obs)
trajectories["episode"] = done.cumsum()
trajectories["episode"] = trajectories["episode"].shift().fillna(0)
trajectories["rewards"] = rewards
trajectories.tail(20)

Unnamed: 0,0,1,episode,rewards
99980,1.0,10.0,5276.0,-1
99981,2.0,10.0,5276.0,-1
99982,3.0,0.0,5276.0,0
99983,2.0,0.0,5277.0,-1
99984,2.0,1.0,5277.0,-1
99985,2.0,2.0,5277.0,-1
99986,2.0,3.0,5277.0,-1
99987,2.0,4.0,5277.0,-1
99988,2.0,5.0,5277.0,-1
99989,2.0,6.0,5277.0,-1


In [5]:
rewards_per_ep = trajectories[["rewards", "episode"]].groupby("episode").agg("sum")
px.line(rewards_per_ep)