In [1]:
import sys
import os

# Add the path to sys.path for the current Python session
new_path = "/home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax"

# Add to sys.path if not already there
if new_path not in sys.path:
    sys.path.insert(0, new_path)

# Also set PYTHONPATH for any subprocesses
os.environ["PYTHONPATH"] = f"{new_path}:{os.environ.get('PYTHONPATH', '')}"

# Verify it worked
print("Python path updated:")
print(f"sys.path includes: {new_path}")
print(f"PYTHONPATH env var: {os.environ['PYTHONPATH']}")

Python path updated:
sys.path includes: /home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax
PYTHONPATH env var: /home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax:


In [2]:
import sys
import math

import jax.numpy as jnp


def get_sizes(state):
    try:
        size = len(state.current_player)
        width = math.ceil(math.sqrt(size - 0.1))
        if size - (width - 1) ** 2 >= width:
            height = width
        else:
            height = width - 1
    except TypeError:
        size = 1
        width = 1
        height = 1
    return size, width, height


def get_cmap(n_channels):
    # import seaborn as sns  # type: ignore
    # return cmap = sns.color_palette("cubehelix", n_channels)
    assert n_channels in (4, 6, 7, 10)
    if n_channels == 4:
        return [(0.08605633600581405, 0.23824692404212, 0.30561236308077167), (0.32927729263408284, 0.4762845556584382, 0.1837155549758328), (0.8146245329198283, 0.49548316572322215, 0.5752525936416857), (0.7587183008012618, 0.7922069335474338, 0.9543861221913403)]
    elif n_channels == 6:
        return [(0.10231025194333628, 0.13952898866828906, 0.2560120319409181), (0.10594361078604106, 0.3809739011595331, 0.27015111282899046), (0.4106130272672762, 0.48044780541672255, 0.1891154277778484), (0.7829183382530567, 0.48158303462490826, 0.48672451968362596), (0.8046168329276406, 0.6365733569301846, 0.8796578402926125), (0.7775608374378459, 0.8840392521212448, 0.9452007992345052)]
    elif n_channels == 7:
        return [(0.10419418740482515, 0.11632019220053316, 0.2327552016195138), (0.08523511613408935, 0.32661779003565533, 0.2973201282529313), (0.26538761550634205, 0.4675654910052002, 0.1908220644759285), (0.6328422475018423, 0.4747981096220677, 0.29070209208025455), (0.8306875710682655, 0.5175161303658079, 0.6628221028832032), (0.7779565181455343, 0.7069421942599752, 0.9314406084043191), (0.7964528047840354, 0.908668973545918, 0.9398253500983916)]
    elif n_channels == 10:
        return [(0.09854228363950114, 0.07115215572295082, 0.16957891809124037), (0.09159726558869188, 0.20394337960213008, 0.29623965888210324), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.23627685553553793, 0.46114369021199075, 0.19770731888985724), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.7354526513473981, 0.4748861903571046, 0.40254094042448907), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.7936920632275369, 0.6641337211433709, 0.9042311843062529), (0.7588424692372241, 0.8253990353420474, 0.9542699331220588), (0.8385645211683802, 0.9411869386771845, 0.9357655639413166)]


# /home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax/pgx/minatar/utils.py

def visualize_minatar(state, savefile=None, fmt="svg", dpi=160):
    # Modified from https://github.com/kenjyoung/MinAtar
    try:
        import matplotlib.colors as colors  # type: ignore
        import matplotlib.pyplot as plt  # type: ignore
    except ImportError:
        sys.stderr.write("MinAtar environment requires matplotlib for visualization. Please install matplotlib.")
        sys.exit(1)

    obs = state.observation
    n_channels = obs.shape[-1]
    cmap = get_cmap(n_channels)
    cmap.insert(0, (0, 0, 0))
    cmap = colors.ListedColormap(cmap)
    bounds = [i for i in range(n_channels + 2)]
    norm = colors.BoundaryNorm(bounds, n_channels + 1)
    size, w, h = get_sizes(state)
    fig, ax = plt.subplots(h, w)
    n_channels = obs.shape[-1]
    if size == 1:
        numerical_state = (
            jnp.amax(
                obs * jnp.reshape(jnp.arange(n_channels) + 1, (1, 1, -1)), 2
            )
            + 0.5
        )
        ax.imshow(numerical_state, cmap=cmap, norm=norm, interpolation="none")
        ax.set_axis_off()
    else:
        for j in range(size):
            numerical_state = (
                jnp.amax(
                    obs[j] * jnp.reshape(jnp.arange(n_channels) + 1, (1, 1, -1)),
                    2,
                )
                + 0.5
            )
            if h == 1:
                ax[j].imshow(numerical_state, cmap=cmap, norm=norm, interpolation="none")
                ax[j].set_axis_off()
            else:
                ax[j // w, j % w].imshow(numerical_state, cmap=cmap, norm=norm, interpolation="none")
                ax[j // w, j % w].set_axis_off()

    if savefile is None:
        # Return in-memory image
        if fmt == "svg":
            from io import StringIO
            sio = StringIO()
            plt.savefig(sio, format="svg", bbox_inches="tight")
            plt.close(fig)
            return sio.getvalue()  # str (SVG markup)
        else:
            from io import BytesIO
            bio = BytesIO()
            plt.savefig(bio, format=fmt, bbox_inches="tight", dpi=dpi)
            plt.close(fig)
            bio.seek(0)
            return bio.getvalue()  # bytes (e.g., PNG)
    else:
        plt.savefig(savefile, format=fmt, bbox_inches="tight", dpi=(None if fmt == "svg" else dpi))
        plt.close(fig)
        return savefile


In [7]:
"""
MinAtar/Freeway with JIT-compatible K-frame skipping.

Changes vs. baseline:
- Add `frame_skip: int = 2` to MinAtarFreeway.__init__.
- Implement frame skipping inside `_step` via `jax.lax.fori_loop`.
- Compute one sticky-processed effective action per external step and repeat it
  for `frame_skip` internal steps, accumulating rewards. Each micro-step samples
  new car speeds/directions (as in the original logic).
"""

from typing import Literal, Optional

import jax
from jax import numpy as jnp

import pgx.core as core
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey

player_speed = jnp.array(3, dtype=jnp.int32)
time_limit = jnp.array(2500, dtype=jnp.int32)

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)
ZERO = jnp.array(0, dtype=jnp.int32)
ONE = jnp.array(1, dtype=jnp.int32)
NINE = jnp.array(9, dtype=jnp.int32)


@dataclass
class State(core.State):
    current_player: Array = jnp.int32(0)
    observation: Array = jnp.zeros((10, 10, 7), dtype=jnp.bool_)
    rewards: Array = jnp.zeros(1, dtype=jnp.float32)  # (1,)
    terminated: Array = FALSE
    truncated: Array = FALSE
    legal_action_mask: Array = jnp.ones(3, dtype=jnp.bool_)
    _step_count: Array = jnp.int32(0)
    # --- MinAtar Freeway specific ---
    _cars: Array = jnp.zeros((8, 4), dtype=jnp.int32)
    _pos: Array = jnp.array(9, dtype=jnp.int32)
    _move_timer: Array = jnp.array(player_speed, dtype=jnp.int32)
    _terminate_timer: Array = jnp.array(time_limit, dtype=jnp.int32)
    _terminal: Array = jnp.array(False, dtype=jnp.bool_)
    _last_action: Array = jnp.array(0, dtype=jnp.int32)

    @property
    def env_id(self) -> core.EnvId:
        return "minatar-freeway"

    def to_svg(
        self,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> str:
        del color_theme, scale
        from .utils import visualize_minatar

        return visualize_minatar(self)

    def save_svg(
        self,
        filename,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> None:
        from .utils import visualize_minatar

        visualize_minatar(self, filename)


class MinAtarFreeway(core.Env):
    def __init__(
        self,
        *,
        use_minimal_action_set: bool = True,
        sticky_action_prob: float = 0.1,
        frame_skip: int = 2,  # NEW: K-frame skipping (default 2)
    ):
        super().__init__()
        assert frame_skip >= 1, "frame_skip must be >= 1"
        self.use_minimal_action_set = use_minimal_action_set
        self.sticky_action_prob: float = float(sticky_action_prob)
        self.frame_skip: int = int(frame_skip)

        self.minimal_action_set = jnp.int32([0, 2, 4])
        self.legal_action_mask = jnp.ones(6, dtype=jnp.bool_)
        if self.use_minimal_action_set:
            self.legal_action_mask = jnp.ones(
                self.minimal_action_set.shape[0], dtype=jnp.bool_
            )

    def step(
        self, state: core.State, action: Array, key: Optional[Array] = None
    ) -> core.State:
        assert key is not None, (
            "v2.0.0 changes the signature of step. Please specify PRNGKey at the third argument:\n\n"
            "  * <  v2.0.0: step(state, action)\n"
            "  * >= v2.0.0: step(state, action, key)\n\n"
            "See v2.0.0 release note for more details:\n\n"
            "  https://github.com/sotetsuk/pgx/releases/tag/v2.0.0"
        )
        return super().step(state, action, key)

    def _init(self, key: PRNGKey) -> State:
        state = _init(rng=key)  # type: ignore
        state = state.replace(legal_action_mask=self.legal_action_mask)  # type: ignore
        return state  # type: ignore

    def _step(self, state: core.State, action, key) -> State:
        assert isinstance(state, State)
        # Keep mask current
        state = state.replace(legal_action_mask=self.legal_action_mask)  # type: ignore

        # Map minimal action set if enabled
        action = jax.lax.select(
            jnp.bool_(self.use_minimal_action_set),
            self.minimal_action_set[action],
            action,
        )
        action = jnp.int32(action)

        # Sticky override ONCE per external step; then repeat effective action
        key_sticky, key_loop = jax.random.split(key, 2)
        effective_action = jax.lax.cond(
            jax.random.uniform(key_sticky) < self.sticky_action_prob,
            lambda: state._last_action,
            lambda: action,
        )
        effective_action = jnp.int32(effective_action)

        # fori_loop carry: (State, total_reward(float32), done(bool), rng)
        def body_fn(i, carry):
            s, rsum, done, rng = carry

            def do_step(args):
                s_inner, rsum_inner, _done_inner, rng_inner = args
                rng_inner, sub = jax.random.split(rng_inner)
                speeds, directions = _random_speed_directions(sub)
                s_next = _step_det(state=s_inner, action=effective_action, speeds=speeds, directions=directions)
                rsum_next = rsum_inner + s_next.rewards[0]
                done_next = s_next.terminated
                return (s_next, rsum_next, done_next, rng_inner)

            return jax.lax.cond(done, lambda x: x, do_step, (s, rsum, done, rng))

        r0 = jnp.array(0.0, dtype=jnp.float32)
        d0 = FALSE
        state, total_r, _, _ = jax.lax.fori_loop(
            0, int(self.frame_skip), body_fn, (state, r0, d0, key_loop)
        )

        # Overwrite reward with accumulated total for this external step
        state = state.replace(rewards=total_r[jnp.newaxis])  # type: ignore
        return state  # type: ignore

    def _observe(self, state: core.State, player_id: Array) -> Array:
        assert isinstance(state, State)
        return _observe(state)

    @property
    def id(self) -> core.EnvId:
        return "minatar-freeway"

    @property
    def version(self) -> str:
        return "v1"

    @property
    def num_players(self):
        return 1


def _step(
    state: State,
    action: Array,
    key,
    sticky_action_prob,
):
    # (Kept for API completeness; not used when class-level frame_skip is active)
    action = jnp.int32(action)
    key0, key1 = jax.random.split(key, 2)
    action = jax.lax.cond(
        jax.random.uniform(key0) < sticky_action_prob,
        lambda: state._last_action,
        lambda: action,
    )
    speeds, directions = _random_speed_directions(key1)
    return _step_det(state, action, speeds=speeds, directions=directions)


def _init(rng: Array) -> State:
    speeds, directions = _random_speed_directions(rng)
    return _init_det(speeds=speeds, directions=directions)


def _step_det(
    state: State,
    action: Array,
    speeds: Array,
    directions: Array,
):
    cars = state._cars
    pos = state._pos
    move_timer = state._move_timer
    terminate_timer = state._terminate_timer
    terminal = state._terminal
    last_action = action

    r = jnp.array(0, dtype=jnp.float32)

    move_timer, pos = jax.lax.cond(
        (action == 2) & (move_timer == 0),
        lambda: (player_speed, jax.lax.max(ZERO, pos - ONE)),
        lambda: (move_timer, pos),
    )
    move_timer, pos = jax.lax.cond(
        (action == 4) & (move_timer == 0),
        lambda: (player_speed, jax.lax.min(NINE, pos + ONE)),
        lambda: (move_timer, pos),
    )

    # Win condition
    cars, r, pos = jax.lax.cond(
        pos == 0,
        lambda: (
            _randomize_cars(speeds, directions, cars, initialize=False),
            r + 1,
            NINE,
        ),
        lambda: (cars, r, pos),
    )

    pos, cars = _update_cars(pos, cars)

    # Update various timers
    move_timer = jax.lax.cond(
        move_timer > 0, lambda: move_timer - 1, lambda: move_timer
    )
    terminate_timer -= ONE
    terminal = terminate_timer < 0

    next_state = state.replace(  # type: ignore
        _cars=cars,
        _pos=pos,
        _move_timer=move_timer,
        _terminate_timer=terminate_timer,
        _terminal=terminal,
        _last_action=last_action,
        rewards=r[jnp.newaxis],
        terminated=terminal,
    )

    return next_state


def _update_cars(pos, cars):
    def _update_stopped_car(pos, car):
        car = car.at[2].set(jax.lax.abs(car[3]))
        car = jax.lax.cond(
            car[3] > 0, lambda: car.at[0].add(1), lambda: car.at[0].add(-1)
        )
        car = jax.lax.cond(car[0] < 0, lambda: car.at[0].set(9), lambda: car)
        car = jax.lax.cond(car[0] > 9, lambda: car.at[0].set(0), lambda: car)
        pos = jax.lax.cond(
            (car[0] == 4) & (car[1] == pos), lambda: NINE, lambda: pos
        )
        return pos, car

    def _update_car(pos, car):
        pos = jax.lax.cond(
            (car[0] == 4) & (car[1] == pos), lambda: NINE, lambda: pos
        )
        pos, car = jax.lax.cond(
            car[2] == 0,
            lambda: _update_stopped_car(pos, car),
            lambda: (pos, car.at[2].add(-1)),
        )
        return pos, car

    pos, cars = jax.lax.scan(_update_car, pos, cars)

    return pos, cars


def _init_det(speeds: Array, directions: Array) -> State:
    cars = _randomize_cars(speeds, directions, initialize=True)
    return State(_cars=cars)  # type: ignore


def _randomize_cars(
    speeds: Array,
    directions: Array,
    cars: Array = jnp.zeros((8, 4), dtype=int),
    initialize: bool = False,
) -> Array:
    speeds *= directions

    def _init(_cars):
        _cars = _cars.at[:, 1].set(jnp.arange(1, 9))
        _cars = _cars.at[:, 2].set(jax.lax.abs(speeds))
        _cars = _cars.at[:, 3].set(speeds)
        return _cars

    def _update(_cars):
        _cars = _cars.at[:, 2].set(abs(speeds))
        _cars = _cars.at[:, 3].set(speeds)
        return _cars

    return jax.lax.cond(initialize, _init, _update, cars)


def _random_speed_directions(rng):
    rng1, rng2 = jax.random.split(rng, 2)
    speeds = jax.random.randint(rng1, [8], 1, 6, dtype=jnp.int32)
    directions = jax.random.choice(
        rng2, jnp.array([-1, 1], dtype=jnp.int32), [8]
    )
    return speeds, directions


def _observe(state: State) -> Array:
    obs = jnp.zeros((10, 10, 7), dtype=jnp.bool_)
    obs = obs.at[state._pos, 4, 0].set(TRUE)

    def _update_obs(i, _obs):
        car = state._cars[i]
        _obs = _obs.at[car[1], car[0], 1].set(TRUE)
        back_x = jax.lax.cond(
            car[3] > 0, lambda: car[0] - 1, lambda: car[0] + 1
        )
        back_x = jax.lax.cond(back_x < 0, lambda: NINE, lambda: back_x)
        back_x = jax.lax.cond(back_x > 9, lambda: ZERO, lambda: back_x)
        trail = jax.lax.abs(car[3]) + 1
        _obs = _obs.at[car[1], back_x, trail].set(TRUE)
        return _obs

    obs = jax.lax.fori_loop(0, 8, _update_obs, obs)
    return obs


In [9]:
"""
MinAtar/Breakout with JIT-compatible K-frame skipping.

Changes vs. baseline:
- Add `frame_skip: int = 2` to MinAtarBreakout.__init__.
- Implement frame skipping inside `_step` via `jax.lax.fori_loop`.
- Compute a single effective action (after sticky override) and repeat it
  for `frame_skip` internal steps, accumulating rewards. Observation comes
  from the final internal step, and `terminated` reflects any terminal reached
  during the repeated steps.
"""

from typing import Literal, Optional

import jax
from jax import numpy as jnp

import pgx.core as core
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)
ZERO = jnp.array(0, dtype=jnp.int32)
ONE = jnp.array(1, dtype=jnp.int32)
TWO = jnp.array(2, dtype=jnp.int32)
THREE = jnp.array(3, dtype=jnp.int32)
FOUR = jnp.array(4, dtype=jnp.int32)
NINE = jnp.array(9, dtype=jnp.int32)


@dataclass
class State(core.State):
    current_player: Array = jnp.int32(0)
    observation: Array = jnp.zeros((10, 10, 4), dtype=jnp.bool_)
    rewards: Array = jnp.zeros(1, dtype=jnp.float32)  # (1,)
    terminated: Array = FALSE
    truncated: Array = FALSE
    legal_action_mask: Array = jnp.ones(3, dtype=jnp.bool_)
    _step_count: Array = jnp.int32(0)
    # --- MinAtar Breakout specific ---
    _ball_y: Array = THREE
    _ball_x: Array = ZERO
    _ball_dir: Array = TWO
    _pos: Array = FOUR
    _brick_map: Array = (
        jnp.zeros((10, 10), dtype=jnp.bool_).at[1:4, :].set(True)
    )
    _strike: Array = jnp.array(False, dtype=jnp.bool_)
    _last_x: Array = ZERO
    _last_y: Array = THREE
    _terminal: Array = jnp.array(False, dtype=jnp.bool_)
    _last_action: Array = ZERO

    @property
    def env_id(self) -> core.EnvId:
        return "minatar-breakout"

    def to_svg(
        self,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> str:
        del color_theme, scale
        from .utils import visualize_minatar

        return visualize_minatar(self)

    def save_svg(
        self,
        filename,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> None:
        from .utils import visualize_minatar

        visualize_minatar(self, filename)


class MinAtarBreakout(core.Env):
    def __init__(
        self,
        *,
        use_minimal_action_set: bool = True,
        sticky_action_prob: float = 0.1,
        frame_skip: int = 2,  # NEW: K-frame skipping (default 2)
    ):
        super().__init__()
        assert frame_skip >= 1, "frame_skip must be >= 1"
        self.use_minimal_action_set = use_minimal_action_set
        self.sticky_action_prob: float = float(sticky_action_prob)
        self.frame_skip: int = int(frame_skip)

        # Minimal action set mapping (NOOP/LEFT/RIGHT for Breakout)
        self.minimal_action_set = jnp.int32([0, 1, 3])

        # Legal mask is either 6 (full) or 3 (minimal)
        self.legal_action_mask = jnp.ones(6, dtype=jnp.bool_)
        if self.use_minimal_action_set:
            self.legal_action_mask = jnp.ones(
                self.minimal_action_set.shape[0], dtype=jnp.bool_
            )

    def step(
        self, state: core.State, action: Array, key: Optional[Array] = None
    ) -> core.State:
        assert key is not None, (
            "v2.0.0 changes the signature of step. Please specify PRNGKey at the third argument:\n\n"
            "  * <  v2.0.0: step(state, action)\n"
            "  * >= v2.0.0: step(state, action, key)\n\n"
            "See v2.0.0 release note for more details:\n\n"
            "  https://github.com/sotetsuk/pgx/releases/tag/v2.0.0"
        )
        return super().step(state, action, key)

    def _init(self, key: PRNGKey) -> State:
        state = _init(rng=key)  # type: ignore
        state = state.replace(legal_action_mask=self.legal_action_mask)  # type: ignore
        return state  # type: ignore

    def _step(self, state: core.State, action, key) -> State:
        """One external env step = repeat the (sticky-processed) action for `frame_skip` internal steps."""
        # Ensure the state carries the current legal mask
        state = state.replace(legal_action_mask=self.legal_action_mask)  # type: ignore

        # Minimal action set mapping (JAX-friendly select)
        action = jax.lax.select(
            jnp.bool_(self.use_minimal_action_set),
            self.minimal_action_set[action],
            action,
        )

        # Compute effective action once (sticky override) for this macro-step
        # Then repeat that effective action for `frame_skip` internal steps.
        effective_action = jax.lax.cond(
            jax.random.uniform(key) < self.sticky_action_prob,
            lambda: jnp.int32(state._last_action),
            lambda: jnp.int32(action),
        )

        # fori_loop carry: (State, total_reward(float32), done(bool))
        def body_fn(i, carry):
            s, rsum, done = carry

            def do_step(args):
                s_inner, rsum_inner, _ = args
                s_next = _step_det(s_inner, effective_action)
                rsum_next = rsum_inner + s_next.rewards[0]
                done_next = s_next.terminated
                return (s_next, rsum_next, done_next)

            # If already done, keep state as-is (no-op), preserving JIT compatibility
            return jax.lax.cond(done, lambda x: x, do_step, (s, rsum, done))

        r0 = jnp.array(0.0, dtype=jnp.float32)
        d0 = FALSE
        state, total_r, _ = jax.lax.fori_loop(0, int(self.frame_skip), body_fn, (state, r0, d0))

        # Overwrite rewards with the accumulated total for this external step
        state = state.replace(rewards=total_r[jnp.newaxis])  # type: ignore
        return state  # type: ignore

    def _observe(self, state: core.State, player_id: Array) -> Array:
        assert isinstance(state, State)
        return _observe(state)

    @property
    def id(self) -> core.EnvId:
        return "minatar-breakout"

    @property
    def version(self) -> str:
        return "v1"

    @property
    def num_players(self):
        return 1


def _step_det(state: State, action: Array):
    ball_y = state._ball_y
    ball_x = state._ball_x
    ball_dir = state._ball_dir
    pos = state._pos
    brick_map = state._brick_map
    strike = state._strike
    terminal = state._terminal

    r = jnp.array(0, dtype=jnp.float32)

    pos = _apply_action(pos, action)

    # Update ball position
    last_x = ball_x
    last_y = ball_y
    new_x, new_y = _update_ball_pos(ball_x, ball_y, ball_dir)

    new_x, ball_dir = jax.lax.cond(
        (new_x < 0) | (new_x > 9),
        lambda: _update_ball_pos_x(new_x, ball_dir),
        lambda: (new_x, ball_dir),
    )

    is_new_y_negative = new_y < 0
    is_strike = brick_map[new_y, new_x] == 1
    is_bottom = new_y == 9
    new_y, ball_dir = jax.lax.cond(
        is_new_y_negative,
        lambda: _update_ball_pos_y(ball_dir),
        lambda: (new_y, ball_dir),
    )
    strike_toggle = ~is_new_y_negative & is_strike
    r, strike, brick_map, new_y, ball_dir = jax.lax.cond(
        ~is_new_y_negative & is_strike & ~strike,
        lambda: _update_by_strike(
            r, brick_map, new_x, new_y, last_y, ball_dir, strike
        ),
        lambda: (r, strike, brick_map, new_y, ball_dir),
    )
    brick_map, new_y, ball_dir, terminal = jax.lax.cond(
        ~is_new_y_negative & ~is_strike & is_bottom,
        lambda: _update_by_bottom(
            brick_map, ball_x, new_x, new_y, pos, ball_dir, last_y, terminal
        ),
        lambda: (brick_map, new_y, ball_dir, terminal),
    )

    strike = jax.lax.cond(
        ~strike_toggle, lambda: jnp.zeros_like(strike), lambda: strike
    )

    state = state.replace(  # type: ignore
        _ball_y=new_y,
        _ball_x=new_x,
        _ball_dir=ball_dir,
        _pos=pos,
        _brick_map=brick_map,
        _strike=strike,
        _last_x=last_x,
        _last_y=last_y,
        _terminal=terminal,
        _last_action=action,
        rewards=r[jnp.newaxis],
        terminated=terminal,
    )
    return state


def _init(rng: Array) -> State:
    ball_start = jax.random.choice(rng, 2)
    return _init_det(ball_start=ball_start)


def _apply_action(pos, action):
    pos = jax.lax.cond(
        action == 1, lambda: jax.lax.max(ZERO, pos - ONE), lambda: pos
    )
    pos = jax.lax.cond(
        action == 3, lambda: jax.lax.min(NINE, pos + ONE), lambda: pos
    )
    return pos


def _update_ball_pos(ball_x, ball_y, ball_dir):
    return jax.lax.switch(
        ball_dir,
        [
            lambda: (ball_x - ONE, ball_y - ONE),
            lambda: (ball_x + ONE, ball_y - ONE),
            lambda: (ball_x + ONE, ball_y + ONE),
            lambda: (ball_x - ONE, ball_y + ONE),
        ],
    )


def _update_ball_pos_x(new_x, ball_dir):
    new_x = jax.lax.max(ZERO, new_x)
    new_x = jax.lax.min(NINE, new_x)
    ball_dir = jnp.array([1, 0, 3, 2], dtype=jnp.int32)[ball_dir]
    return new_x, ball_dir


def _update_ball_pos_y(ball_dir):
    ball_dir = jnp.array([3, 2, 1, 0], dtype=jnp.int32)[ball_dir]
    return ZERO, ball_dir


def _update_by_strike(r, brick_map, new_x, new_y, last_y, ball_dir, strike):
    brick_map = brick_map.at[new_y, new_x].set(False)
    new_y = last_y
    ball_dir = jnp.array([3, 2, 1, 0], dtype=jnp.int32)[ball_dir]
    return r + 1, jnp.ones_like(strike), brick_map, new_y, ball_dir


def _update_by_bottom(
    brick_map, ball_x, new_x, new_y, pos, ball_dir, last_y, terminal
):
    brick_map = jax.lax.cond(
        brick_map.sum() == 0,
        lambda: brick_map.at[1:4, :].set(True),
        lambda: brick_map,
    )
    new_y, ball_dir, terminal = jax.lax.cond(
        ball_x == pos,
        lambda: (
            last_y,
            jnp.array([3, 2, 1, 0], dtype=jnp.int32)[ball_dir],
            terminal,
        ),
        lambda: jax.lax.cond(
            new_x == pos,
            lambda: (
                last_y,
                jnp.array([2, 3, 0, 1], dtype=jnp.int32)[ball_dir],
                terminal,
            ),
            lambda: (new_y, ball_dir, jnp.array(True, dtype=jnp.bool_)),
        ),
    )
    return brick_map, new_y, ball_dir, terminal


def _init_det(ball_start: Array) -> State:
    ball_x, ball_dir = jax.lax.switch(
        ball_start,
        [lambda: (ZERO, TWO), lambda: (NINE, THREE)],
    )
    last_x = ball_x
    return State(
        _ball_x=ball_x, _ball_dir=ball_dir, _last_x=last_x
    )  # type: ignore


def _observe(state: State) -> Array:
    obs = jnp.zeros((10, 10, 4), dtype=jnp.bool_)
    obs = obs.at[state._ball_y, state._ball_x, 1].set(True)
    obs = obs.at[9, state._pos, 0].set(True)
    obs = obs.at[state._last_y, state._last_x, 2].set(True)
    obs = obs.at[:, :, 3].set(state._brick_map)
    return obs


In [None]:
# -----------------------------
# Config
# -----------------------------
from pydantic import BaseModel
import pgx
class PPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-breakout"
    seed: int = 0
    lr: float = 0.0003
    num_envs: int = 4096
    num_eval_envs: int = 100
    num_steps: int = 128
    plan_horizon: int = 4
    total_timesteps: int = 20_000_000
    frame_skip: int = 1
    update_epochs: int = 3
    minibatch_size: int = 4096
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    wandb_project: str = "pgx-minatar-ppo"
    save_model: bool = True
    

    class Config:
        extra = "forbid"


# In Jupyter, directly create the config instead of parsing CLI args
# You can override any default values here
args = PPOConfig(
    env_name="minatar-freeway",  # Change this to test different games
    num_envs=4096,  # Smaller for testing in notebook
    total_timesteps=200000000,  # Shorter for testing
    frame_skip = 1,
    save_model=True,  # Don't save in notebook by default
)
print(f"Config: {args}")

env = pgx.make(str(args.env_name))
if args.env_name == "minatar-freeway":
    print("using custom env")
    env = MinAtarFreeway(
        use_minimal_action_set=True,
        sticky_action_prob=0.1,
        frame_skip=args.frame_skip,
    )
if args.env_name == "minatar-breakout":
    env = MinAtarBreakout(
        use_minimal_action_set=True,
        sticky_action_prob=0.1,
        frame_skip=args.frame_skip,
    )

num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size

Config: env_name='minatar-freeway' seed=0 lr=0.0003 num_envs=4096 num_eval_envs=100 num_steps=128 plan_horizon=4 total_timesteps=200000000 frame_skip=1 update_epochs=3 minibatch_size=4096 gamma=0.99 gae_lambda=0.95 clip_eps=0.2 ent_coef=0.01 vf_coef=0.5 max_grad_norm=0.5 wandb_project='pgx-minatar-ppo' save_model=True
using custom env


In [14]:
# === PPO with Expanded Discrete Macro-Actions (repeat 1..N of primitive) ===
import os, io, math, time, pickle, sys
from functools import partial
from typing import NamedTuple

import jax
import jax.numpy as jnp
import optax
from flax import nnx
import wandb

import pgx
from pgx.experimental import auto_reset
from pydantic import BaseModel
from breakout_frame_skip import MinAtarBreakout
from freeway_frame_skip import MinAtarFreeway

# ---------------------------------------
# Simple categorical wrapper
# ---------------------------------------
class Categorical:
    def __init__(self, logits):
        self.logits = logits  # [..., A]

    def sample(self, seed):
        return jax.random.categorical(seed, self.logits, axis=-1)

    def log_prob(self, value):
        log_probs = jax.nn.log_softmax(self.logits, axis=-1)
        return jnp.take_along_axis(log_probs, value[..., None], axis=-1).squeeze(-1)

    def entropy(self):
        log_probs = jax.nn.log_softmax(self.logits, axis=-1)
        probs = jax.nn.softmax(self.logits, axis=-1)
        return -(probs * log_probs).sum(axis=-1)

# ---------------------------------------
# Utility
# ---------------------------------------
def pool_out_dim(n: int, window: int = 2, stride: int = 2, padding: str = "VALID") -> int:
    if padding.upper() == "VALID":
        return (n - window) // stride + 1
    return math.ceil(n / stride)

def _tree_where_batch(mask_b, new, old):
    """Select per-batch: where(mask_b, new, old) broadcasting mask along trailing dims."""
    def sel(n, o):
        m = mask_b.reshape((mask_b.shape[0],) + (1,) * (n.ndim - 1))
        return jnp.where(m, n, o)
    return jax.tree.map(sel, new, old)

# ---------------------------------------
# Actor-Critic (single-step) with expanded macro-action head
# ---------------------------------------
class ActorCritic(nnx.Module):
    def __init__(
        self,
        base_num_actions: int,   # primitive actions (e.g., 3 for MinAtar minimal set)
        obs_shape,
        activation: str = "tanh",
        *,
        rngs: nnx.Rngs,
        plan_horizon: int = 4,   # N (repeat choices 1..N)
    ):
        assert activation in ["relu", "tanh"]
        self.base_num_actions = int(base_num_actions)
        self.plan_horizon = int(plan_horizon)
        self.macro_num_actions = self.base_num_actions * self.plan_horizon
        self.activation = activation

        H, W, C = obs_shape
        # shared torso
        self.conv = nnx.Conv(in_features=C, out_features=32, kernel_size=(2, 2), rngs=rngs)
        self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2), padding="VALID")
        H2 = pool_out_dim(H, 2, 2, "VALID")
        W2 = pool_out_dim(W, 2, 2, "VALID")
        flatten_dim = H2 * W2 * 32
        self.fc = nnx.Linear(flatten_dim, 64, rngs=rngs)

        # actor (macro head)
        self.actor_h1 = nnx.Linear(64, 64, rngs=rngs)
        self.actor_h2 = nnx.Linear(64, 64, rngs=rngs)
        self.actor_out = nnx.Linear(64, self.macro_num_actions, rngs=rngs)

        # critic
        self.critic_h1 = nnx.Linear(64, 64, rngs=rngs)
        self.critic_h2 = nnx.Linear(64, 64, rngs=rngs)
        self.critic_out = nnx.Linear(64, 1, rngs=rngs)

    def _act(self, x): 
        return nnx.relu(x) if self.activation == "relu" else nnx.tanh(x)

    def _torso(self, x):
        x = x.astype(jnp.float32)
        x = self.conv(x)
        x = nnx.relu(x)
        x = self.avg_pool(x)
        x = x.reshape((x.shape[0], -1))
        x = nnx.relu(self.fc(x))
        return x  # [B,64]

    def __call__(self, x):
        h = self._torso(x)
        a = self._act(self.actor_h1(h)); a = self._act(self.actor_h2(a))
        logits = self.actor_out(a)  # [B, base*A_horizon]
        v = self._act(self.critic_h1(h)); v = self._act(self.critic_h2(v))
        value = self.critic_out(v)  # [B,1]
        return logits, jnp.squeeze(value, axis=-1)

    # decode macro id -> (primitive in [0..base-1], repeat in [1..N])
    def decode_macro(self, macro_id: jnp.ndarray):
        rep_idx, prim = jnp.divmod(macro_id, self.base_num_actions)  # both [B]
        repeat = rep_idx + 1
        return prim.astype(jnp.int32), repeat.astype(jnp.int32)

# ---------------------------------------
# Transition container (macro as a single PPO step)
# ---------------------------------------
class Transition(NamedTuple):
    done: jnp.ndarray       # [B]
    action: jnp.ndarray     # [B]  (macro id)
    value: jnp.ndarray      # [B]
    reward: jnp.ndarray     # [B]  (sum with inner gamma^i)
    log_prob: jnp.ndarray   # [B]
    obs: jnp.ndarray        # [B,H,W,C]
    gamma_p: jnp.ndarray    # [B]  (gamma ** repeat)

# ---------------------------------------
# Env setup (use custom env if present)
# ---------------------------------------
ppo_args = args  # assumes provided externally
if ppo_args.env_name == "minatar-breakout" and "MinAtarBreakout" in globals():
    print("using custom breakout env")
    env = MinAtarBreakout(
        use_minimal_action_set=True,
        sticky_action_prob=0.1,
        frame_skip=ppo_args.frame_skip,
    )
elif ppo_args.env_name == "minatar-freeway" and "MinAtarFreeway" in globals():
    print("using custom freeway env")
    env = MinAtarFreeway(
        use_minimal_action_set=True,
        sticky_action_prob=0.1,
        frame_skip=ppo_args.frame_skip,
    )
else:
    print("using default env")
    env = pgx.make(str(ppo_args.env_name))

# ---------------------------------------
# Optimizer
# ---------------------------------------
tx = optax.chain(
    optax.clip_by_global_norm(ppo_args.max_grad_norm),
    optax.adam(ppo_args.lr, eps=1e-5),
)

# ---------------------------------------
# Update step (standard PPO, but each step executes a macro)
# ---------------------------------------
def make_update_step():
    step_fn = jax.vmap(auto_reset(env.step, env.init))
    base_A = int(env.num_actions)
    N = int(ppo_args.plan_horizon)
    gamma = jnp.float32(ppo_args.gamma)

    def apply_macro(env_state, primitive: jnp.ndarray, repeats: jnp.ndarray, rng):
        """
        Execute up to N external steps; only step items where i < repeats & not done_any.
        Accumulate discounted reward inside the macro:
          R = r_0 + gamma*r_1 + ... + gamma^(k-1)*r_{k-1}
        Return final state, R, done_any, and gamma^k for bootstrapping.
        """
        B = env_state.observation.shape[0]
        R = jnp.zeros((B,), dtype=jnp.float32)
        done_any = jnp.zeros((B,), dtype=jnp.bool_)

        def body(i, carry):
            state, R, done_any, rng = carry
            active = jnp.logical_and(i < repeats, jnp.logical_not(done_any))  # [B]
            rng, sub = jax.random.split(rng)
            keys = jax.random.split(sub, B)

            # step everyone, then select per-batch whether to accept this step
            state_next = step_fn(state, primitive, keys)
            r_i = jnp.squeeze(state_next.rewards, -1)  # [B]
            term_i = state_next.terminated             # [B]

            # accumulate reward only for active
            R = R + (gamma ** i) * active.astype(jnp.float32) * r_i

            # update done_any
            done_any = jnp.logical_or(done_any, jnp.logical_and(active, term_i))

            # keep or discard per batch
            state = _tree_where_batch(active, state_next, state)
            return (state, R, done_any, rng)

        state, R, done_any, rng = jax.lax.fori_loop(0, N, body, (env_state, R, done_any, rng))
        gamma_p = gamma ** repeats.astype(jnp.float32)  # [B]
        return state, R, done_any, gamma_p, rng

    @nnx.jit(donate_argnames=("model", "optimizer"))
    def _update_step(model: nnx.Module,
                     optimizer: nnx.Optimizer,
                     env_state,
                     last_obs,
                     rng):
        # -------- Collect trajectories (num_steps macros) --------
        def _env_step(runner_state, _):
            model, optimizer, env_state, last_obs, rng = runner_state

            # policy
            rng, _rng = jax.random.split(rng)
            logits, value = model(last_obs)                       # logits over macro actions
            pi = Categorical(logits=logits)
            macro_id = pi.sample(seed=_rng)                       # [B]
            log_prob = pi.log_prob(macro_id)                      # [B]

            # decode macro -> primitive + repeats
            primitive, repeats = model.decode_macro(macro_id)     # both [B]

            # apply macro to env
            rng, _rng = jax.random.split(rng)
            env_state, R_chunk, done_any, gamma_p, _ = apply_macro(env_state, primitive, repeats, _rng)

            transition = Transition(
                done_any,               # done
                macro_id,               # action (macro id)
                value,                  # value at obs0
                R_chunk,                # discounted inside macro
                log_prob,               # old logp(macro)
                last_obs,               # obs0
                gamma_p,                # gamma ** repeats
            )
            runner_state = (model, optimizer, env_state, env_state.observation, rng)
            return runner_state, transition

        runner_state = (model, optimizer, env_state, last_obs, rng)
        runner_state, traj_batch = nnx.scan(
            _env_step, in_axes=(nnx.Carry, None), out_axes=(nnx.Carry, 0), length=ppo_args.num_steps
        )(runner_state, None)

        # -------- Advantage / targets (GAE with per-step gamma_p) --------
        model, optimizer, env_state, last_obs, rng = runner_state
        _, last_val = model(last_obs)  # [B]

        def _gae(carry, tr: Transition):
            gae, next_value = carry
            delta = tr.reward + tr.gamma_p * next_value * (1 - tr.done) - tr.value
            gae = delta + tr.gamma_p * ppo_args.gae_lambda * (1 - tr.done) * gae
            return (gae, tr.value), gae

        (_, _), advantages = jax.lax.scan(
            _gae, (jnp.zeros_like(last_val), last_val), traj_batch, reverse=True, unroll=16
        )
        targets = advantages + traj_batch.value  # [T,B]

        # -------- SGD epochs (standard PPO) --------
        def _update_epoch(update_state, _):
            model, optimizer, traj_batch, advantages, targets, rng = update_state

            def _update_minibatch(state, minibatch):
                model, optimizer = state
                mb_traj, mb_adv, mb_targets = minibatch

                def _loss_fn(model: nnx.Module, traj: Transition, gae, targets):
                    # re-run policy on obs0
                    logits, value = model(traj.obs)            # logits over macro actions
                    pi = Categorical(logits=logits)
                    new_log_prob = pi.log_prob(traj.action)

                    # value loss (clipped) at obs0
                    value_pred_clipped = traj.value + (value - traj.value).clip(-ppo_args.clip_eps, ppo_args.clip_eps)
                    v_loss_unclipped = jnp.square(value - targets)
                    v_loss_clipped = jnp.square(value_pred_clipped - targets)
                    value_loss = 0.5 * jnp.maximum(v_loss_unclipped, v_loss_clipped).mean()

                    # policy loss (clipped)
                    ratio = jnp.exp(new_log_prob - traj.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = jnp.clip(ratio, 1.0 - ppo_args.clip_eps, 1.0 + ppo_args.clip_eps) * gae
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean()

                    # entropy bonus (over macro action space)
                    entropy = pi.entropy().mean()

                    total = loss_actor + ppo_args.vf_coef * value_loss - ppo_args.ent_coef * entropy
                    return total, (value_loss, loss_actor, entropy)

                (total_loss, aux), grads = nnx.value_and_grad(
                    _loss_fn, has_aux=True, argnums=nnx.DiffState(0, nnx.Param)
                )(model, mb_traj, mb_adv, mb_targets)

                optimizer.update(grads=grads)
                return (model, optimizer), (total_loss, aux)

            # flatten (T,B) -> (T*B)
            rng, _rng = jax.random.split(rng)
            batch_size = ppo_args.minibatch_size * num_minibatches
            assert batch_size == ppo_args.num_steps * ppo_args.num_envs, \
                "batch size must equal (num_steps * num_envs)"

            batch = (traj_batch, advantages, targets)
            batch = jax.tree.map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
            perm = jax.random.permutation(_rng, batch_size)
            shuffled = jax.tree.map(lambda x: jnp.take(x, perm, axis=0), batch)
            minibatches = jax.tree.map(
                lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
                shuffled,
            )

            (model, optimizer), losses = nnx.scan(
                _update_minibatch, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0)
            )((model, optimizer), minibatches)
            update_state = (model, optimizer, traj_batch, advantages, targets, rng)
            return update_state, losses

        update_state = (model, optimizer, traj_batch, advantages, targets, rng)
        update_state, loss_info = nnx.scan(
            _update_epoch, in_axes=(nnx.Carry, None), out_axes=(nnx.Carry, 0), length=ppo_args.update_epochs
        )(update_state, None)

        model, optimizer, _, _, _, rng = update_state
        runner_state = (model, optimizer, env_state, last_obs, rng)
        return runner_state, loss_info

    return _update_step

# ---------------------------------------
# Evaluation (sample macro, execute repeats)
# ---------------------------------------
@nnx.jit
def evaluate_macro(model: nnx.Module, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, ppo_args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)  # [B,1]

    def cond_fn(tup):
        state, *_ = tup
        return ~state.terminated.all()

    def loop_macro(tup):
        state, R, rng_key = tup
        logits, _ = model(state.observation)
        pi = Categorical(logits=logits)
        rng_key, sub = jax.random.split(rng_key)
        macro_id = pi.sample(sub)
        # decode macro
        base_A = env.num_actions
        rep_idx, prim = jnp.divmod(macro_id, base_A)
        repeats = rep_idx + 1

        # execute repeats (up to N) per batch
        def body(i, carry):
            s, Racc, rng = carry
            active = i < repeats
            rng, sub2 = jax.random.split(rng)
            keys = jax.random.split(sub2, s.observation.shape[0])
            s_next = step_fn(s, prim, keys)
            # select per-batch
            s = _tree_where_batch(active, s_next, s)
            Racc = Racc + jnp.where(active[:, None], s_next.rewards, 0.0)
            return (s, Racc, rng)

        state, R, rng_key = jax.lax.fori_loop(0, ppo_args.plan_horizon, body, (state, R, rng_key))
        return state, R, rng_key

    state, R, _ = jax.lax.while_loop(cond_fn, loop_macro, (state, R, rng_key))
    return R.mean()

# ---------------------------------------
# Training loop (unchanged structure, macro steps collected)
# ---------------------------------------
def build_model(env, rng):
    obs_shape = env.observation_shape
    return ActorCritic(
        env.num_actions,
        obs_shape=obs_shape,
        activation="tanh",
        rngs=nnx.Rngs(rng),
        plan_horizon=ppo_args.plan_horizon,
    )

def train(rng):
    # model + optimizer
    rng, _rng = jax.random.split(rng)
    model = build_model(env, _rng)
    optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

    # update fn
    update_step = make_update_step()

    # init envs
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, ppo_args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)
    last_obs = env_state.observation
    rng, _rng = jax.random.split(rng)
    runner_state = (model, optimizer, env_state, last_obs, _rng)

    # warmup compile
    _, _ = update_step(*runner_state)

    # initial eval
    steps = 0
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate_macro(runner_state[0], _rng)
    log = {"sec": 0.0, f"{ppo_args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    if wandb.run is not None: wandb.log(log)
    st = time.time(); tt = 0.0

    for _ in range(num_updates):
        runner_state, loss_info = update_step(*runner_state)
        model, optimizer, env_state, last_obs, rng = runner_state
        # rough env-steps accounting: assume avg repeat ~ plan_horizon
        steps += ppo_args.num_envs * ppo_args.num_steps * ppo_args.plan_horizon

        et = time.time(); tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate_macro(model, _rng)
        log = {"sec": tt, f"{ppo_args.env_name}/eval_R": float(eval_R), "steps": steps}
        print(log)
        if wandb.run is not None: wandb.log(log)
        st = time.time()

    return runner_state

# ---------------------------------------
# Bookkeeping, updates, and run
# ---------------------------------------
# If you want total_timesteps to be counted in ENV steps, this is a rough conversion
num_updates = ppo_args.total_timesteps // (ppo_args.num_envs * ppo_args.num_steps * ppo_args.plan_horizon)
num_minibatches = (ppo_args.num_envs * ppo_args.num_steps) // ppo_args.minibatch_size

wandb.init(
    project=f"ppo-{ppo_args.env_name}-frameskip",
    name=f"{ppo_args.env_name}-frameskip{ppo_args.frame_skip}-N{ppo_args.plan_horizon}-macroflat",
    config=ppo_args.dict(),
    mode="disabled",  # set "online" to enable logging
)

print("Starting training (PPO with expanded macro actions)...")
rng = jax.random.PRNGKey(ppo_args.seed)
runner_state = train(rng)

# save
model_dir = f"./minatar-ppo-models/{ppo_args.env_name}/"
os.makedirs(model_dir, exist_ok=True)

if ppo_args.save_model:
    model = runner_state[0]
    ckpt = f"{model_dir}/{ppo_args.env_name}-frameskip={ppo_args.frame_skip}-N={ppo_args.plan_horizon}-macroflat.ckpt"
    with open(ckpt, "wb") as f:
        pickle.dump(nnx.state(model, nnx.Param), f)
    print(f"Model saved to {ckpt}")

wandb.finish()


using custom freeway env
Starting training (PPO with expanded macro actions)...


2025-11-01 20:19:04.428212: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1157] Compiling 64 configs for 13 fusions on a single thread.
2025-11-01 20:19:25.000216: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-11-01 20:19:25.000253: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-11-01 20:19:25.000276: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1157] Compiling 77 configs for 4 fusions on a single thread.


{'sec': 0.0, 'minatar-freeway/eval_R': 1.7400000095367432, 'steps': 0}
{'sec': 0.7022140026092529, 'minatar-freeway/eval_R': 5.029999732971191, 'steps': 2097152}
{'sec': 1.4007649421691895, 'minatar-freeway/eval_R': 9.239999771118164, 'steps': 4194304}
{'sec': 2.0995163917541504, 'minatar-freeway/eval_R': 14.979999542236328, 'steps': 6291456}
{'sec': 2.7979683876037598, 'minatar-freeway/eval_R': 18.44999885559082, 'steps': 8388608}
{'sec': 3.4968459606170654, 'minatar-freeway/eval_R': 22.85999870300293, 'steps': 10485760}
{'sec': 4.195537805557251, 'minatar-freeway/eval_R': 26.489999771118164, 'steps': 12582912}
{'sec': 4.893589735031128, 'minatar-freeway/eval_R': 28.44999885559082, 'steps': 14680064}
{'sec': 5.591802597045898, 'minatar-freeway/eval_R': 30.889999389648438, 'steps': 16777216}
{'sec': 6.289898872375488, 'minatar-freeway/eval_R': 33.119998931884766, 'steps': 18874368}
{'sec': 6.9886181354522705, 'minatar-freeway/eval_R': 34.45000076293945, 'steps': 20971520}
{'sec': 7.686