# ⚡***Lightning fast parallelized training with*** <img src='https://upload.wikimedia.org/wikipedia/commons/8/86/Google_JAX_logo.svg' alt="Environment" width="60" />

## **3) Double Q-learning on a GridWorld**

In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px

from jax import random, lax, jit, vmap, pmap
from functools import partial
from jax_tqdm import loop_tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sys

sys.path.append("../../../")

from src import GridWorld, Double_Q_learning, DoubleEpsilonGreedy
from utils import animated_heatmap, double_rollout, plot_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2
INITIAL_STATE = jnp.array([8, 12])
GOAL_STATE = jnp.array([0, 0])
GRID_SIZE = jnp.array([8, 12])
N_STATES = jnp.prod(GRID_SIZE)
N_ACTIONS = 4
DISCOUNT = 0.9
LEARNING_RATE = 0.1
TIME_STEPS = 100_000
STOCHASTIC_RESET = False

key = random.PRNGKey(SEED)

env = GridWorld(INITIAL_STATE, GOAL_STATE, GRID_SIZE, STOCHASTIC_RESET)
policy = DoubleEpsilonGreedy(0.1, sum_qs=False)
agent = Double_Q_learning(
    key,
    N_STATES,
    N_ACTIONS, 
    DISCOUNT,
    LEARNING_RATE,
)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
env_states, action_keys, obs, rewards, done, q1, q2 = double_rollout(key, TIME_STEPS, N_ACTIONS, GRID_SIZE, env, agent, policy)
animated_heatmap(
                 jnp.mean(jnp.array([q1,q2]), axis=0), 
                 dims=jnp.asarray(GRID_SIZE), 
                 agent_name="Double Q-learning", 
                 sample_freq=1000, 
                 log_scale=False
                )
plot_path(obs)

Running for 100,000 iterations: 100%|██████████| 100000/100000 [00:00<00:00, 1666595.41it/s]


In [4]:
rewards.sum()

Array(3256, dtype=int32)