In [1]:
import sys
import jax
import jax.numpy as jnp
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import optax
import haiku as hk
import plotly.graph_objects as go
import numpy as np
import chex 

from functools import partial
from jax import random, vmap, lax, tree_map
from chex import dataclass
from jax_tqdm import loop_tqdm
from typing import Tuple, List, Dict

sys.path.append("../../")
from jym import (
    Breakout,
    DQN_PER,
    per_rollout,
    SumTree,
    Experience,
    PrioritizedExperienceReplay,
)

  from .autonotebook import tqdm as notebook_tqdm


In [55]:
@chex.dataclass
class EnvState:
    ball_x: int
    last_x: int
    ball_y: int
    last_y: int
    ball_dir: int
    pos: int
    brick_map: jnp.ndarray
    strike: bool
    time: int
    done: bool

@chex.dataclass
class EnvState_:
    ball_x: int
    last_x: int
    ball_y: int
    last_y: int
    ball_dir: int
    pos: int
    brick_map: jnp.ndarray
    strike: bool
    time: int
    terminal: bool

key = random.PRNGKey(0)
env = Breakout()
state, obs, env_key = env.reset(key)


expected = {"x": 9, "y": 3, "ball_dir": 3}
expected_state = EnvState(
    ball_x=expected["x"],
    last_x=expected["x"],
    ball_y=expected["y"],
    last_y=expected["y"],
    ball_dir=expected["ball_dir"],
    pos=4,
    brick_map=jnp.zeros((10, 10)).at[1:4, :].set(1),
    strike=False,
    time=0,
    done=False,
)

for i in state:
    print(i, jnp.all(state[i] == expected_state[i]))


ball_x True
last_x True
ball_y True
last_y True
ball_dir True
pos True
brick_map True
strike True
time True
done True




In [3]:
env._get_obs(state).sum(axis=-1)



Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 3.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]], dtype=float32)

def play_n_steps(n: int):
    key = random.PRNGKey(0)
    env = Breakout()
    state, obs, env_key = env.reset(key)

    for step in range(n):
        print(f"Game state, step {step}:")
        print(obs.sum(axis=-1))
        action = int(input("Select an action:"))
        state, next_state, reward, done, env_key = env.step(state, env_key, action)
    return state, obs

state, obs = play_n_steps(15)

In [112]:
expected = {
    "start_coord": (0, 6),
    "ball_dir": 3,
    "next_ball_dir": 2,
    "next_coord": (1, 7),
}
env = Breakout()
state = EnvState_(
    ball_x=expected["start_coord"][0],
    last_x=expected["start_coord"][0],
    ball_y=expected["start_coord"][1],
    last_y=expected["start_coord"][1],
    ball_dir=expected["ball_dir"],
    pos=9,
    brick_map=jnp.zeros((10, 10)).at[1:4, :].set(1),
    strike=False,
    time=0,
    terminal=False,
)

state_ = EnvState(
    ball_x=expected["start_coord"][0],
    last_x=expected["start_coord"][0],
    ball_y=expected["start_coord"][1],
    last_y=expected["start_coord"][1],
    ball_dir=expected["ball_dir"],
    pos=4,
    brick_map=jnp.zeros((10, 10)).at[1:4, :].set(1),
    strike=False,
    time=0,
    done=False,
)
# state, new_x, new_y = env._agent_step(state, action=0)
# print(state.ball_x, state.ball_y, new_x, new_y)
# state, reward = env._step_ball_brick(state, new_x, new_y)
# print(state.ball_x, state.ball_y)
# state, new_x, new_y = env._agent_step(state, action=0)
# print(state.ball_x, state.ball_y, new_x, new_y)
# state, reward = env._step_ball_brick(state, new_x, new_y)
# print(state.ball_x, state.ball_y)

In [104]:
def step_agent_(state: EnvState, action: int) -> Tuple[EnvState, int, int]:
    """Helper that steps the agent and checks boundary conditions."""
    # Update player position
    pos = (
        # Action left & border condition
        jnp.maximum(0, state.pos - 1) * (action == 1)
        # Action right & border condition
        + jnp.minimum(9, state.pos + 1) * (action == 3)
        # Don't move player if not l/r chosen
        + state.pos * jnp.logical_and(action != 1, action != 3)
    )

    # Update ball position - based on direction of movement
    last_x = state.ball_x
    last_y = state.ball_y
    new_x = (
        (state.ball_x - 1) * (state.ball_dir == 0)
        + (state.ball_x + 1) * (state.ball_dir == 1)
        + (state.ball_x + 1) * (state.ball_dir == 2)
        + (state.ball_x - 1) * (state.ball_dir == 3)
    )
    new_y = (
        (state.ball_y - 1) * (state.ball_dir == 0)
        + (state.ball_y - 1) * (state.ball_dir == 1)
        + (state.ball_y + 1) * (state.ball_dir == 2)
        + (state.ball_y + 1) * (state.ball_dir == 3)
    )

    # Boundary conditions for x position
    border_cond_x = jnp.logical_or(new_x < 0, new_x > 9)
    new_x = jax.lax.select(
        border_cond_x, (0 * (new_x < 0) + 9 * (new_x > 9)), new_x
    )
    # Reflect ball direction if bounced off at x border
    ball_dir = jax.lax.select(
        border_cond_x, jnp.array([1, 0, 3, 2])[state.ball_dir], state.ball_dir
    )
    return (
        state.replace(
            pos=pos,
            last_x=last_x,
            last_y=last_y,
            ball_dir=ball_dir,
        ),
        new_x,
        new_y,
    )


def step_ball_brick_(
    state: EnvState, new_x: int, new_y: int
) -> Tuple[EnvState, float]:
    """Helper that computes reward and termination cond. from brickmap."""
    reward = 0

    # Reflect ball direction if bounced off at y border
    border_cond1_y = new_y < 0
    new_y = lax.select(border_cond1_y, 0, new_y)
    ball_dir = lax.select(
        border_cond1_y, jnp.array([3, 2, 1, 0])[state.ball_dir], state.ball_dir
    )

    # 1st NASTY ELIF BEGINS HERE... = Brick collision
    strike_toggle = jnp.logical_and(
        1 - border_cond1_y, state.brick_map[new_y, new_x] == 1
    )
    strike_bool = jnp.logical_and((1 - state.strike), strike_toggle)
    reward += strike_bool * 1.0
    strike = jax.lax.select(strike_toggle, strike_bool, False)

    brick_map = jax.lax.select(
        strike_bool, state.brick_map.at[new_y, new_x].set(0), state.brick_map
    )
    new_y = jax.lax.select(strike_bool, state.last_y, new_y)
    ball_dir = jax.lax.select(
        strike_bool, jnp.array([3, 2, 1, 0])[ball_dir], ball_dir
    )

    # 2nd NASTY ELIF BEGINS HERE... = Wall collision
    brick_cond = jnp.logical_and(1 - strike_toggle, new_y == 9)

    # Spawn new bricks if there are no more around - everything is collected
    spawn_bricks = jnp.logical_and(
        brick_cond, jnp.count_nonzero(brick_map) == 0
    )
    brick_map = jax.lax.select(
        spawn_bricks, brick_map.at[1:4, :].set(1), brick_map
    )

    # Redirect ball because it collided with old player position
    redirect_ball1 = jnp.logical_and(brick_cond, state.ball_x == state.pos)
    ball_dir = jax.lax.select(
        redirect_ball1, jnp.array([3, 2, 1, 0])[ball_dir], ball_dir
    )
    new_y = jax.lax.select(redirect_ball1, state.last_y, new_y)

    # Redirect ball because it collided with new player position
    redirect_ball2a = jnp.logical_and(brick_cond, 1 - redirect_ball1)
    redirect_ball2 = jnp.logical_and(redirect_ball2a, new_x == state.pos)
    ball_dir = jax.lax.select(
        redirect_ball2, jnp.array([2, 3, 0, 1])[ball_dir], ball_dir
    )
    new_y = jax.lax.select(redirect_ball2, state.last_y, new_y)
    redirect_cond = jnp.logical_and(1 - redirect_ball1, 1 - redirect_ball2)
    terminal = jnp.logical_and(brick_cond, redirect_cond)

    strike = jax.lax.select(1 - strike_toggle == 1, False, True)
    return (
        state.replace(
            ball_dir=ball_dir,
            brick_map=brick_map,
            strike=strike,
            ball_x=new_x,
            ball_y=new_y,
            terminal=terminal,
        ),
        reward,
    )

for _ in range(20):
    state, new_x_, new_y_ = step_agent_(state, 0)
    print(state.ball_x, state.ball_y)
    state, reward = step_ball_brick_(state, new_x_, new_y_)
    print(state.ball_x, state.ball_y)

0 6
0 7
0 7
1 8
1 8
2 9
2 9
3 10
3 10
4 11
4 11
5 12
5 12
6 13
6 13
7 14
7 14
8 15
8 15
9 16
9 16
9 17
9 17
8 18
8 18
7 19
7 19
6 20
6 20
5 21
5 21
4 22
4 22
3 23
3 23
2 24
2 24
1 25
1 25
0 26


In [113]:
for _ in range(20):
    state_, new_x_, new_y_ = env._agent_step(state_, 0)
    state_, reward = env._step_ball_brick(state_, new_x_, new_y_)
    print(state_.ball_x, state_.ball_y, state_.done)


0 7 False
1 8 False
2 9 True
3 9 True
4 9 False
3 8 False
2 7 False
1 6 False
0 5 False
0 4 False
1 4 False
2 5 False
3 6 False
4 7 False
5 8 False
6 9 True
7 9 True
8 9 True
9 9 True
9 9 True


In [11]:
def _dont_move(pos):
    return pos


def _move_left(pos):
    return jnp.clip(pos - 1, 0, 9)


def _move_right(pos):
    return jnp.clip(pos + 1, 0, 9)


# Update agent's position based on the action
pos = lax.switch(
    0,
    [_dont_move, _move_left, _move_right],
    operand=state.pos,
)
last_x = state.ball_x
last_y = state.ball_y

# Update ball position based on its direction
new_x = lax.cond(
    # if the ball is moving right
    jnp.isin(state.ball_dir, jnp.array([1, 2])),
    lambda x: x + 1,
    lambda x: x - 1,
    operand=state.ball_x,
)
new_y = lax.cond(
    # if the ball is moving up
    jnp.isin(state.ball_dir, jnp.array([2, 3])),
    lambda y: y + 1,
    lambda y: y - 1,
    operand=state.ball_y,
)

# Check if the ball's new position is within the grid boundaries
border_cond_x = jnp.logical_or(new_x < 0, new_x > 9)
new_x = jnp.clip(new_x, 0, 9)  # Ensure ball stays within the grid

# Update ball's direction if it hits horizontal boundaries
ball_dir = lax.select(
    border_cond_x,
    jnp.array([1, 0, 3, 2])[state.ball_dir],
    state.ball_dir,
)

state, new_x, new_y = (
    state.replace(
        pos=pos,
        last_x=last_x,
        last_y=last_y,
        ball_dir=ball_dir,
    ),
    new_x,
    new_y,
)
state, new_x, new_y

(EnvState(ball_x=0, last_x=0, ball_y=6, last_y=6, ball_dir=Array(2, dtype=int32), pos=Array(4, dtype=int32, weak_type=True), brick_map=Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), strike=False, time=0, done=False),
 Array(0, dtype=int32, weak_type=True),
 Array(7, dtype=int32, weak_type=True))

In [7]:
assert

SyntaxError: invalid syntax (2389114725.py, line 1)

In [58]:
env = Breakout()
expected = {
    "start_coord": (0, 6),
    "ball_dir": 3,
    "next_ball_dir": 2,
    "next_coord": (1, 7),
}
state = EnvState(
    ball_x=expected["start_coord"][0],
    last_x=expected["start_coord"][0],
    ball_y=expected["start_coord"][1],
    last_y=expected["start_coord"][1],
    ball_dir=expected["ball_dir"],
    pos=4,
    brick_map=jnp.zeros((10, 10)).at[1:4, :].set(1),
    strike=False,
    time=0,
    done=False,
)
# new_x, new_y = expected["next_coord"]
state, new_x, new_y = env._agent_step(state, action=0)
new_x, new_y


(Array(0, dtype=int32, weak_type=True), Array(7, dtype=int32, weak_type=True))

In [None]:
reward = 0

# Reflect the ball's direction if it hits the top border
border_cond_y = new_y < 0
print(new_y)
new_y = jnp.clip(new_y, 0, 9)  # Ensure new_y remains within the grid

print(border_cond_y)
print(new_y)

7
False
7


In [None]:

# Check for collision with a brick
strike_toggle = jnp.logical_and(
    jnp.invert(border_cond_y), state.brick_map[new_y, new_x] == 1
)
strike_bool = jnp.logical_and(jnp.invert(state.strike), strike_toggle)
reward += jnp.float32(strike_bool)  # Increment reward if a brick is struck

print(strike_toggle, strike_bool, reward)

False False 0.0


In [None]:

# Remove the brick on collision
brick_map = lax.select(
    strike_bool, state.brick_map.at[new_y, new_x].set(0), state.brick_map
)

# Update ball position and direction post-collision with brick
new_y = lax.select(strike_bool, state.last_y, new_y)
ball_dir = lax.select(
    strike_bool, jnp.array([3, 2, 1, 0])[state.ball_dir], state.ball_dir
)
print(new_y, ball_dir)

7 2


In [None]:

# Check for ball at the bottom row but not colliding with a brick
brick_cond = jnp.logical_and(jnp.invert(strike_toggle), new_y == 9)

# Spawn new bricks if all are cleared
spawn_bricks = jnp.logical_and(brick_cond, jnp.count_nonzero(brick_map) == 0)
brick_map = lax.select(spawn_bricks, brick_map.at[1:4, :].set(1), brick_map)
print(brick_cond, spawn_bricks)

False False


In [None]:

# Handle ball collision with paddle's old position
redirect_ball_old = jnp.logical_and(brick_cond, state.ball_x == state.pos)
ball_dir = lax.select(redirect_ball_old, jnp.array([3, 2, 1, 0])[ball_dir], ball_dir)
new_y = lax.select(redirect_ball_old, state.last_y, new_y)
print(redirect_ball_old, ball_dir, new_y)

False 2 7


In [None]:
# Handle ball collision with paddle's new position
collision_new_pos = jnp.logical_and(brick_cond, jnp.invert(redirect_ball_old))
redirect_ball_new = jnp.logical_and(collision_new_pos, new_x == state.pos)
ball_dir = lax.select(
    redirect_ball_new, jnp.array([2, 3, 0, 1])[ball_dir], ball_dir
)
new_y = lax.select(redirect_ball_new, state.last_y, new_y)
print(collision_new_pos, redirect_ball_new, ball_dir, new_y)

False False 2 7


In [None]:
state.pos

Array(4, dtype=int32, weak_type=True)

In [None]:

# Check if ball missed the paddle
not_redirected = jnp.logical_and(
    jnp.invert(redirect_ball_old), jnp.invert(redirect_ball_new)
)

# The game ends if the ball is on the bottom row and is not being redirected
done = jnp.logical_and(brick_cond, not_redirected)

# Update the strike state
strike = jnp.bool_(strike_toggle)

print(done)

new_state, reward = (
    state.replace(
        ball_dir=ball_dir,
        brick_map=brick_map,
        strike=strike,
        ball_x=new_x,
        ball_y=new_y,
        done=done,
    ),
    reward,
)

assert new_state.done == jnp.array(False, dtype=jnp.bool_)
assert reward == 0.0

False


In [None]:
new_state.ball_x

Array(0, dtype=int32, weak_type=True)