In [1]:
# !pip install git+https://github.com/deepmind/dm-haiku optax jax_tqdm jax pandas git+https://github.com/instadeepai/jumanji.git

In [2]:
import jax
import jax.numpy as jnp
import pandas as pd
import optax
import haiku as hk
import numpy as np
from network import CustomNetwork

from jax import jit, lax, random, vmap, debug

from jax_tqdm import loop_tqdm

import matplotlib.pyplot as plt

from cartpole import CartPole
from dqn import DQN
from replay_buffer import ReplayBuffer

import jumanji
from jumanji.wrappers import AutoResetWrapper
from jumanji.types import StepType
BATCH_SIZE = 32

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from typing import Callable

from base_agent import BaseAgent
from base_env import BaseEnv
from base_buffer import BaseReplayBuffer

def rollout(
    timesteps: int,
    random_seed: int,
    target_net_update_freq: int,
    model: hk.Transformed,
    optimizer: optax.GradientTransformation,
    buffer_state: dict,
    agent: BaseAgent,
    env: BaseEnv,
    replay_buffer: BaseReplayBuffer,
    state_shape: int,
    buffer_size: int,
    epsilon_decay_fn: Callable,
    epsilon_start: float,
    epsilon_end: float,
    decay_rate: float,
) -> dict:

    @loop_tqdm(timesteps)
    @jit
    def _fori_body(i: int, val: tuple):
        (
            model_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            action_key,
            buffer_key,
            env_state,
            all_actions,
            all_obs,
            all_masks,
            all_rewards,
            all_done,
            losses,
        ) = val

        state = env_state
        epsilon = epsilon_decay_fn(epsilon_start, epsilon_end, i, decay_rate)
        
        action_key, action_keys = jax.random.split(action_key, 2)
        action_keys = jax.random.split(action_keys, BATCH_SIZE)
        # action_key = jax.random.split(action_key, BATCH_SIZE)
        action, action_keys = vmap(agent.act, in_axes=(0, None, 0, None))(action_keys, model_params, state, epsilon)
        # action, action_key = agent.act(action_key, model_params, state, epsilon)

        # new_state, timestep = env.step(env_state, action)
        new_state, timestep = vmap(env.step, in_axes=(0, 0))(env_state, action)
        # debug.print("_fori_body(timestep): {x}", x=timestep)
        env_state = new_state
        reward = timestep.reward
        done = jnp.asarray(timestep.step_type == StepType.LAST)

        experience = (
            jnp.asarray(env_state.board).reshape(-1, 16),
            jnp.asarray(env_state.action_mask),
            jnp.asarray(action),
            jnp.asarray(reward),
            jnp.asarray(new_state.board).reshape(-1, 16),
            jnp.asarray(new_state.action_mask),
            jnp.asarray(done),
        )

        buffer_state = replay_buffer.add(buffer_state, experience, i, BATCH_SIZE)
        current_buffer_size = jnp.min(jnp.array([i, buffer_size]))

        experiences_batch, buffer_key = replay_buffer.sample(
            buffer_key,
            buffer_state,
            current_buffer_size,
        )

        model_params, optimizer_state, loss = agent.update(
            model_params,
            target_net_params,
            optimizer,
            optimizer_state,
            experiences_batch,
        )
        # debug.print("_fori_body(loss): {x}", x=loss)

        # model_params, optimizer_state, loss = agent.batch_update(
        #     model_params,
        #     target_net_params,
        #     optimizer,
        #     optimizer_state,
        #     experiences_batch,
        # )

        # debug.print("_fori_body(experiences_batch): {x}", x=experiences_batch)

        # model_params, optimizer_state, loss = vmap(agent.update, in_axes=(None, None, None, None, 0), out_axes=(None, None, 0))(
        #     model_params,
        #     target_net_params,
        #     optimizer,
        #     optimizer_state,
        #     experiences_batch,
        # )
        # vmap(DQN.update, in_axes=(None, None, None, None, 0))(
        #     online_net_params, target_net_params, optimizer, optimizer_state, experiences
        # )

        # update the target parameters every ``target_net_update_freq`` steps
        target_net_params = lax.cond(
            i % target_net_update_freq == 0,
            lambda _: model_params,
            lambda _: target_net_params,
            operand=None,
        )

        @jit
        def _update_val(j: int, target_supply: tuple):
            target_data, supplied_data, idx,  = target_supply
        
            all_actions = target_data[0].at[idx].set(supplied_data[0][j])
            all_obs = target_data[1].at[idx].set(jnp.resize(supplied_data[1][j], (16,)))
            all_masks = target_data[2].at[idx].set(supplied_data[2][j])
            all_rewards = target_data[3].at[idx].set(supplied_data[3][j])
            all_done = target_data[4].at[idx].set(supplied_data[4][j])
            # losses = target_data[5].at[idx].set(supplied_data[5][j])
            return (all_actions, all_obs, all_masks, all_rewards, all_done), supplied_data, idx+1


        target_data = (all_actions, all_obs, all_masks, all_rewards, all_done)
        supplied_data = (action, new_state.board, new_state.action_mask, reward, done, loss)

        target_data, _, _ = lax.fori_loop(0, BATCH_SIZE, _update_val, (target_data, supplied_data, i))

        # all_actions = all_actions.at[i].set(action)
        # all_obs = all_obs.at[i].set(jnp.resize(new_state.board, (16,)))
        # all_masks = all_masks.at[i].set(new_state.action_mask)
        # all_rewards = all_rewards.at[i].set(reward)
        # all_done = all_done.at[i].set(done)
        losses = losses.at[i].set(loss)

        all_actions = target_data[0]
        all_obs = target_data[1]
        all_masks = target_data[2]
        all_rewards = target_data[3]
        all_done = target_data[4]
        # losses = target_data[5]

        val = (
            model_params,
            target_net_params,
            optimizer_state,
            buffer_state,
            action_key,
            buffer_key,
            env_state,
            all_actions,
            all_obs,
            all_masks,
            all_rewards,
            all_done,
            losses,
        )

        return val

    init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + random_seed)
    init_key, init_keys = jax.random.split(init_key, 2)
    init_keys = jax.random.split(init_key, BATCH_SIZE)
    env_state, _ = vmap(env.reset)(init_keys)
    
    all_actions = jnp.zeros([timesteps*BATCH_SIZE])
    all_obs = jnp.zeros([timesteps*BATCH_SIZE, *state_shape])
    all_masks = jnp.zeros([timesteps*BATCH_SIZE, 4])
    all_rewards = jnp.zeros([timesteps*BATCH_SIZE], dtype=jnp.float32)
    all_done = jnp.zeros([timesteps*BATCH_SIZE], dtype=jnp.bool_)
    losses = jnp.zeros([timesteps], dtype=jnp.float32)

    model_params = model.init(init_key, jnp.zeros(state_shape))
    target_net_params = model.init(action_key, jnp.zeros(state_shape))
    optimizer_state = optimizer.init(model_params)

    val_init = (
        model_params,
        target_net_params,
        optimizer_state,
        buffer_state,
        action_key,
        buffer_key,
        env_state,
        all_actions,
        all_obs,
        all_masks,
        all_rewards,
        all_done,
        losses,
    )

    vals = lax.fori_loop(0, timesteps, _fori_body, val_init)
    output_dict = {}
    keys = [
        "model_params",
        "target_net_params",
        "optimizer_state",
        "buffer_state",
        "action_key",
        "buffer_key",
        "env_state",
        "all_actions",
        "all_obs",
        "all_masks",
        "all_rewards",
        "all_done",
        "losses",
    ]
    for idx, value in enumerate(vals):
        output_dict[keys[idx]] = value

    return output_dict


In [4]:
# Env parameters
RANDOM_SEED = 1
N_ACTIONS = 4
STATE_SHAPE = (16,)

# Hyperparameters
DISCOUNT = 0.99
TIMESTEPS = 20_000
TARGET_NET_UPDATE_FREQ = 10
BUFFER_SIZE = 4096
BATCH_SIZE = 32
LEARNING_RATE = 0.00005
EPSILON_START = 0.99
EPSILON_END = 0.01
DECAY_RATE = 1e-4

In [5]:
buffer_state = {
    "states": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "action_masks": jnp.empty((BUFFER_SIZE, 4), dtype=jnp.bool_),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "new_action_masks": jnp.empty((BUFFER_SIZE, 4), dtype=jnp.bool_),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
print(jax.tree_util.tree_map(lambda x: x.shape, buffer_state))

{'action_masks': (4096, 4), 'actions': (4096,), 'dones': (4096,), 'new_action_masks': (4096, 4), 'next_states': (4096, 16), 'rewards': (4096,), 'states': (4096, 16)}


In [6]:
model = hk.transform(lambda x: CustomNetwork(output_dim=N_ACTIONS)(x))

def inverse_scaling_decay(epsilon_start, epsilon_end, current_step, decay_rate):
    return epsilon_end + (epsilon_start - epsilon_end) / (1 + decay_rate * current_step)

online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)

env = jumanji.make("Game2048-v1")
env = AutoResetWrapper(env)

replay_buffer = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
online_net_params = model.init(online_key, random.normal(online_key, STATE_SHAPE))
target_net_params = model.init(target_key, random.normal(target_key, STATE_SHAPE))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
optimizer_state = optimizer.init(online_net_params)
agent = DQN(model, DISCOUNT, N_ACTIONS)

In [7]:
x_values = range(TIMESTEPS)
y_values = [inverse_scaling_decay(EPSILON_START, EPSILON_END, i, DECAY_RATE) for i in x_values]

plt.figure(figsize=(10, 6))
plt.plot(x_values, y_values, label="Epsilon Decay")
plt.title("Epsilon Decay")
plt.xlabel("Timesteps")
plt.ylabel("Epsilon Value")
plt.legend()
plt.grid(True)
plt.show()

<IPython.core.display.Javascript object>

In [8]:
jax.tree_util.tree_map(lambda x: x.shape, online_net_params)

{'custom_network/~/conv2_d': {'b': (32,), 'w': (2, 2, 1, 32)},
 'custom_network/~/conv2_d_1': {'b': (64,), 'w': (2, 2, 32, 64)},
 'custom_network/~/linear': {'b': (128,), 'w': (1024, 128)},
 'custom_network/~/linear_1': {'b': (256,), 'w': (128, 256)},
 'custom_network/~/linear_2': {'b': (128,), 'w': (256, 128)},
 'custom_network/~/linear_3': {'b': (64,), 'w': (128, 64)},
 'custom_network/~/linear_4': {'b': (4,), 'w': (64, 4)}}

In [9]:
jax.tree_util.tree_map(lambda x: x.shape, optimizer_state)

(ScaleByAdamState(count=(), mu={'custom_network/~/conv2_d': {'b': (32,), 'w': (2, 2, 1, 32)}, 'custom_network/~/conv2_d_1': {'b': (64,), 'w': (2, 2, 32, 64)}, 'custom_network/~/linear': {'b': (128,), 'w': (1024, 128)}, 'custom_network/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'custom_network/~/linear_2': {'b': (128,), 'w': (256, 128)}, 'custom_network/~/linear_3': {'b': (64,), 'w': (128, 64)}, 'custom_network/~/linear_4': {'b': (4,), 'w': (64, 4)}}, nu={'custom_network/~/conv2_d': {'b': (32,), 'w': (2, 2, 1, 32)}, 'custom_network/~/conv2_d_1': {'b': (64,), 'w': (2, 2, 32, 64)}, 'custom_network/~/linear': {'b': (128,), 'w': (1024, 128)}, 'custom_network/~/linear_1': {'b': (256,), 'w': (128, 256)}, 'custom_network/~/linear_2': {'b': (128,), 'w': (256, 128)}, 'custom_network/~/linear_3': {'b': (64,), 'w': (128, 64)}, 'custom_network/~/linear_4': {'b': (4,), 'w': (64, 4)}}),
 EmptyState())

In [10]:
init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + RANDOM_SEED)
keys = jax.random.split(init_key, BATCH_SIZE)
env_state, _ = vmap(env.reset)(keys)

for i in range(BUFFER_SIZE//BATCH_SIZE+1):

    action_key, action_keys = jax.random.split(action_key, 2)
    action_keys = jax.random.split(action_keys, BATCH_SIZE)
    # print(f"action_keys: {len(action_keys)}")
    # action, action_key = agent.batch_act(action_key, online_net_params, env_state, 1)
    action, action_keys = vmap(agent.act, in_axes=(0, None, 0, None))(action_keys, online_net_params, env_state, 1)
    # print(f"action: {action}, action_key: {len(action_key)}")
    # keys = jax.random.split(action_key, BATCH_SIZE)
    new_state, timestep = vmap(env.step, in_axes=(0, 0))(env_state, action)

    # print(f"env_state: {env_state}, new_state: {(new_state)}")#, timestep: {timestep}")
    reward = timestep.reward
    done = jnp.asarray(timestep.step_type == StepType.LAST)

    experience = (
        jnp.asarray(env_state.board).reshape(-1, 16),
        jnp.asarray(env_state.action_mask),
        jnp.asarray(action),
        jnp.asarray(reward),
        jnp.asarray(new_state.board).reshape(-1, 16),
        jnp.asarray(new_state.action_mask),
        jnp.asarray(done),
    )

    # print(experience)

    buffer_state = replay_buffer.add(buffer_state, experience, i, BATCH_SIZE)
    env_state = new_state



In [11]:
print(len(buffer_state['action_masks'])==BUFFER_SIZE, len(buffer_state['states'])==BUFFER_SIZE)

True True


In [12]:
experiences_batch, buffer_key = replay_buffer.sample(
    buffer_key,
    buffer_state,
    32,
)

In [13]:
rollout_params = {
    "timesteps": TIMESTEPS,
    "random_seed": RANDOM_SEED,
    "target_net_update_freq": TARGET_NET_UPDATE_FREQ,
    "model": model,
    "optimizer": optimizer,
    "buffer_state": buffer_state,
    "agent": agent,
    "env": env,
    "replay_buffer": replay_buffer,
    "state_shape": STATE_SHAPE,
    "buffer_size": BUFFER_SIZE,
    "epsilon_decay_fn": inverse_scaling_decay,
    "epsilon_start": EPSILON_START,
    "epsilon_end": EPSILON_END,
    "decay_rate": DECAY_RATE,
}

out = rollout(**rollout_params)

In [14]:
print(out.keys())

dict_keys(['model_params', 'target_net_params', 'optimizer_state', 'buffer_state', 'action_key', 'buffer_key', 'env_state', 'all_actions', 'all_obs', 'all_masks', 'all_rewards', 'all_done', 'losses'])


In [15]:
losses = out["losses"][:-TIMESTEPS]

plt.figure(figsize=(10, 6))
plt.plot(losses, label="Training Loss")
plt.title("Loss during training")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

Running for 20,000 iterations:   0%|          | 0/20000 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

Running for 20,000 iterations: 100%|██████████| 20000/20000 [04:21<00:00, 76.60it/s]


In [16]:
print((out['all_obs'][10:40]))

[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 3. 0. 2. 3. 1.]
 [0. 0. 0. 0. 1. 2. 0. 0. 3. 0. 0. 0. 2. 3. 1. 0.]
 [0. 0. 1. 0. 0. 0. 1. 2. 0. 0. 0. 3. 0. 2. 3. 1.]
 [1. 0. 0. 0. 1. 2. 0. 0. 3. 1. 0. 0. 2. 3. 1. 0.]
 [2. 2. 1. 0. 3. 1. 0. 0. 2. 3. 0. 0. 0. 0. 1. 0.]
 [0. 2. 0. 0. 2. 2. 0. 0. 3. 1. 0. 0. 2. 3. 2. 0.]
 [2. 0. 0. 0. 3. 1. 0. 0. 3. 1. 0. 0. 2. 3. 2. 0.]
 [1. 0. 0. 2. 0. 0. 3. 1. 0. 0. 3. 1. 0. 2. 3. 2.]
 [0. 0. 0. 0. 0. 2. 0. 2. 0. 0. 3. 2. 1. 2. 4. 2.]
 [0. 0. 0. 0. 1. 0. 0. 3. 0. 0. 3. 2. 1. 2. 4. 2.]
 [2. 2. 3. 3. 0. 0. 4. 3. 0. 0. 0. 0. 1. 0. 0. 0.]
 [3. 4. 0. 0. 4. 3. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0.]
 [3. 4. 1. 0. 4. 3. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 3. 4. 1. 2. 4. 3. 1. 0. 0. 0. 1. 0. 0. 0. 0.]
 [3. 4. 1. 0. 2. 4. 3. 1. 1. 0. 0. 0. 0. 0. 0. 1.]
 [0. 3. 4. 1. 2. 4. 3. 1. 1. 0. 0. 1. 0. 0. 0. 1.]
 [3. 4. 1. 0. 2. 4. 3. 1. 2. 0. 0. 0. 1. 1. 0. 0.]
 [1. 3. 4. 1. 2. 4. 3. 1. 0. 0. 0. 2. 0. 0. 0. 2.]
 [1. 3. 4. 1. 2. 4. 3. 1. 2. 0. 0. 0. 2. 0. 1. 0.]
 [0. 0. 1. 0. 1. 0. 4. 0. 2. 3.

In [17]:
# Assuming `out` and colors are defined as in the provided code
colors = ['#636EFA', '#EF553B']  # Example for Plotly qualitative colors
reward_treshold = 2048
# Create the DataFrame
df = pd.DataFrame(
    data={
        "episode": out["all_done"].cumsum(),
        "reward": out["all_rewards"],
    },
)
print(df.tail())
df["episode"] = df["episode"].shift().fillna(0)

# Aggregate data by episode
episodes_df = df.groupby("episode").agg("sum")

# Define hover text and colors
episodes_df["hover_text"] = np.where(
    episodes_df["reward"] > reward_treshold,
    "Over 200 steps: " + episodes_df["reward"].astype(str),
    "Under 200 steps: " + episodes_df["reward"].astype(str)
)
episodes_df["color"] = np.where(episodes_df["reward"] > reward_treshold, colors[1], colors[0])

# Separate data for different categories
under_200 = episodes_df[episodes_df["reward"] < reward_treshold]
over_200 = episodes_df[episodes_df["reward"] >= reward_treshold]

# Create the figure
plt.figure(figsize=(12, 6))

# Plot bars for "under 200 steps"
plt.bar(under_200.index, under_200["reward"], color=colors[0], label="Under 200 steps")

# Plot bars for "over 200 steps"
plt.bar(over_200.index, over_200["reward"], color=colors[1], label="Over 200 steps")

# Add labels, title, and legend
plt.title("Performances of DQN on the CartPole Environment")
plt.xlabel("Episode")
plt.ylabel("Sum of rewards")
plt.ylim(0, 200)
plt.legend(title="Reward Categories")
plt.grid(axis="y", linestyle="--", alpha=0.7)

# Display the plot
plt.tight_layout()
plt.show()


        episode  reward
639995      151     0.0
639996      151     0.0
639997      151     0.0
639998      151     0.0
639999      151     0.0


<IPython.core.display.Javascript object>

In [18]:
df.groupby("episode").agg("sum").max()

reward    3084.0
dtype: float32

In [19]:
state = env_state
print(jnp.reshape(state.board, (-1, 16)))
q = model.apply(online_net_params, None, jnp.reshape(state.board, (-1,16)))
print(state.action_mask)
print(jnp.argmax((jnp.where(state.action_mask, q, -jnp.inf)), axis=1))

[[1 3 5 1 3 1 7 1 2 4 5 4 1 3 2 3]
 [2 0 0 0 3 0 1 1 7 4 1 5 2 6 2 1]
 [0 0 2 1 0 4 2 1 0 0 1 2 0 0 0 3]
 [1 3 1 0 3 1 0 0 1 3 0 0 0 0 0 0]
 [0 3 4 1 0 0 2 2 0 0 0 0 0 0 0 0]
 [5 2 3 1 2 5 1 4 2 3 7 1 1 2 4 0]
 [2 3 1 2 1 5 4 2 0 0 2 3 0 0 2 2]
 [2 0 1 4 0 2 7 1 2 5 3 2 3 2 5 4]
 [3 1 1 0 2 3 2 0 3 1 0 0 0 0 1 0]
 [4 2 0 1 3 3 0 0 6 5 1 1 2 1 7 2]
 [0 0 0 0 1 0 0 1 0 0 1 2 0 0 1 2]
 [1 1 0 0 4 0 0 0 1 0 0 0 3 1 0 0]
 [0 0 0 1 0 0 0 3 0 0 0 6 1 2 2 2]
 [0 1 2 2 0 0 1 0 0 0 0 0 0 0 0 0]
 [1 3 2 5 1 5 7 4 2 3 2 3 1 1 3 1]
 [1 3 7 1 3 5 1 2 1 2 2 4 5 1 1 2]
 [1 4 3 1 2 3 6 2 1 7 3 3 3 2 1 1]
 [2 1 4 2 0 2 1 1 0 0 4 3 0 1 0 1]
 [2 1 0 1 4 2 2 0 2 5 3 0 1 2 1 0]
 [1 1 7 1 3 2 4 2 4 6 2 1 1 2 1 2]
 [2 1 0 0 1 0 0 0 3 0 0 0 1 5 1 1]
 [1 1 1 1 2 3 5 2 1 6 7 4 0 1 2 1]
 [0 2 2 3 0 5 3 2 2 3 2 1 1 0 0 2]
 [2 1 4 1 5 3 2 3 2 5 1 1 1 2 7 1]
 [2 1 4 1 5 2 3 2 3 4 1 0 2 0 0 0]
 [1 1 2 1 0 0 3 1 0 0 0 1 0 0 0 0]
 [1 4 1 3 3 6 3 1 2 3 2 0 2 1 0 0]
 [1 2 3 0 6 1 0 0 5 4 3 0 3 7 2 1]
 [1 3 2 1 4 7 3 3 2 

In [20]:
import jax
import matplotlib.pyplot as plt
import jumanji

env = jumanji.make("Game2048-v1")
env = AutoResetWrapper(env)

key = jax.random.PRNGKey(0)
init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + RANDOM_SEED)
state, _ = env.reset(init_key)

def render_state(state):
    array = state.board
    plt.figure(figsize=(4, 4))
    plt.imshow(array, vmin=0, vmax=np.max(array))

    for i in range(array.shape[0]):
        for j in range(array.shape[1]):
            plt.text(j, i, str(array[i, j]), ha="center", va="center", color="white")

    plt.title("2D Array with Values on a Grid")
    plt.xticks(range(array.shape[1]))
    plt.yticks(range(array.shape[0]))
    plt.show()

last_action = 0
for step in range(1000):
    action, action_key = agent.act(action_key, out['model_params'], state, 0)
    if action not in np.array(np.where(state.action_mask)):
       print("error!")
    # action = np.random.choice(np.array(np.where(state.action_mask))[0])
    print(f"mask: {state.action_mask}, action: {action}")
    state, timestep = env.step(state, action)
    print(timestep['extras'])
    reward = timestep.reward
    done = jnp.asarray(timestep.step_type == StepType.LAST)
    print(f"Step {step + 1}, Reward: {reward} max value: {np.max(state.board)}")
    # print(state.board)
    if last_action != action:
      render_state(state)
      last_action = action

    if done:
       print(state.board)
    if done:
        print("Game Over!")
        break

mask: [False False  True  True], action: 2
{'highest_tile': Array(2, dtype=int32, weak_type=True)}
Step 1, Reward: 0.0 max value: 1


<IPython.core.display.Javascript object>

mask: [ True False  True  True], action: 0
{'highest_tile': Array(4, dtype=int32, weak_type=True)}
Step 2, Reward: 4.0 max value: 2


<IPython.core.display.Javascript object>

mask: [ True False  True  True], action: 0
{'highest_tile': Array(4, dtype=int32, weak_type=True)}
Step 3, Reward: 0.0 max value: 2
mask: [ True  True  True  True], action: 0
{'highest_tile': Array(4, dtype=int32, weak_type=True)}
Step 4, Reward: 0.0 max value: 2
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(4, dtype=int32, weak_type=True)}
Step 5, Reward: 0.0 max value: 2


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 1
{'highest_tile': Array(4, dtype=int32, weak_type=True)}
Step 6, Reward: 4.0 max value: 2
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(4, dtype=int32, weak_type=True)}
Step 7, Reward: 4.0 max value: 2
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(8, dtype=int32, weak_type=True)}
Step 8, Reward: 8.0 max value: 3
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(8, dtype=int32, weak_type=True)}
Step 9, Reward: 0.0 max value: 3
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(8, dtype=int32, weak_type=True)}
Step 10, Reward: 0.0 max value: 3
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(8, dtype=int32, weak_type=True)}
Step 11, Reward: 4.0 max value: 3
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(8, dtype=int32, weak_type=True)}
Step 12, Reward: 4.0 max value: 3
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(8, dtyp

<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(16, dtype=int32, weak_type=True)}
Step 17, Reward: 8.0 max value: 4
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(16, dtype=int32, weak_type=True)}
Step 18, Reward: 4.0 max value: 4
mask: [ True  True False  True], action: 1
{'highest_tile': Array(16, dtype=int32, weak_type=True)}
Step 19, Reward: 16.0 max value: 4


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 20, Reward: 32.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 21, Reward: 0.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 22, Reward: 4.0 max value: 5
mask: [ True  True False  True], action: 1
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 23, Reward: 12.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 24, Reward: 12.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 25, Reward: 0.0 max value: 5
mask: [ True  True False  True], action: 1
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 26, Reward: 0.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 27, Reward: 0.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 28, Reward: 8.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 29, Reward: 4.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 30, Reward: 4.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 31, Reward: 0.0 max value: 5
mask: [ True  True False  True], action: 1
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 32, Reward: 20.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 33, Reward: 8.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 34, Reward: 0.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 35, Reward: 4.0 max value: 5
mask: [ True  True False  True], action: 1
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 36, Reward: 4.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 37, Reward: 0.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 38, Reward: 4.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 39, Reward: 4.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 40, Reward: 16.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 41, Reward: 0.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 42, Reward: 4.0 max value: 5
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 43, Reward: 0.0 max value: 5
mask: [ True  True False  True], action: 1
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 44, Reward: 48.0 max value: 5


<IPython.core.display.Javascript object>

mask: [ True  True False  True], action: 1
{'highest_tile': Array(32, dtype=int32, weak_type=True)}
Step 45, Reward: 32.0 max value: 5
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 46, Reward: 64.0 max value: 6
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 47, Reward: 4.0 max value: 6
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 48, Reward: 4.0 max value: 6
mask: [ True False  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 49, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 50, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 51, Reward: 0.0 max value: 6
mask: [ True False  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 52, Reward: 4.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 53, Reward: 12.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 54, Reward: 8.0 max value: 6
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 55, Reward: 4.0 max value: 6
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 56, Reward: 12.0 max value: 6
mask: [ True  True  True  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 57, Reward: 16.0 max value: 6
mask: [ True False  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 58, Reward: 48.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True False  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 59, Reward: 4.0 max value: 6
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 60, Reward: 8.0 max value: 6


  plt.figure(figsize=(4, 4))


<IPython.core.display.Javascript object>

mask: [ True False False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 61, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 62, Reward: 4.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 63, Reward: 4.0 max value: 6
mask: [ True  True False False], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 64, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 65, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True False  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 66, Reward: 4.0 max value: 6
mask: [ True False  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 67, Reward: 4.0 max value: 6
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 68, Reward: 4.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True False False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 69, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True False], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 70, Reward: 24.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 71, Reward: 32.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 72, Reward: 0.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 73, Reward: 4.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 74, Reward: 8.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 75, Reward: 20.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 76, Reward: 4.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 77, Reward: 12.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile':

<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 88, Reward: 8.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 89, Reward: 16.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 90, Reward: 4.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 91, Reward: 4.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 92, Reward: 0.0 max value: 6
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 93, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 94, Reward: 4.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 95, Reward: 12.0 max value: 6
mask: [ True  True  True False], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 96, Reward: 8.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 97, Reward: 4.0 max value: 6
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 98, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 99, Reward: 16.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 100, Reward: 36.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 101, Reward: 4.0 max value: 6
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 102, Reward: 8.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True False False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 103, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 104, Reward: 4.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 105, Reward: 4.0 max value: 6
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 106, Reward: 8.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 107, Reward: 4.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 108, Reward: 12.0 max value: 6
mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 109, Reward: 16.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True False False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 110, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 111, Reward: 4.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True False  True], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 112, Reward: 12.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True False False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 113, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 114, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True False False], action: 1
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 115, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 116, Reward: 8.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 117, Reward: 0.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 118, Reward: 0.0 max value: 6
mask: [False  True False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 119, Reward: 4.0 max value: 6
mask: [False  True False  True], action: 3
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 120, Reward: 8.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 121, Reward: 16.0 max value: 6


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 122, Reward: 32.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 123, Reward: 68.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 124, Reward: 0.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 125, Reward: 4.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 126, Reward: 8.0 max value: 6
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(64, dtype=int32, weak_type=True)}
Step 127, Reward: 4.0 max value: 6
mask: [False  True False  True], action: 3
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 128, Reward: 160.0 max value: 7


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 129, Reward: 32.0 max value: 7


<IPython.core.display.Javascript object>

mask: [False  True False False], action: 1
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 130, Reward: 0.0 max value: 7


<IPython.core.display.Javascript object>

mask: [False  True  True  True], action: 2
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 131, Reward: 0.0 max value: 7


<IPython.core.display.Javascript object>

mask: [ True  True  True  True], action: 2
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 132, Reward: 4.0 max value: 7
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 133, Reward: 0.0 max value: 7
mask: [ True False  True  True], action: 2
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 134, Reward: 0.0 max value: 7
mask: [ True  True  True  True], action: 2
{'highest_tile': Array(128, dtype=int32, weak_type=True)}
Step 135, Reward: 8.0 max value: 1
[[0 0 0 0]
 [0 0 0 1]
 [0 0 0 0]
 [0 0 0 0]]
Game Over!
