# ***Breakout with DQN***

<div align="center">
    <img src="https://gymnasium.farama.org/_images/breakout.gif">
</div>

## ***References***:
* [Minatar](https://github.com/kenjyoung/MinAtar/blob/master/minatar/environments/breakout.py)
* [Gymnax](https://github.com/RobertTLange/gymnax/blob/main/gymnax/environments/minatar/breakout.py)
* [Gymnasium](https://gymnasium.farama.org/environments/atari/breakout/)

In [1]:
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np

from jax import random, vmap, lax

from jax_tqdm import loop_tqdm

import sys

sys.path.append("../../")

from src import Breakout, DQN, UniformReplayBuffer, deep_rl_rollout

  from .autonotebook import tqdm as notebook_tqdm


In [40]:
# MinAtar Breakout params
BATCH_SIZE = 32
REPLAY_BUFFER_SIZE = 100_000
TARGET_NETWORK_UPDATE_FREQ = 1000
TRAINING_FREQ = 1
NUM_FRAMES = 5_000_000
FIRST_N_FRAMES = 100_000
REPLAY_START_SIZE = 5000
END_EPSILON = 0.1
LEARNING_RATE = 0.00025
GRAD_MOMENTUM = 0.95
SQUARED_GRAD_MOMENTUM = 0.95
MIN_SQUARED_GRAD = 0.01
DISCOUNT = 0.99
EPSILON = 1.0

# other params
RANDOM_SEED = 0

In [42]:
key = random.PRNGKey(0)
env = Breakout()


@hk.transform
def model(x):
    """
    MinAtar version of DQN
    ref: https://github.com/kenjyoung/MinAtar/blob/master/examples/dqn.py
    """
    conv_layer = hk.Conv2D(output_channels=16, kernel_shape=3, stride=1)
    fc = hk.nets.MLP(
        output_sizes=[128, env.n_actions],
        activation=jax.nn.relu,
        activate_final=False,
    )

    x = jax.nn.relu(conv_layer(x))
    x = x.reshape(-1)
    return fc(x)


online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)
online_net_params = model.init(online_key, random.normal(online_key, env.obs_shape))
target_net_params = model.init(target_key, random.normal(target_key, env.obs_shape))
jax.tree_map(lambda x: x.shape, online_net_params)

{'conv2_d': {'b': (16,), 'w': (3, 3, 4, 16)},
 'mlp/~/linear_0': {'b': (128,), 'w': (1600, 128)},
 'mlp/~/linear_1': {'b': (3,), 'w': (128, 3)}}

In [44]:
state, env_state = env.reset(key)
model.apply(online_net_params, None, state)

Array([ 0.0912613 , -0.10809293, -0.02150637], dtype=float32)