# ***Breakout with DQN***

<div align="center">
    <img src="https://gymnasium.farama.org/_images/breakout.gif">
</div>

## ***References***:
* [Minatar](https://github.com/kenjyoung/MinAtar/blob/master/minatar/environments/breakout.py)
* [Gymnax](https://github.com/RobertTLange/gymnax/blob/main/gymnax/environments/minatar/breakout.py)
* [Gymnasium](https://gymnasium.farama.org/environments/atari/breakout/)

In [1]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from jax import random, vmap, lax
from jax_tqdm import loop_tqdm

sys.path.append("../../")
from src import Breakout, DQN, UniformReplayBuffer, deep_rl_rollout

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# MinAtar Breakout params
BATCH_SIZE = 32
BUFFER_SIZE = 100_000
TARGET_NETWORK_UPDATE_FREQ = 1000
TRAINING_FREQ = 1
NUM_FRAMES = 5_000_000
FIRST_N_FRAMES = 100_000
REPLAY_START_SIZE = 5000
END_EPSILON = 0.1
LEARNING_RATE = 0.00025
GRAD_MOMENTUM = 0.95
SQUARED_GRAD_MOMENTUM = 0.95
MIN_SQUARED_GRAD = 0.01
DISCOUNT = 0.99
EPSILON = 1.0

# other params
RANDOM_SEED = 0
STATE_SHAPE = (10, 10, 4)

In [3]:
key = random.PRNGKey(0)
env = Breakout()


@hk.transform
def model(x):
    """
    MinAtar version of DQN
    ref: https://github.com/kenjyoung/MinAtar/blob/master/examples/dqn.py
    """
    conv_layer = hk.Conv2D(
        output_channels=16,
        kernel_shape=3,
        stride=1,
    )
    fc = hk.nets.MLP(
        output_sizes=[128, env.n_actions],
        activation=jax.nn.relu,
        activate_final=False,
    )

    x = jax.nn.relu(conv_layer(x))
    x = x.reshape(-1)
    return fc(x)


def inverse_scaling_decay(epsilon_start, epsilon_end, current_step, decay_rate):
    return epsilon_end + (epsilon_start - epsilon_end) / (1 + decay_rate * current_step)


replay_buffer = UniformReplayBuffer(BUFFER_SIZE, BATCH_SIZE)
online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)
online_net_params = model.init(online_key, random.normal(online_key, env.obs_shape))
target_net_params = model.init(target_key, random.normal(target_key, env.obs_shape))
optimizer = optax.adam(learning_rate=LEARNING_RATE)
optimizer_state = optimizer.init(online_net_params)
agent = DQN(model, DISCOUNT, len(env.actions))

jax.tree_map(lambda x: x.shape, online_net_params)

{'conv2_d': {'b': (16,), 'w': (3, 3, 4, 16)},
 'mlp/~/linear_0': {'b': (128,), 'w': (1600, 128)},
 'mlp/~/linear_1': {'b': (3,), 'w': (128, 3)}}

In [4]:
buffer_state = {
    "states": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),
    "next_states": jnp.empty((BUFFER_SIZE, *STATE_SHAPE), dtype=jnp.float32),
    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),
}
jax.tree_map(lambda x: x.shape, buffer_state)

{'actions': (100000,),
 'dones': (100000,),
 'next_states': (100000, 10, 10, 4),
 'rewards': (100000,),
 'states': (100000, 10, 10, 4)}

In [6]:
init_key, action_key, buffer_key = vmap(random.PRNGKey)(jnp.arange(3) + 1)
env_state, _ = env.reset(init_key)
all_actions = jnp.zeros([1000])
all_obs = jnp.zeros([1000, *STATE_SHAPE])
all_rewards = jnp.zeros([1000], dtype=jnp.float32)
all_done = jnp.zeros([1000], dtype=jnp.bool_)
losses = jnp.zeros([1000], dtype=jnp.float32)

model_params = model.init(init_key, jnp.zeros(STATE_SHAPE))
target_net_params = model.init(action_key, jnp.zeros(STATE_SHAPE))
optimizer_state = optimizer.init(model_params)

state, _ = env_state
epsilon = inverse_scaling_decay(1.0, 0.1, 0, 1e-3)
action, action_key = agent.act(action_key, model_params, state, epsilon)
env_state, new_state, reward, done = env.step(env_state, action)
experience = (state, action, reward, new_state, done)

AttributeError: 'tuple' object has no attribute 'pos'

In [None]:
state, env_state = env.reset(key)
model.apply(online_net_params, None, state)



AttributeError: 'EnvState' object has no attribute 'ndim'

In [None]:
rollout_params = {
    "timesteps": 1000,
    "random_seed": RANDOM_SEED,
    "target_net_update_freq": TARGET_NETWORK_UPDATE_FREQ,
    "model": model,
    "optimizer": optimizer,
    "buffer_state": buffer_state,
    "agent": agent,
    "env": env,
    "replay_buffer": replay_buffer,
    "state_shape": STATE_SHAPE,
    "buffer_size": BUFFER_SIZE,
    "epsilon_decay_fn": inverse_scaling_decay,
    "epsilon_start": EPSILON,
    "epsilon_end": END_EPSILON,
    "decay_rate": 1e-3,
}

out = deep_rl_rollout(**rollout_params)

ValueError: too many values to unpack (expected 2)