In [1]:
import sys
import time
import math
import shutil
import pickle
from functools import partial
from typing import NamedTuple, Literal

import jax
import jax.numpy as jnp
import optax
from flax import nnx
from omegaconf import OmegaConf
from pydantic import BaseModel
import wandb


import sys
import os

# Add the path to sys.path for the current Python session
new_path = "/home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax"

# Add to sys.path if not already there
if new_path not in sys.path:
    sys.path.insert(0, new_path)

# Also set PYTHONPATH for any subprocesses
os.environ["PYTHONPATH"] = f"{new_path}:{os.environ.get('PYTHONPATH', '')}"

# Verify it worked
print("Python path updated:")
print(f"sys.path includes: {new_path}")
print(f"PYTHONPATH env var: {os.environ['PYTHONPATH']}")

import pgx
from pgx.experimental import auto_reset


# -----------------------------
# Simple Categorical distribution wrapper using JAX built-ins
# -----------------------------
class Categorical:
    def __init__(self, logits):
        self.logits = logits

    def sample(self, seed):
        return jax.random.categorical(seed, self.logits)

    def log_prob(self, value):
        log_probs = jax.nn.log_softmax(self.logits)
        return jnp.take_along_axis(log_probs, value[..., None], axis=-1).squeeze(-1)

    def entropy(self):
        log_probs = jax.nn.log_softmax(self.logits)
        probs = jax.nn.softmax(self.logits)
        return -(probs * log_probs).sum(axis=-1)


# -----------------------------
# Config
# -----------------------------
class PPOConfig(BaseModel):
    env_name: Literal[
        "minatar-breakout",
        "minatar-freeway",
        "minatar-space_invaders",
        "minatar-asterix",
        "minatar-seaquest",
    ] = "minatar-space_invaders"
    seed: int = 0
    lr: float = 0.0003
    num_envs: int = 4096
    num_eval_envs: int = 100
    num_steps: int = 128
    total_timesteps: int = 20_000_000
    update_epochs: int = 3
    minibatch_size: int = 4096
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    wandb_project: str = "pgx-minatar-ppo"
    save_model: bool = True
    out_models_dir: str = "/home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax/examples/minatar-ppo/space_models"

    class Config:
        extra = "forbid"

Python path updated:
sys.path includes: /home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax
PYTHONPATH env var: /home/ubuntu/tensorflow_test/control/real-timeRL/realtime-atari-jax:


In [6]:
"""MinAtar/SpaceInvaders: A fork of github.com/kenjyoung/MinAtar

The authors of original MinAtar implementation are:
    * Kenny Young (kjyoung@ualberta.ca)
    * Tian Tian (ttian@ualberta.ca)
The original MinAtar implementation is distributed under GNU General Public License v3.0
    * https://github.com/kenjyoung/MinAtar/blob/master/License.txt
"""
from typing import Literal, Optional

import jax
import jax.lax as lax
from jax import numpy as jnp

import pgx.core as core
from pgx._src.struct import dataclass
from pgx._src.types import Array, PRNGKey

FALSE = jnp.bool_(False)
TRUE = jnp.bool_(True)

SHOT_COOL_DOWN = jnp.int32(5)
ENEMY_MOVE_INTERVAL = jnp.int32(12)
ENEMY_SHOT_INTERVAL = jnp.int32(10)

ZERO = jnp.int32(0)
NINE = jnp.int32(9)


@dataclass
class State(core.State):
    current_player: Array = jnp.int32(0)
    observation: Array = jnp.zeros((10, 10, 6), dtype=jnp.bool_)
    rewards: Array = jnp.zeros(1, dtype=jnp.float32)  # (1,)
    terminated: Array = FALSE
    truncated: Array = FALSE
    legal_action_mask: Array = jnp.ones(4, dtype=jnp.bool_)
    _step_count: Array = jnp.int32(0)
    # --- MinAtar SpaceInvaders specific ---
    _pos: Array = jnp.int32(5)
    _f_bullet_map: Array = jnp.zeros((10, 10), dtype=jnp.bool_)
    _e_bullet_map: Array = jnp.zeros((10, 10), dtype=jnp.bool_)
    _alien_map: Array = (
        jnp.zeros((10, 10), dtype=jnp.bool_).at[0:4, 2:8].set(TRUE)
    )
    _alien_dir: Array = jnp.int32(-1)
    _enemy_move_interval: Array = ENEMY_MOVE_INTERVAL
    _alien_move_timer: Array = ENEMY_MOVE_INTERVAL
    _alien_shot_timer: Array = ENEMY_SHOT_INTERVAL
    _ramp_index: Array = jnp.int32(0)
    _shot_timer: Array = jnp.int32(0)
    _terminal: Array = FALSE
    _last_action: Array = jnp.int32(0)

    @property
    def env_id(self) -> core.EnvId:
        return "minatar-space_invaders"

    def to_svg(
        self,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> str:
        del color_theme, scale
        from .utils import visualize_minatar

        return visualize_minatar(self)

    def save_svg(
        self,
        filename,
        *,
        color_theme: Optional[Literal["light", "dark"]] = None,
        scale: Optional[float] = None,
    ) -> None:
        from .utils import visualize_minatar

        visualize_minatar(self, filename)


class MinAtarSpaceInvaders(core.Env):
    def __init__(
        self,
        *,
        use_minimal_action_set: bool = True,
        sticky_action_prob: float = 0.1,
    ):
        super().__init__()
        self.use_minimal_action_set = use_minimal_action_set
        self.sticky_action_prob: float = sticky_action_prob
        self.minimal_action_set = jnp.int32([0, 1, 3, 5])
        self.legal_action_mask = jnp.ones(6, dtype=jnp.bool_)
        if self.use_minimal_action_set:
            self.legal_action_mask = jnp.ones(
                self.minimal_action_set.shape[0], dtype=jnp.bool_
            )

    def step(
        self, state: core.State, action: Array, key: Optional[Array] = None
    ) -> core.State:
        assert key is not None, (
            "v2.0.0 changes the signature of step. Please specify PRNGKey at the third argument:\n\n"
            "  * <  v2.0.0: step(state, action)\n"
            "  * >= v2.0.0: step(state, action, key)\n\n"
            "See v2.0.0 release note for more details:\n\n"
            "  https://github.com/sotetsuk/pgx/releases/tag/v2.0.0"
        )
        return super().step(state, action, key)

    def _init(self, key: PRNGKey) -> State:
        state = State()
        state = state.replace(legal_action_mask=self.legal_action_mask)  # type: ignore
        return state  # type: ignore

    def _step(self, state: core.State, action, key) -> State:
        state = state.replace(legal_action_mask=self.legal_action_mask)  # type: ignore
        action = jax.lax.select(
            self.use_minimal_action_set,
            self.minimal_action_set[action],
            action,
        )
        return _step(state, action, key, self.sticky_action_prob)  # type: ignore

    def _observe(self, state: core.State, player_id: Array) -> Array:
        assert isinstance(state, State)
        return _observe(state)

    @property
    def id(self) -> core.EnvId:
        return "minatar-space_invaders"

    @property
    def version(self) -> str:
        return "v1"

    @property
    def num_players(self):
        return 1


def _step(
    state: State,
    action: Array,
    key,
    sticky_action_prob,
):
    action = jnp.int32(action)
    action = jax.lax.cond(
        jax.random.uniform(key) < sticky_action_prob,
        lambda: state._last_action,
        lambda: action,
    )
    return _step_det(state, action)


def _observe(state: State) -> Array:
    obs = jnp.zeros((10, 10, 6), dtype=jnp.bool_)
    obs = obs.at[9, state._pos, 0].set(TRUE)
    obs = obs.at[:, :, 1].set(state._alien_map)
    obs = obs.at[:, :, 2].set(
        lax.cond(
            state._alien_dir < 0,
            lambda: state._alien_map,
            lambda: jnp.zeros_like(state._alien_map),
        )
    )
    obs = obs.at[:, :, 3].set(
        lax.cond(
            state._alien_dir < 0,
            lambda: jnp.zeros_like(state._alien_map),
            lambda: state._alien_map,
        )
    )
    obs = obs.at[:, :, 4].set(state._f_bullet_map)
    obs = obs.at[:, :, 5].set(state._e_bullet_map)
    return obs


def _step_det(
    state: State,
    action: Array,
):
    r = jnp.float32(0)

    pos = state._pos
    f_bullet_map = state._f_bullet_map
    e_bullet_map = state._e_bullet_map
    alien_map = state._alien_map
    alien_dir = state._alien_dir
    enemy_move_interval = state._enemy_move_interval
    alien_move_timer = state._alien_move_timer
    alien_shot_timer = state._alien_shot_timer
    ramp_index = state._ramp_index
    shot_timer = state._shot_timer
    terminal = state._terminal

    # Resolve player action
    # action_map = ['n','l','u','r','d','f']
    pos, f_bullet_map, shot_timer = _resole_action(
        pos, f_bullet_map, shot_timer, action
    )

    # Update Friendly Bullets
    f_bullet_map = jnp.roll(f_bullet_map, -1, axis=0)
    f_bullet_map = f_bullet_map.at[9, :].set(FALSE)

    # Update Enemy Bullets
    e_bullet_map = jnp.roll(e_bullet_map, 1, axis=0)
    e_bullet_map = e_bullet_map.at[0, :].set(FALSE)
    terminal = lax.cond(e_bullet_map[9, pos], lambda: TRUE, lambda: terminal)

    # Update aliens
    terminal = lax.cond(alien_map[9, pos], lambda: TRUE, lambda: terminal)
    alien_move_timer, alien_map, alien_dir, terminal = lax.cond(
        alien_move_timer == 0,
        lambda: _update_alien_by_move_timer(
            alien_map, alien_dir, enemy_move_interval, pos, terminal
        ),
        lambda: (alien_move_timer, alien_map, alien_dir, terminal),
    )
    timer_zero = alien_shot_timer == 0
    alien_shot_timer = lax.cond(
        timer_zero, lambda: ENEMY_SHOT_INTERVAL, lambda: alien_shot_timer
    )
    e_bullet_map = lax.cond(
        timer_zero,
        lambda: e_bullet_map.at[_nearest_alien(pos, alien_map)].set(TRUE),
        lambda: e_bullet_map,
    )

    kill_locations = alien_map & (alien_map == f_bullet_map)

    r += jnp.sum(kill_locations, dtype=jnp.float32)
    alien_map = alien_map & (~kill_locations)
    f_bullet_map = f_bullet_map & (~kill_locations)

    # Update various timers
    shot_timer -= shot_timer > 0
    alien_move_timer -= 1
    alien_shot_timer -= 1
    ramping = True
    is_enemy_zero = jnp.count_nonzero(alien_map) == 0
    enemy_move_interval, ramp_index = lax.cond(
        is_enemy_zero & (enemy_move_interval > 6) & ramping,
        lambda: (enemy_move_interval - 1, ramp_index + 1),
        lambda: (enemy_move_interval, ramp_index),
    )
    alien_map = lax.cond(
        is_enemy_zero,
        lambda: alien_map.at[0:4, 2:8].set(TRUE),
        lambda: alien_map,
    )

    return state.replace(  # type: ignore
        _pos=pos,
        _f_bullet_map=f_bullet_map,
        _e_bullet_map=e_bullet_map,
        _alien_map=alien_map,
        _alien_dir=alien_dir,
        _enemy_move_interval=enemy_move_interval,
        _alien_move_timer=alien_move_timer,
        _alien_shot_timer=alien_shot_timer,
        _ramp_index=ramp_index,
        _shot_timer=shot_timer,
        _terminal=terminal,
        _last_action=action,
        rewards=r[jnp.newaxis],
        terminated=terminal,
    )


def _resole_action(pos, f_bullet_map, shot_timer, action):
    f_bullet_map = lax.cond(
        (action == 5) & (shot_timer == 0),
        lambda: f_bullet_map.at[9, pos].set(TRUE),
        lambda: f_bullet_map,
    )
    shot_timer = lax.cond(
        (action == 5) & (shot_timer == 0),
        lambda: SHOT_COOL_DOWN,
        lambda: shot_timer,
    )
    pos = lax.cond(
        action == 1, lambda: jax.lax.max(ZERO, pos - 1), lambda: pos
    )
    pos = lax.cond(
        action == 3, lambda: jax.lax.min(NINE, pos + 1), lambda: pos
    )
    return pos, f_bullet_map, shot_timer


def _nearest_alien(pos, alien_map):
    search_order = jnp.argsort(jnp.abs(jnp.arange(10, dtype=jnp.int32) - pos))
    ix = lax.while_loop(
        lambda i: jnp.sum(alien_map[:, search_order[i]]) <= 0,
        lambda i: i + 1,
        0,
    )
    ix = search_order[ix]
    j = lax.while_loop(lambda i: alien_map[i, ix] == 0, lambda i: i - 1, 9)
    return (j, ix)


def _update_alien_by_move_timer(
    alien_map, alien_dir, enemy_move_interval, pos, terminal
):
    alien_move_timer = lax.min(
        jnp.sum(alien_map, dtype=jnp.int32), enemy_move_interval
    )
    cond = ((jnp.sum(alien_map[:, 0]) > 0) & (alien_dir < 0)) | (
        (jnp.sum(alien_map[:, 9]) > 0) & (alien_dir > 0)
    )
    terminal = lax.cond(
        cond & (jnp.sum(alien_map[9, :]) > 0),
        lambda: jnp.bool_(True),
        lambda: terminal,
    )
    alien_dir = lax.cond(cond, lambda: -alien_dir, lambda: alien_dir)
    alien_map = lax.cond(
        cond,
        lambda: jnp.roll(alien_map, 1, axis=0),
        lambda: jnp.roll(alien_map, alien_dir, axis=1),
    )
    terminal = lax.cond(
        alien_map[9, pos], lambda: jnp.bool_(True), lambda: terminal
    )
    return alien_move_timer, alien_map, alien_dir, terminal


In [7]:
args = PPOConfig(
    env_name="minatar-space_invaders",)
env = MinAtarSpaceInvaders()
num_updates = args.total_timesteps // args.num_envs // args.num_steps
num_minibatches = args.num_envs * args.num_steps // args.minibatch_size

In [8]:
# -----------------------------
# NNX Actor-Critic
# -----------------------------
def pool_out_dim(n: int, window: int = 2, stride: int = 2, padding: str = "VALID") -> int:
    # Matches flax.linen/nnx pooling semantics for VALID padding
    if padding.upper() == "VALID":
        return (n - window) // stride + 1
    # Fallback (not used here)
    return math.ceil(n / stride)


class ActorCritic(nnx.Module):
    def __init__(self, num_actions: int, obs_shape, activation: str = "tanh", *, rngs: nnx.Rngs):
        assert activation in ["relu", "tanh"]
        self.num_actions = num_actions
        self.activation = activation

        H, W, C = obs_shape  # NHWC expected by flax.nnx.Conv
        # Convolution (channels-last). Default padding is 'SAME'.
        self.conv = nnx.Conv(in_features=C, out_features=32, kernel_size=(2, 2), rngs=rngs)

        # AvgPool params are fixed; keep a partial for clean callsites
        self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2), padding="VALID")

        # After conv ('SAME') + avg_pool('VALID', 2x2, stride 2) the spatial dims become:
        H2 = pool_out_dim(H, 2, 2, "VALID")
        W2 = pool_out_dim(W, 2, 2, "VALID")
        flatten_dim = H2 * W2 * 32

        # Shared torso
        self.fc = nnx.Linear(flatten_dim, 64, rngs=rngs)

        # Actor head: 64 -> 64 -> 64 -> num_actions (two hidden layers like original)
        self.actor_h1 = nnx.Linear(64, 64, rngs=rngs)
        self.actor_h2 = nnx.Linear(64, 64, rngs=rngs)
        self.actor_out = nnx.Linear(64, num_actions, rngs=rngs)

        # Critic head: 64 -> 64 -> 64 -> 1 (two hidden layers like original)
        self.critic_h1 = nnx.Linear(64, 64, rngs=rngs)
        self.critic_h2 = nnx.Linear(64, 64, rngs=rngs)
        self.critic_out = nnx.Linear(64, 1, rngs=rngs)

    def _act(self, x):
        return nnx.relu(x) if self.activation == "relu" else nnx.tanh(x)

    def __call__(self, x):
        x = x.astype(jnp.float32)
        x = self.conv(x)
        x = nnx.relu(x)
        x = self.avg_pool(x)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nnx.relu(self.fc(x))

        a = self._act(self.actor_h1(x))
        a = self._act(self.actor_h2(a))
        logits = self.actor_out(a)

        v = self._act(self.critic_h1(x))
        v = self._act(self.critic_h2(v))
        value = self.critic_out(v)

        return logits, jnp.squeeze(value, axis=-1)


# -----------------------------
# Optimizer (Optax via NNX wrapper)
# -----------------------------
tx = optax.chain(
    optax.clip_by_global_norm(args.max_grad_norm),
    optax.adam(args.lr, eps=1e-5),
)


# -----------------------------
# Rollout container
# -----------------------------
class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray


def save_checkpoint(model: nnx.Module, step: int) -> str:
    checkpoint_path = os.path.join(
        args.out_models_dir,
        f"{args.env_name}-seed={args.seed}-steps={step}.ckpt",
    )
    with open(checkpoint_path, "wb") as f:
        pickle.dump(nnx.state(model, nnx.Param), f)
    return checkpoint_path


# -----------------------------
# Update step (collect + optimize), jitted with NNX
# -----------------------------
def make_update_step():
    step_fn = jax.vmap(auto_reset(env.step, env.init))

    @nnx.jit(donate_argnames=("model", "optimizer"))
    def _update_step(model: nnx.Module,
                     optimizer: nnx.Optimizer,
                     env_state,
                     last_obs,
                     rng):
        # -------- Collect trajectories --------
        def _env_step(runner_state, _):
            model, optimizer, env_state, last_obs, rng = runner_state

            # Policy
            rng, _rng = jax.random.split(rng)
            logits, value = model(last_obs)
            pi = Categorical(logits=logits)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # Env step
            rng, _rng = jax.random.split(rng)
            keys = jax.random.split(_rng, env_state.observation.shape[0])
            env_state = step_fn(env_state, action, keys)

            transition = Transition(
                env_state.terminated,
                action,
                value,
                jnp.squeeze(env_state.rewards),
                log_prob,
                last_obs,
            )
            runner_state = (model, optimizer, env_state, env_state.observation, rng)
            return runner_state, transition

        runner_state = (model, optimizer, env_state, last_obs, rng)
        runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, length=args.num_steps)

        # -------- Advantage / targets (GAE) --------
        model, optimizer, env_state, last_obs, rng = runner_state
        _, last_val = model(last_obs)

        def _get_advantages(gae_and_next_value, transition):
            gae, next_value = gae_and_next_value
            done, value, reward = transition.done, transition.value, transition.reward
            delta = reward + args.gamma * next_value * (1 - done) - value
            gae = delta + args.gamma * args.gae_lambda * (1 - done) * gae
            return (gae, value), gae

        (_, _), advantages = jax.lax.scan(
            _get_advantages,
            (jnp.zeros_like(last_val), last_val),
            traj_batch,
            reverse=True,
            unroll=16,
        )
        targets = advantages + traj_batch.value

        # -------- SGD epochs --------
        def _update_epoch(update_state, _):
            model, optimizer, traj_batch, advantages, targets, rng = update_state

            def _update_minibatch(state, minibatch):
                model, optimizer = state
                mb_traj, mb_adv, mb_targets = minibatch

                def _loss_fn(model: nnx.Module, traj: Transition, gae, targets):
                    # Re-run policy
                    logits, value = model(traj.obs)
                    pi = Categorical(logits=logits)
                    log_prob = pi.log_prob(traj.action)

                    # Value loss (clipped)
                    value_pred_clipped = traj.value + (value - traj.value).clip(-args.clip_eps, args.clip_eps)
                    v_loss_unclipped = jnp.square(value - targets)
                    v_loss_clipped = jnp.square(value_pred_clipped - targets)
                    value_loss = 0.5 * jnp.maximum(v_loss_unclipped, v_loss_clipped).mean()

                    # Policy loss (clipped)
                    ratio = jnp.exp(log_prob - traj.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = jnp.clip(ratio, 1.0 - args.clip_eps, 1.0 + args.clip_eps) * gae
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean()

                    # Entropy bonus
                    entropy = pi.entropy().mean()

                    total = loss_actor + args.vf_coef * value_loss - args.ent_coef * entropy
                    return total, (value_loss, loss_actor, entropy)

                # Compute grads w.r.t. model Params
                (total_loss, aux), grads = nnx.value_and_grad(
                    _loss_fn, has_aux=True, argnums=nnx.DiffState(0, nnx.Param)
                )(model, mb_traj, mb_adv, mb_targets)

                # Optax step via NNX Optimizer (updates model in-place)
                optimizer.update(model, grads)

                return (model, optimizer), (total_loss, aux)

            # Shuffle + minibatch
            rng, _rng = jax.random.split(rng)
            batch_size = args.minibatch_size * num_minibatches
            assert batch_size == args.num_steps * args.num_envs, "batch size must equal steps * envs"

            batch = (traj_batch, advantages, targets)
            batch = jax.tree.map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch)
            permutation = jax.random.permutation(_rng, batch_size)
            shuffled = jax.tree.map(lambda x: jnp.take(x, permutation, axis=0), batch)
            minibatches = jax.tree.map(
                lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
                shuffled,
            )

            (model, optimizer), losses = jax.lax.scan(_update_minibatch, (model, optimizer), minibatches)
            update_state = (model, optimizer, traj_batch, advantages, targets, rng)
            return update_state, losses

        update_state = (model, optimizer, traj_batch, advantages, targets, rng)
        update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, length=args.update_epochs)

        model, optimizer, _, _, _, rng = update_state
        runner_state = (model, optimizer, env_state, last_obs, rng)
        return runner_state, loss_info

    return _update_step


# -----------------------------
# Evaluation (greedy sample)
# -----------------------------
@nnx.jit
def evaluate(model: nnx.Module, rng_key):
    step_fn = jax.vmap(env.step)
    rng_key, sub_key = jax.random.split(rng_key)
    subkeys = jax.random.split(sub_key, args.num_eval_envs)
    state = jax.vmap(env.init)(subkeys)
    R = jnp.zeros_like(state.rewards)

    def cond_fn(tup):
        state, _, _ = tup
        return ~state.terminated.all()

    def loop_fn(tup):
        state, R, rng_key = tup
        logits, _value = model(state.observation)
        pi = Categorical(logits=logits)
        rng_key, _rng = jax.random.split(rng_key)
        action = pi.sample(seed=_rng)
        rng_key, _rng = jax.random.split(rng_key)
        keys = jax.random.split(_rng, state.observation.shape[0])
        state = step_fn(state, action, keys)
        return state, R + state.rewards, rng_key

    state, R, _ = jax.lax.while_loop(cond_fn, loop_fn, (state, R, rng_key))
    return R.mean()


# -----------------------------
# Training Loop
# -----------------------------
def train(rng):
    tt = 0.0
    st = time.time()

    # Model + optimizer
    rng, _rng = jax.random.split(rng)
    obs_shape = env.observation_shape
    model = ActorCritic(env.num_actions, obs_shape=obs_shape, activation="tanh", rngs=nnx.Rngs(_rng))
    optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

    # Update function
    update_step = make_update_step()

    # Init envs
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, args.num_envs)
    env_state = jax.jit(jax.vmap(env.init))(reset_rng)
    last_obs = env_state.observation

    rng, _rng = jax.random.split(rng)
    runner_state = (model, optimizer, env_state, last_obs, _rng)

    # Warmup (compile)
    _, _ = update_step(*runner_state)

    # initial evaluation
    et = time.time()
    tt += et - st
    rng, _rng = jax.random.split(rng)
    eval_R = evaluate(runner_state[0], _rng)
    steps = 0
    training_total = args.num_envs * args.num_steps * num_updates
    checkpoint_targets = []
    checkpoint_paths = {}
    if args.save_model:
        os.makedirs(args.out_models_dir, exist_ok=True)
        base_interval = max(1, math.ceil(training_total / 4))
        checkpoint_targets = [min(training_total, base_interval * i) for i in range(1, 4)]
        checkpoint_targets.append(training_total)
        checkpoint_targets = sorted(set(checkpoint_targets))
    log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
    print(log)
    wandb.log(log)
    st = time.time()

    for _ in range(num_updates):
        runner_state, loss_info = update_step(*runner_state)
        model, optimizer, env_state, last_obs, rng = runner_state
        steps += args.num_envs * args.num_steps

        # evaluation
        et = time.time()
        tt += et - st
        rng, _rng = jax.random.split(rng)
        eval_R = evaluate(model, _rng)
        log = {"sec": tt, f"{args.env_name}/eval_R": float(eval_R), "steps": steps}
        print(log)
        wandb.log(log)
        st = time.time()
        if args.save_model:
            for target in checkpoint_targets:
                if steps >= target and target not in checkpoint_paths:
                    checkpoint_paths[target] = save_checkpoint(model, target)

    if args.save_model:
        for target in checkpoint_targets:
            if steps >= target and target not in checkpoint_paths:
                checkpoint_paths[target] = save_checkpoint(runner_state[0], target)

    return runner_state, checkpoint_paths  # (model, optimizer, env_state, last_obs, rng)


In [9]:
wandb.init(project=args.wandb_project, config=args.dict())
rng = jax.random.PRNGKey(args.seed)
out, checkpoint_paths = train(rng)
if args.save_model:
    model = out[0]
    if checkpoint_paths:
        final_step = max(checkpoint_paths)
        shutil.copyfile(
            checkpoint_paths[final_step],
            os.path.join(
                args.out_models_dir,
                f"{args.env_name}-seed={args.seed}.ckpt",
            ),
        )
    else:
        # Save only learnable parameters (nnx.Param state) like Haiku params
        os.makedirs(args.out_models_dir, exist_ok=True)
        with open(
            os.path.join(
                args.out_models_dir,
                f"{args.env_name}-seed={args.seed}.ckpt",
            ),
            "wb",
        ) as f:
            pickle.dump(nnx.state(model, nnx.Param), f)
wandb.finish()

/tmp/ipykernel_1544095/2311977280.py:1: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  wandb.init(project=args.wandb_project, config=args.dict())


{'sec': 7.559342622756958, 'minatar-space_invaders/eval_R': 4.269999980926514, 'steps': 0}
{'sec': 7.819483041763306, 'minatar-space_invaders/eval_R': 4.630000114440918, 'steps': 524288}
{'sec': 8.08708643913269, 'minatar-space_invaders/eval_R': 6.119999885559082, 'steps': 1048576}
{'sec': 8.360270023345947, 'minatar-space_invaders/eval_R': 7.869999885559082, 'steps': 1572864}
{'sec': 8.639054775238037, 'minatar-space_invaders/eval_R': 12.029999732971191, 'steps': 2097152}
{'sec': 8.918978691101074, 'minatar-space_invaders/eval_R': 13.039999961853027, 'steps': 2621440}
{'sec': 9.198667764663696, 'minatar-space_invaders/eval_R': 17.75, 'steps': 3145728}
{'sec': 9.479210376739502, 'minatar-space_invaders/eval_R': 19.850000381469727, 'steps': 3670016}
{'sec': 9.759523391723633, 'minatar-space_invaders/eval_R': 23.94999885559082, 'steps': 4194304}
{'sec': 10.039384603500366, 'minatar-space_invaders/eval_R': 27.44999885559082, 'steps': 4718592}
{'sec': 10.31794261932373, 'minatar-space_inva

0,1
minatar-space_invaders/eval_R,▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▂▃▃▃▃▃▃▄▄▅▄▄▅▅▆▆▆▆▆▇████
sec,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
steps,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
minatar-space_invaders/eval_R,179.70999
sec,18.13342
steps,19922944.0
