## Atari Games

### Cart Pole

Let's download the data if it's not present in the folder:

In [10]:
!pip3.11 install gymnasium ray tianshou pygame stable_baselines3

Collecting stable_baselines3
  Downloading stable_baselines3-2.6.0-py3-none-any.whl.metadata (4.8 kB)
INFO: pip is looking at multiple versions of stable-baselines3 to determine which version is compatible with other requirements. This could take a while.
  Downloading stable_baselines3-2.5.0-py3-none-any.whl.metadata (4.8 kB)
  Downloading stable_baselines3-2.4.1-py3-none-any.whl.metadata (4.5 kB)
  Downloading stable_baselines3-2.4.0-py3-none-any.whl.metadata (4.5 kB)
  Downloading stable_baselines3-2.3.2-py3-none-any.whl.metadata (5.1 kB)
Downloading stable_baselines3-2.3.2-py3-none-any.whl (182 kB)
Installing collected packages: stable_baselines3
Successfully installed stable_baselines3-2.3.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


## Cart Pole

Let's explore the API behind any reinforcement learning environments, as provided by the gym library:

In [38]:
import gymnasium as gym
env = gym.make("CartPole-v1")
observation, info = env.reset(seed=42)

rewards = []
for _ in range(10000):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)
    rewards.append(reward)
    if terminated or truncated:
        observation, info = env.reset()
env.close()
print(observation)

[ 0.06839293  0.21602237 -0.12626234 -0.46528688]


Let's create animation methods using this API, as well as average-time evaluator, and check it with random-move generating model:

In [None]:
def animate_model(model, env=None):
    env = gym.make("CartPole-v1", render_mode="rgb_array") if not env else env
    frames = []
    obs, _ = env.reset(seed=42)
    for _ in range(300):
        frames.append(env.render())
        action, _ = model.predict(obs, deterministic=True)
        obs, _, terminated, truncated, _ = env.step(action)
        if terminated or truncated:
            break
    env.close()
    return frames

import matplotlib.pyplot as plt
from matplotlib import animation
def show_animation(frames):
    fig = plt.figure()
    plt.axis('off')
    im = plt.imshow(frames[0])
    def animate(i):
        im.set_array(frames[i])
        return [im]
    ani = animation.FuncAnimation(fig, animate, frames=len(frames), interval=50)
    plt.close()
    return HTML(ani.to_html5_video())

def evaluate_model(model, num_episodes=200, env=None, seed=None):
    total_seconds = 0.0
    env = env if env else gym.make("CartPole-v1")
    for _ in range(num_episodes):
        obs, _ = env.reset(seed=None)
        episode_steps = 0
        while True:
            action, _ = model.predict(obs, deterministic=True)
            obs, _, terminated, truncated, _ = env.step(action)
            episode_steps += 1
            if terminated or truncated:
                break
        total_seconds += episode_steps * 0.05
    env.close()
    return total_seconds / num_episodes

import random
class RandModel:
    def __init__(self):
        self.first = True
    def predict(self, obs, deterministic=False):
        if deterministic and self.first:
            self.first = False
            random.seed(42)
        return random.randint(0, 1), None

randmodel = RandModel()
average_seconds = evaluate_model(randmodel, num_episodes=100)
print(f"Average survival time: {average_seconds:.2f} seconds")
frames = animate_model(randmodel)
show_animation(frames)

Average survival time: 1.05 seconds


Now let's first use Stable Baselines3 library's DQN implementation just to touch the waters and see which models might potantially work. Next step takes around 30 seconds on my CPU:

In [None]:
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv

# first try stable baselines3 DQN agent
env = gym.make("CartPole-v1")
env = DummyVecEnv([lambda: env]) # wrap with stable baselines3

model = DQN(
    "MlpPolicy",
    env,
    verbose=0,
    learning_rate=1e-3,
    buffer_size=50000,
    learning_starts=1000,
    batch_size=128,
    gamma=0.99,
    exploration_final_eps=0.01,
    exploration_fraction=0.1,
)

model.learn(total_timesteps=100000)
model.save("dqn_cartpole")

Using cpu device
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.99     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 29691    |
|    time_elapsed     | 0        |
|    total_timesteps  | 103      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.98     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 25059    |
|    time_elapsed     | 0        |
|    total_timesteps  | 200      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.97     |
| time/               |          |
|    episodes         | 12       |
|    fps              | 21825    |
|    time_elapsed     | 0        |
|    total_timesteps  | 298      |
----------------------------------
----------------------------------
| r

Now let's evaluate and animate this:

In [77]:
average_seconds = evaluate_model(model, num_episodes=200)
print(f"Average survival time: {average_seconds:.2f} seconds")
frames = animate_model(model)
show_animation(frames)

Average survival time: 2.96 seconds


Now, let's switch to an actual model that runs on GPU and is implemented manually!

**(Next step takes around 8m to run on my device with GPU!)**

We will fully transfer the idea to torch for GPU support and a great control over architecture!

Next cell will implement this.

- Uses a fixed seed for reproducibility
- Parameters, training steps, etc are experimentally adjusted for faster training by me
- Training reward is tweaked to penalize cart position, cart speed, pole velocity and pole angular velocity! This is not a requirement: I just wanted to see a beautiful steady result.

This implementation achieves a "perfect" score of 25 seconds out of 25 possible for this game (500 steps), as measured by running on 200 random episodes with seeds that are not present in the training data. All runs are being truncated at 500 which is max possible for this game!

In [None]:
import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import os

def set_global_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True, warn_only=True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Reduce threading nondeterminism
    torch.set_num_threads(1)
    torch.set_num_interop_threads(1)


# Wrapper class subtracts a penalty proportional to |x|/x_threshold * alpha, then cart speed / max speed by beta, etc.
class DriftPenaltyEnv(gym.Wrapper):
    def __init__(self, env: gym.Env,
                alpha: float = 0.2,
                beta:  float = 0.3,
                gamma: float = 0.3,
                delta: float = 0.03):
        super().__init__(env)
        uw = self.unwrapped
        self.alpha = alpha
        self.beta  = beta
        self.gamma = gamma
        self.delta = delta

        # extract thresholds and max velocities
        self.x_threshold     = uw.x_threshold
        self.theta_threshold = uw.theta_threshold_radians

        self.tau             = uw.tau
        self.v_threshold     = self.x_threshold     / self.tau
        self.omega_threshold = self.theta_threshold / self.tau

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        x, x_dot, theta, theta_dot = obs

        # normalize penalties
        p_pos = abs(x)        / self.x_threshold
        p_vel = abs(x_dot)    / self.v_threshold
        p_ang = abs(theta)    / self.theta_threshold
        p_avel= abs(theta_dot)/ self.omega_threshold

        penalty = ( self.alpha * p_pos
                  + self.beta  * p_vel
                  + self.gamma * p_ang
                  + self.delta * p_avel )

        reward -= penalty
        return obs, reward, terminated, truncated, info

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(states).to(device),
            torch.LongTensor(actions).to(device),
            torch.FloatTensor(rewards).to(device),
            torch.FloatTensor(next_states).to(device),
            torch.FloatTensor(dones).to(device)
        )

    def __len__(self):
        return len(self.buffer)

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )

    def forward(self, x):
        return self.net(x)

class DQNAgent:
    def __init__(self, policy_net: nn.Module, device: torch.device):
        self.policy_net = policy_net
        self.device = device
        self.policy_net.eval()

    def predict(self, obs, deterministic=True):
        with torch.no_grad():
            obs_tensor = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
            q_values = self.policy_net(obs_tensor)
            action = q_values.argmax(dim=1).item()
        return action, None


def train_dqn(
    env,
    state_dim,
    action_dim,
    device,
    episodes=500,
    buffer_size=50000,
    batch_size=128,
    gamma=0.9995,
    lr=1e-4,
    target_update=500,
    eps_start=1.0,
    eps_end=0.01,
    eps_decay_steps=100_000,
    seed=42
):
    set_global_seed(seed)
    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    replay_buffer = ReplayBuffer(buffer_size)

    policy_net = DQN(state_dim, action_dim).to(device)
    target_net = DQN(state_dim, action_dim).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)

    total_steps = 0
    epsilon = eps_start

    for ep in range(1, episodes+1):
        state, _ = env.reset(seed=seed + ep)
        done = False
        ep_reward = 0.0

        while not done:
            total_steps += 1
            epsilon = max(eps_end,
                          eps_start - (eps_start - eps_end) * min(1.0, total_steps / eps_decay_steps))

            # select action
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    st = torch.FloatTensor(state).to(device)
                    action = policy_net(st).argmax().item()

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            replay_buffer.push((state, action, reward, next_state, float(done)))
            state = next_state
            ep_reward += reward

            # learning step
            if len(replay_buffer) >= batch_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

                # current Q
                curr_q = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
                # Double DQN target
                with torch.no_grad():
                    next_actions = policy_net(next_states).argmax(dim=1, keepdim=True)
                    next_q = target_net(next_states).gather(1, next_actions).squeeze()
                    target_q = rewards + (1 - dones) * gamma * next_q

                loss = nn.MSELoss()(curr_q, target_q)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10)
                optimizer.step()

            if total_steps % target_update == 0:
                target_net.load_state_dict(policy_net.state_dict())

        if ep % 20 == 0:
            print(f"Episode {ep}, Reward: {ep_reward:.2f}, Epsilon: {epsilon:.3f}")

    return policy_net


def evaluate_model(agent, env, episodes=200):
    total_time = 0.0
    truncated_episodes = 0
    for ep in range(episodes):
        state, _ = env.reset(seed=42239+ep) # reproducible yet different from any seed used in training
        done = False
        steps = 0
        while not done:
            action, _ = agent.predict(state)
            state, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            if truncated:
                truncated_episodes += 1
            steps += 1
        total_time += steps * 0.05
    return total_time / episodes, truncated_episodes / episodes


def animate_model(agent, env, max_steps=500):
    from IPython.display import HTML
    import matplotlib.pyplot as plt
    from matplotlib import animation

    frames = []
    state, _ = env.reset(seed=42)
    for _ in range(max_steps):
        frames.append(env.render())
        action, _ = agent.predict(state)
        state, _, terminated, truncated, _ = env.step(action)
        if terminated or truncated:
            break
    env.close()

    fig = plt.figure()
    plt.axis('off')
    im = plt.imshow(frames[0])
    def update(i):
        im.set_array(frames[i])
        return [im]
    ani = animation.FuncAnimation(fig, update, frames=len(frames), interval=50)
    plt.close()
    return HTML(ani.to_html5_video())

device = torch.device("mps"  if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Base environment + drift penalty wrapper
base_env = gym.make("CartPole-v1", render_mode="rgb_array")
env = DriftPenaltyEnv(base_env)

# Start training
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy_net = train_dqn(
    env,
    state_dim,
    action_dim,
    device,
    episodes=290,
    buffer_size=100000,
    batch_size=32,
    gamma=0.9995,
    lr=5*1e-4,
    target_update=100,
    eps_start=1.0,
    eps_end=0.01,
    eps_decay_steps=10_000,
    seed=42
)

# Evaluation
agent = DQNAgent(policy_net, device)
avg_time, truncated_fraction = evaluate_model(agent, env)
print(f"Average survival: {avg_time:.2f} seconds, {truncated_fraction*100:.2f}% truncated")

# Animate
html_video = animate_model(agent, env)
from IPython.display import display
display(html_video)

  torch.FloatTensor(states).to(device),


Episode 20, Reward: 8.34, Epsilon: 0.935
Episode 40, Reward: 30.13, Epsilon: 0.891
Episode 60, Reward: 16.15, Epsilon: 0.832
Episode 80, Reward: 34.14, Epsilon: 0.774
Episode 100, Reward: 43.04, Epsilon: 0.686
Episode 120, Reward: 106.12, Epsilon: 0.560
Episode 140, Reward: 96.87, Epsilon: 0.404
Episode 160, Reward: 150.62, Epsilon: 0.129
Episode 180, Reward: 202.35, Epsilon: 0.010
Episode 200, Reward: 160.30, Epsilon: 0.010
Episode 220, Reward: 142.54, Epsilon: 0.010
Episode 240, Reward: 493.54, Epsilon: 0.010
Episode 260, Reward: 480.69, Epsilon: 0.010
Episode 280, Reward: 451.55, Epsilon: 0.010
Average survival: 25.00 seconds, 100.00% truncated


### Space Invaders

Now, onto the next game!

In [126]:
!pip3.11 install ale-py autorom

Collecting ale-py
  Downloading ale_py-0.11.0-cp311-cp311-macosx_13_0_arm64.whl.metadata (8.2 kB)
Collecting autorom
  Downloading AutoROM-0.6.1-py3-none-any.whl.metadata (2.4 kB)
Downloading ale_py-0.11.0-cp311-cp311-macosx_13_0_arm64.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading AutoROM-0.6.1-py3-none-any.whl (9.4 kB)
Installing collected packages: ale-py, autorom
Successfully installed ale-py-0.11.0 autorom-0.6.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


The next step just accepts the license for ROM, that is supplementary materials such as space invader images. This is only for animation

In [127]:
!AutoROM --accept-license

AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/opt/homebrew/lib/python3.11/site-packages/AutoROM/roms

Existing ROMs will be overwritten.
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/adventure.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/air_raid.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/alien.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/amidar.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/assault.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/asterix.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/asteroids.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/atlantis.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/atlantis2.bin
Installed /opt/homebrew/lib/python3.11/site-packages/AutoROM/roms/backgammon.bin
Installed /opt/homebrew/lib/python3.11/site-packa

In [None]:
!pip3.11 install "gymnasium[atari, accept-rom-license]"

Now, again, we will explore the API for the game, create methods for further use, evaluate and animate it just for the random-action strategy. However, this time, we won't focus on the survival time - only on the rewards of the game

In [None]:
import gymnasium as gym
import random
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import ale_py

env = gym.make("ALE/SpaceInvaders-v5")
observation, info = env.reset(seed=42)
rewards = []
for _ in range(10000):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)
    rewards.append(reward)
    if terminated or truncated:
        observation, info = env.reset()
env.close()
print("Rewards max, mean, std:", np.max(rewards), np.mean(rewards), np.std(rewards))
print("Observation shape:", observation.shape)

def animate_model(model, env=None):
    env = gym.make("ALE/SpaceInvaders-v5", render_mode="rgb_array") if env is None else env
    frames = []
    obs, _ = env.reset(seed=42)
    for _ in range(300):
        frames.append(env.render())
        action, _ = model.predict(obs, deterministic=True)
        obs, _, terminated, truncated, _ = env.step(action)
        if terminated or truncated:
            break
    env.close()
    return frames

def show_animation(frames):
    fig = plt.figure()
    plt.axis('off')
    im = plt.imshow(frames[0])
    def animate(i):
        im.set_array(frames[i])
        return [im]
    ani = animation.FuncAnimation(fig, animate, frames=len(frames), interval=50)
    plt.close()
    return HTML(ani.to_html5_video())

def evaluate_model(model, num_episodes=200, env=None, seed=42239):
    total_rewards = 0
    env = env if env else gym.make("ALE/SpaceInvaders-v5")
    for ep in range(num_episodes):
        obs, _ = env.reset(seed=seed+ep)
        steps = 0
        while True:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, _ = env.step(action)
            total_rewards += reward
            steps += 1
            if terminated or truncated:
                break
    env.close()
    return total_rewards / num_episodes

class RandModel:
    def __init__(self, action_space):
        self.action_space = action_space
        self.first = True
    def predict(self, obs, deterministic=False):
        if deterministic and self.first:
            random.seed(42)
            self.first = False
        return self.action_space.sample(), None

randmodel = RandModel(gym.make("ALE/SpaceInvaders-v5").action_space)
avg_reward = evaluate_model(randmodel, num_episodes=100)
print(f"Average reward: {avg_reward:.2f}")
frames = animate_model(randmodel)
show_animation(frames)

A.L.E: Arcade Learning Environment (version 0.11.0+dfae0bd)
[Powered by Stella]


Rewards max, mean, std: 200.0 0.28 2.9785231239659695
Observation shape: (210, 160, 3)
Average reward: 142.70


Next, after studying DQNs and making a lot of experiments with different techniques and architectures, we will implement a classic double dueling DQN architecture, reward clipping, prioritized experience replay buffer and n-step rewards! Additionally, there is a warm-up and "useful action" prioritization.

However, it takes a long time to learn, even with GPU!

In [None]:
import os
import ale_py
import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
from gymnasium.wrappers import AtariPreprocessing

# Try to import FrameStackObservation for newer gym versions
try:
    from gymnasium.wrappers import FrameStackObservation as FrameStack
    FS = "stack_size"
except ImportError:
    from gymnasium.wrappers import FrameStack
    FS = "num_stack"


def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class PrioritizedReplayBuffer:
    def __init__(self, capacity: int, device: torch.device, alpha: float = 0.6):
        self.buffer = deque(maxlen=capacity)
        self.priorities = []
        self.device = device
        self.alpha = alpha
        self.next_idx = 0

    def push(self, transition, td_error: torch.Tensor = None):
        if td_error is None:
            priority = max(self.priorities, default=1.0)
        else:
            priority = float(td_error.abs().cpu().item()) + 1e-6

        if len(self.buffer) < self.buffer.maxlen:
            self.buffer.append(transition)
            self.priorities.append(priority)
        else:
            self.buffer[self.next_idx] = transition
            self.priorities[self.next_idx] = priority
            self.next_idx = (self.next_idx + 1) % self.buffer.maxlen

    def sample(self, batch_size: int, beta: float = 0.4):
        N = len(self.buffer)
        prios = np.array(self.priorities[:N], dtype=np.float32) ** self.alpha
        probs = prios / prios.sum()

        indices = np.random.choice(N, batch_size, p=probs)
        batch = [self.buffer[i] for i in indices]
        states, actions, rewards, next_states, dones = zip(*batch)

        weights = (N * probs[indices]) ** (-beta)
        weights /= weights.max()
        w_tensor = torch.tensor(weights, dtype=torch.float32, device=self.device)

        states_np = np.stack(states, axis=0)
        next_states_np = np.stack(next_states, axis=0)
        s = torch.from_numpy(states_np).float().to(self.device) / 255.0
        s2 = torch.from_numpy(next_states_np).float().to(self.device) / 255.0

        return (
            s,
            torch.tensor(actions, dtype=torch.int64, device=self.device),
            torch.tensor(rewards, dtype=torch.float32, device=self.device),
            s2,
            torch.tensor(dones, dtype=torch.float32, device=self.device),
            indices,
            w_tensor,
        )

    def update_priorities(self, indices, td_errors):
        for idx, td in zip(indices, td_errors):
            self.priorities[idx] = abs(float(td.cpu().item())) + 1e-6

    def __len__(self):
        return len(self.buffer)


class NStepBuffer:
    def __init__(self, n: int, gamma: float):
        self.n = n
        self.gamma = gamma
        self.buffer = deque()

    def append(self, transition):  # (state, action, reward, next_state, done)
        self.buffer.append(transition)

    def is_ready(self):
        return len(self.buffer) >= self.n

    def pop(self):
        return self.buffer.popleft()

    def get_n_step_transition(self):
        R = 0.0
        for idx, (_, _, reward, _, done) in enumerate(self.buffer):
            R += (self.gamma ** idx) * reward
            if done:
                break
        state0, action0, _, _, _ = self.buffer[0]
        _, _, _, next_state_n, done_n = self.buffer[-1]
        return (state0, action0, R, next_state_n, done_n)

    def clear(self):
        self.buffer.clear()


class FastDuelingDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super().__init__()
        c, h, w = input_shape
        self.features = nn.Sequential(
            nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            flat_size = self.features(dummy).shape[1]
        self.value_stream = nn.Sequential(
            nn.Linear(flat_size, 512), nn.ReLU(),
            nn.Linear(512, 1)
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(flat_size, 512), nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        f = self.features(x)
        v = self.value_stream(f)
        a = self.advantage_stream(f)
        return v + (a - a.mean(dim=1, keepdim=True))


class DQNAgent:
    def __init__(self, network: nn.Module, device: torch.device):
        self.network = network.to(device)
        self.device = device

    def select_action(self, state_np):
        state = torch.from_numpy(state_np).unsqueeze(0).float().to(self.device) / 255.0
        with torch.no_grad():
            q = self.network(state)
        return int(q.argmax(1).item())


class RewardClippingWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, term, trunc, info = self.env.step(action)
        clipped = np.clip(reward, -1.0, 1.0)
        return obs, clipped, term, trunc, info


def train_optimized_dqn(
    env_name="ALE/SpaceInvaders-v5",
    num_episodes=10000,
    replay_buffer_capacity=1_000_000,
    batch_size=32,
    gamma=0.99,
    lr=6.25e-5,
    adam_eps=1.5e-4,
    n_steps=3,
    learning_freq=4,
    target_update_freq=32_000,
    epsilon_start=1.0,
    epsilon_final=0.01,
    epsilon_decay_steps=500_000,
    prio_alpha=0.6,
    beta_start=0.4,
    beta_frames=1_000_000,
    warmup_steps=50_000,
    report_interval=50,
    eval_interval=100,
    checkpoint_dir="checkpoints",
    seed=42,
):
    set_global_seed(seed)
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Environments
    base_env = gym.make("SpaceInvadersNoFrameskip-v4")
    atari_env = AtariPreprocessing(base_env, frame_skip=4, grayscale_obs=True, scale_obs=False, noop_max=30, terminal_on_life_loss=True)
    env = RewardClippingWrapper(FrameStack(atari_env, **{FS: 4}))
    eval_env = FrameStack(
        AtariPreprocessing(gym.make("SpaceInvadersNoFrameskip-v4"), frame_skip=4, grayscale_obs=True, scale_obs=False, noop_max=30, terminal_on_life_loss=False),
        **{FS: 4}
    )

    obs, _ = env.reset(seed=seed)
    obs_shape = env.observation_space.shape
    n_actions = env.action_space.n
    print(f"Observation shape: {obs_shape}, Actions: {n_actions}")

    # Networks & optimizer
    policy_net = FastDuelingDQN(obs_shape, n_actions).to(device)
    target_net = FastDuelingDQN(obs_shape, n_actions).to(device)
    target_net.load_state_dict(policy_net.state_dict()); target_net.eval()
    optimizer = optim.Adam(policy_net.parameters(), lr=lr, eps=adam_eps)

    # Replay & agent
    replay = PrioritizedReplayBuffer(replay_buffer_capacity, device, prio_alpha)
    agent = DQNAgent(policy_net, device)

    total_steps = 0
    epsilon = epsilon_start
    episode_rewards = deque(maxlen=100)
    eval_scores = deque(maxlen=20)
    best_eval_score = -float('inf')

    def evaluate_agent(n_eps=5):
        scores = []
        for _ in range(n_eps):
            o, _ = eval_env.reset()
            s = np.array(o, dtype=np.uint8)
            done = False; score = 0
            while not done:
                a = agent.select_action(s)
                o2, r, t, tr, _ = eval_env.step(a)
                done = t or tr; score += r; s = np.array(o2, dtype=np.uint8)
            scores.append(score)
        return np.mean(scores)

    # Warmup with n-step
    print("Warmup...")
    nbuf = NStepBuffer(n_steps, gamma)
    o, _ = env.reset(seed=seed)
    s = np.array(o, dtype=np.uint8)
    for step in range(warmup_steps):
        total_steps += 1
        # exploration heuristic
        if random.random() < 0.6:
            a = random.choice([1,2,3])
        elif random.random() < 0.8:
            a = random.choice([0,1])
        else:
            a = env.action_space.sample()
        o2, r, t, tr, _ = env.step(a)
        done = t or tr
        s2 = np.array(o2, dtype=np.uint8)
        nbuf.append((s, a, r, s2, done))
        if nbuf.is_ready():
            trans = nbuf.get_n_step_transition(); replay.push(trans)
            nbuf.pop()
        if done:
            nbuf.clear(); o2, _ = env.reset(seed=seed)
            s2 = np.array(o2, dtype=np.uint8)
        s = s2

    print("Training...")
    for episode in range(1, num_episodes+1):
        o, _ = env.reset(seed=seed+episode)
        s = np.array(o, dtype=np.uint8)
        done = False; ep_reward = 0
        nbuf.clear()
        while not done:
            total_steps += 1
            # epsilon schedule
            epsilon = max(epsilon_final, epsilon_start - (epsilon_start-epsilon_final)*total_steps/epsilon_decay_steps)
            # select action
            if random.random() < epsilon:
                a = env.action_space.sample()
            else:
                a = agent.select_action(s)
            # step
            o2, r, t, tr, _ = env.step(a)
            done = t or tr; ep_reward += r
            s2 = np.array(o2, dtype=np.uint8)
            # multi-step store
            nbuf.append((s, a, r, s2, done))
            if nbuf.is_ready():
                replay.push(nbuf.get_n_step_transition())
                nbuf.pop()
            if done:
                # flush remaining
                while len(nbuf.buffer) > 0:
                    replay.push(nbuf.get_n_step_transition()); nbuf.pop()
            s = s2

            # learning step
            if total_steps % learning_freq == 0 and len(replay) >= batch_size:
                beta = min(1.0, beta_start + total_steps*(1.0-beta_start)/beta_frames)
                s_b, a_b, r_b, s2_b, d_b, idxs, w_b = replay.sample(batch_size, beta)
                policy_net.train()
                with torch.no_grad():
                    next_a = policy_net(s2_b).argmax(1, keepdim=True)
                    next_q = target_net(s2_b).gather(1, next_a).squeeze(1)
                    target = r_b + (1-d_b)*(gamma**n_steps)*next_q
                current = policy_net(s_b).gather(1, a_b.unsqueeze(1)).squeeze(1)
                td_err = target - current
                loss = (w_b * F.smooth_l1_loss(current, target, reduction='none')).mean()
                optimizer.zero_grad(); loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
                optimizer.step()
                replay.update_priorities(idxs, td_err.detach())

            # target net update
            if total_steps % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())

        episode_rewards.append(ep_reward)
        avg_reward = np.mean(episode_rewards)

        # Periodic evaluation
        if episode % eval_interval == 0:
            eval_score = evaluate_agent()
            eval_scores.append(eval_score)
            avg_eval_score = np.mean(eval_scores)
            
            if eval_score > best_eval_score:
                best_eval_score = eval_score
                best_path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_best_model.pth")
                torch.save(policy_net.state_dict(), best_path)
                print(f"! NEW BEST! Episode {episode} | Eval Score: {eval_score:.1f} | Best: {best_eval_score:.1f}")

        # Logging and checkpointing
        if episode % report_interval == 0:
            path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_ep{episode}.pth")
            torch.save(policy_net.state_dict(), path)
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Episode {episode:4d} | Reward {ep_reward:7.2f} | Avg {avg_reward:7.2f} | Eps {epsilon:.3f} | LR {current_lr:.2e} | Buffer {len(replay):6d}")

    print(f"Done! Best eval score {best_eval_score:.1f}")
    return agent

trained = train_optimized_dqn(
    env_name="ALE/SpaceInvaders-v5",
    num_episodes=50000,
    seed=42,
)

Using device: mps
Observation shape: (4, 84, 84), Actions: 6


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Warmup...
Training...
Episode   50 | Reward    9.00 | Avg    4.28 | Eps 0.882 | LR 6.25e-05 | Buffer  59309
! NEW BEST! Episode 100 | Eval Score: 212.0 | Best: 212.0
Episode  100 | Reward    3.00 | Avg    4.43 | Eps 0.862 | LR 6.25e-05 | Buffer  69074
Episode  150 | Reward    6.00 | Avg    4.54 | Eps 0.839 | LR 6.25e-05 | Buffer  80841
Episode  200 | Reward    4.00 | Avg    4.16 | Eps 0.821 | LR 6.25e-05 | Buffer  89844
Episode  250 | Reward    6.00 | Avg    4.01 | Eps 0.801 | LR 6.25e-05 | Buffer 100000
Episode  300 | Reward    5.00 | Avg    4.52 | Eps 0.780 | LR 6.25e-05 | Buffer 110771
Episode  350 | Reward    6.00 | Avg    4.53 | Eps 0.761 | LR 6.25e-05 | Buffer 120467
Episode  400 | Reward    3.00 | Avg    4.50 | Eps 0.740 | LR 6.25e-05 | Buffer 130642
Episode  450 | Reward    8.00 | Avg    4.86 | Eps 0.721 | LR 6.25e-05 | Buffer 140467
Episode  500 | Reward   12.00 | Avg    4.82 | Eps 0.699 | LR 6.25e-05 | Buffer 151542
Episode  550 | Reward    2.00 | Avg    4.07 | Eps 0.682 | LR

Here I realized that the training is going slowly due to buffer's straightforward and inefficient implementation. Here I improved this using segment tree, numba jit and data storage on the device. I also implemented the ability to continue from the existing checkpoint. This implementation is more than 3 times faster!

In [1]:
!pip3.11 install numba


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m


In [None]:
import os
import ale_py
import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
from gymnasium.wrappers import AtariPreprocessing
import glob
import re
import numba
from numba import jit, prange
from typing import Tuple, Optional

# Try to import FrameStackObservation for newer gym versions
try:
    from gymnasium.wrappers import FrameStackObservation as FrameStack
    FS = "stack_size"
except ImportError:
    from gymnasium.wrappers import FrameStack
    FS = "num_stack"


def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@jit(nopython=True)
def update_segment_tree(tree: np.ndarray, capacity: int, idx: int, value: float):
    idx += capacity
    tree[idx] = value

    while idx > 1:
        idx //= 2
        tree[idx] = tree[2 * idx] + tree[2 * idx + 1]

@jit(nopython=True)
def find_prefix_sum_idx(tree: np.ndarray, capacity: int, prefix_sum: float) -> int:
    # find highest index with prefix sum <= prefix_sum
    idx = 1
    while idx < capacity:
        left = 2 * idx
        if tree[left] > prefix_sum:
            idx = left
        else:
            prefix_sum -= tree[left]
            idx = left + 1    
    return idx - capacity

@jit(nopython=True, parallel=True)
def sample_indices_numba(tree: np.ndarray, capacity: int, batch_size: int, 
                         total_priority: float) -> np.ndarray:
    indices = np.empty(batch_size, dtype=np.int32)
    segment = total_priority / batch_size
    for i in prange(batch_size):
        a = segment * i
        b = segment * (i + 1)
        cumsum = np.random.uniform(a, b)
        indices[i] = find_prefix_sum_idx(tree, capacity, cumsum)
    return indices


class OptimizedPrioritizedReplayBuffer:
    """
    - Segment tree for O(log n) operations
    - Numba JIT compilation
    - Pre-allocated arrays
    - Vectorized operations
    - Direct GPU tensor creation
    """
    
    def __init__(self, capacity: int, device: torch.device, 
                 alpha: float = 0.6, state_shape: Tuple[int, ...] = (4, 84, 84)):
        self.capacity = capacity
        self.device = device
        self.alpha = alpha
        self.beta = 0.4
        self.eps = 1e-6
        self.state_shape = state_shape

        self.states = np.zeros((capacity, *state_shape), dtype=np.uint8)
        self.actions = np.zeros(capacity, dtype=np.int32)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8)
        self.dones = np.zeros(capacity, dtype=np.float32)

        tree_capacity = 1
        while tree_capacity < capacity:
            tree_capacity *= 2
        self.tree_capacity = tree_capacity
        self.sum_tree = np.zeros(2 * tree_capacity, dtype=np.float32)
        
        self.max_priority = 1.0
        self.ptr = 0
        self.size = 0

        self._preallocate_batch_tensors(32)
        
    def _preallocate_batch_tensors(self, batch_size: int):
        
        self.batch_size_allocated = batch_size
        self.batch_states = torch.zeros((batch_size, *self.state_shape), 
                                       dtype=torch.float32, device=self.device)
        self.batch_next_states = torch.zeros((batch_size, *self.state_shape), 
                                            dtype=torch.float32, device=self.device)
        self.batch_actions = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
        self.batch_rewards = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
        self.batch_dones = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
        self.batch_weights = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
    
    def push(self, transition, td_error: torch.Tensor = None):
        state, action, reward, next_state, done = transition
        idx = self.ptr
        self.states[idx] = state
        self.actions[idx] = action
        self.rewards[idx] = reward
        self.next_states[idx] = next_state
        self.dones[idx] = float(done)

        if td_error is None:
            priority = self.max_priority
        else:
            priority = (float(td_error.abs().cpu().item()) + self.eps) ** self.alpha
        
        update_segment_tree(self.sum_tree, self.tree_capacity, idx, priority)
        self.max_priority = max(self.max_priority, priority)
        
        # update pointers
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
    
    def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[torch.Tensor, ...]:
        if batch_size > self.batch_size_allocated:
            self._preallocate_batch_tensors(int(batch_size * 1.5))

        total_priority = self.sum_tree[1]

        indices = sample_indices_numba(self.sum_tree, self.tree_capacity, batch_size, total_priority)
        indices = np.clip(indices, 0, self.size - 1)
        
        priorities = np.array([self.sum_tree[idx + self.tree_capacity] for idx in indices], dtype=np.float32)

        probs = priorities / total_priority
        weights = (self.size * probs) ** (-beta)
        weights /= weights.max()

        batch_states = self.batch_states[:batch_size]
        batch_actions = self.batch_actions[:batch_size]
        batch_rewards = self.batch_rewards[:batch_size]
        batch_next_states = self.batch_next_states[:batch_size]
        batch_dones = self.batch_dones[:batch_size]
        batch_weights = self.batch_weights[:batch_size]

        # copy data to GPU
        batch_states.copy_(torch.from_numpy(self.states[indices]).float() / 255.0)
        batch_actions.copy_(torch.from_numpy(self.actions[indices]).long())
        batch_rewards.copy_(torch.from_numpy(self.rewards[indices]))
        batch_next_states.copy_(torch.from_numpy(self.next_states[indices]).float() / 255.0)
        batch_dones.copy_(torch.from_numpy(self.dones[indices]))
        batch_weights.copy_(torch.from_numpy(weights))
        
        return (
            batch_states,
            batch_actions,
            batch_rewards,
            batch_next_states,
            batch_dones,
            indices,
            batch_weights
        )

    def update_priorities(self, indices: np.ndarray, td_errors: torch.Tensor):
        td_errors_np = td_errors.detach().cpu().numpy()
        priorities = (np.abs(td_errors_np) + self.eps) ** self.alpha

        for idx, priority in zip(indices, priorities):
            update_segment_tree(self.sum_tree, self.tree_capacity, idx, priority)
        
        self.max_priority = max(self.max_priority, priorities.max())
    
    def __len__(self):
        return self.size


class NStepBuffer:
    def __init__(self, n: int, gamma: float):
        self.n = n
        self.gamma = gamma
        self.buffer = deque()

    def append(self, transition):
        self.buffer.append(transition)

    def is_ready(self):
        return len(self.buffer) >= self.n

    def pop(self):
        return self.buffer.popleft()

    def get_n_step_transition(self):
        R = 0.0
        for idx, (_, _, reward, _, done) in enumerate(self.buffer):
            R += (self.gamma ** idx) * reward
            if done:
                break
        state0, action0, _, _, _ = self.buffer[0]
        _, _, _, next_state_n, done_n = self.buffer[-1]
        return (state0, action0, R, next_state_n, done_n)

    def clear(self):
        self.buffer.clear()


class FastDuelingDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super().__init__()
        c, h, w = input_shape
        self.features = nn.Sequential(
            nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            flat_size = self.features(dummy).shape[1]
        self.value_stream = nn.Sequential(
            nn.Linear(flat_size, 512), nn.ReLU(),
            nn.Linear(512, 1)
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(flat_size, 512), nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        f = self.features(x)
        v = self.value_stream(f)
        a = self.advantage_stream(f)
        return v + (a - a.mean(dim=1, keepdim=True))


class DQNAgent:
    def __init__(self, network: nn.Module, device: torch.device):
        self.network = network.to(device)
        self.device = device

    def select_action(self, state_np):
        state = torch.from_numpy(state_np).unsqueeze(0).float().to(self.device) / 255.0
        with torch.no_grad():
            q = self.network(state)
        return int(q.argmax(1).item())


class RewardClippingWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, term, trunc, info = self.env.step(action)
        clipped = np.clip(reward, -1.0, 1.0)
        return obs, clipped, term, trunc, info


def find_latest_checkpoint(checkpoint_dir):
    pattern = os.path.join(checkpoint_dir, "ALE_SpaceInvaders-v5_ep*.pth")
    files = glob.glob(pattern)
    if not files:
        return None, 0
    episodes = []
    for f in files:
        match = re.search(r'_ep(\d+)\.pth$', f)
        if match:
            episodes.append((int(match.group(1)), f))
    if episodes:
        episodes.sort(key=lambda x: x[0])
        return episodes[-1][1], episodes[-1][0]
    return None, 0

def resume_training_dqn(
    env_name="ALE/SpaceInvaders-v5",
    num_episodes=10000,
    replay_buffer_capacity=1_000_000,
    batch_size=32,
    gamma=0.99,
    lr=6.25e-5,
    adam_eps=1.5e-4,
    n_steps=3,
    learning_freq=4,
    target_update_freq=32_000,
    epsilon_start=1.0,
    epsilon_final=0.01,
    epsilon_decay_steps=500_000,
    prio_alpha=0.6,
    beta_start=0.4,
    beta_frames=1_000_000,
    report_interval=50,
    eval_interval=100,
    checkpoint_dir="checkpoints",
    seed=42,
    buffer_refill_episodes=100,
):
    set_global_seed(seed)
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Find latest checkpoint
    latest_checkpoint, last_episode = find_latest_checkpoint(checkpoint_dir)
    if not latest_checkpoint:
        print("No checkpoint found! Please check the checkpoint directory.")
        return
    print(f"Found checkpoint: {latest_checkpoint} (Episode {last_episode})")

    # Calculate approximate total steps
    estimated_steps_per_episode = 200
    total_steps = last_episode * estimated_steps_per_episode
    
    # Environments
    base_env = gym.make("SpaceInvadersNoFrameskip-v4")
    atari_env = AtariPreprocessing(base_env, frame_skip=4, grayscale_obs=True, scale_obs=False, noop_max=30, terminal_on_life_loss=True)
    env = RewardClippingWrapper(FrameStack(atari_env, **{FS: 4}))
    eval_env = FrameStack(
        AtariPreprocessing(gym.make("SpaceInvadersNoFrameskip-v4"), frame_skip=4, grayscale_obs=True, scale_obs=False, noop_max=30, terminal_on_life_loss=False),
        **{FS: 4}
    )

    obs, _ = env.reset(seed=seed)
    obs_shape = env.observation_space.shape
    n_actions = env.action_space.n
    print(f"Observation shape: {obs_shape}, Actions: {n_actions}")

    # Networks & optimizer
    policy_net = FastDuelingDQN(obs_shape, n_actions).to(device)
    target_net = FastDuelingDQN(obs_shape, n_actions).to(device)

    # Load checkpoint
    print(f"Loading model from {latest_checkpoint}")
    checkpoint_state = torch.load(latest_checkpoint, map_location=device)
    policy_net.load_state_dict(checkpoint_state)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=lr, eps=adam_eps)

    # Check for best model
    best_model_path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_best_model.pth")
    best_eval_score = -float('inf')
    if os.path.exists(best_model_path):
        print(f"Found best model at {best_model_path}")

    # Create replay buffer
    print("Creating replay buffer...")
    replay = OptimizedPrioritizedReplayBuffer(
        replay_buffer_capacity, 
        device, 
        prio_alpha,
        state_shape=obs_shape
    )
    agent = DQNAgent(policy_net, device)

    # Calculate current epsilon
    epsilon = max(epsilon_final, epsilon_start - (epsilon_start - epsilon_final) * total_steps / epsilon_decay_steps)
    print(f"Calculated epsilon: {epsilon:.4f} based on estimated {total_steps} total steps")

    episode_rewards = deque(maxlen=100)
    eval_scores = deque(maxlen=20)

    def evaluate_agent(n_eps=5):
        scores = []
        for _ in range(n_eps):
            o, _ = eval_env.reset()
            s = np.array(o, dtype=np.uint8)
            done = False
            score = 0
            while not done:
                a = agent.select_action(s)
                o2, r, t, tr, _ = eval_env.step(a)
                done = t or tr
                score += r
                s = np.array(o2, dtype=np.uint8)
            scores.append(score)
        return np.mean(scores)

    # Evaluate current performance
    print("Evaluating current model performance...")
    current_eval = evaluate_agent()
    print(f"Current model evaluation score: {current_eval:.1f}")

    if os.path.exists(best_model_path):
        # Load and evaluate best model
        best_net = FastDuelingDQN(obs_shape, n_actions).to(device)
        best_net.load_state_dict(torch.load(best_model_path, map_location=device))
        best_agent = DQNAgent(best_net, device)
        temp_agent = agent
        agent = best_agent
        best_eval_score = evaluate_agent()
        agent = temp_agent
        print(f"Best model evaluation score: {best_eval_score:.1f}")
    else:
        best_eval_score = current_eval

    # Refill replay buffer with realistic play
    print(f"\nRefilling replay buffer with {buffer_refill_episodes} episodes of realistic play...")
    nbuf = NStepBuffer(n_steps, gamma)
    
    # Track refill performance
    refill_start_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
    refill_end_time = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None

    if refill_start_time:
        refill_start_time.record()

    import time
    cpu_start_time = time.time()

    for refill_ep in range(buffer_refill_episodes):
        o, _ = env.reset(seed=seed + last_episode + refill_ep)
        s = np.array(o, dtype=np.uint8)
        done = False
        ep_reward = 0
        nbuf.clear()
        while not done:
            if random.random() < epsilon:
                a = env.action_space.sample()
            else:
                a = agent.select_action(s)
            o2, r, t, tr, _ = env.step(a)
            done = t or tr
            ep_reward += r
            s2 = np.array(o2, dtype=np.uint8)

            # Store in n-step buffer
            nbuf.append((s, a, r, s2, done))
            if nbuf.is_ready():
                replay.push(nbuf.get_n_step_transition())
                nbuf.pop()
            if done:
                # Flush remaining transitions
                while len(nbuf.buffer) > 0:
                    replay.push(nbuf.get_n_step_transition())
                    nbuf.pop()
            s = s2
        if (refill_ep + 1) % 10 == 0:
            print(f"Refill episode {refill_ep + 1}/{buffer_refill_episodes}, "
                  f"Reward: {ep_reward:.1f}, Buffer size: {len(replay)}")

    cpu_refill_time = time.time() - cpu_start_time
    print(f"Buffer refilled with {len(replay)} transitions in {cpu_refill_time:.1f} seconds")
    
    if refill_end_time:
        refill_end_time.record()
        torch.cuda.synchronize()
        gpu_refill_time = refill_start_time.elapsed_time(refill_end_time) / 1000.0
        print(f"GPU timing: {gpu_refill_time:.1f} seconds")

    # Continue training
    print(f"\nResuming training from episode {last_episode + 1}...")

    # Track training performance and start training
    step_times = deque(maxlen=1000)
    for episode in range(last_episode + 1, num_episodes + 1):
        o, _ = env.reset(seed=seed + episode)
        s = np.array(o, dtype=np.uint8)
        done = False
        ep_reward = 0
        nbuf.clear()
        episode_steps = 0

        while not done:
            step_start = time.time()
            total_steps += 1
            episode_steps += 1
            epsilon = max(epsilon_final, epsilon_start - (epsilon_start - epsilon_final) * total_steps / epsilon_decay_steps)

            if random.random() < epsilon:
                a = env.action_space.sample()
            else:
                a = agent.select_action(s)

            # Step
            o2, r, t, tr, _ = env.step(a)
            done = t or tr
            ep_reward += r
            s2 = np.array(o2, dtype=np.uint8)

            # Multi-step store
            nbuf.append((s, a, r, s2, done))
            if nbuf.is_ready():
                replay.push(nbuf.get_n_step_transition())
                nbuf.pop()

            if done:
                # Flush remaining
                while len(nbuf.buffer) > 0:
                    replay.push(nbuf.get_n_step_transition())
                    nbuf.pop()

            s = s2

            # Learning step
            if total_steps % learning_freq == 0 and len(replay) >= batch_size:
                beta = min(1.0, beta_start + total_steps * (1.0 - beta_start) / beta_frames)
                s_b, a_b, r_b, s2_b, d_b, idxs, w_b = replay.sample(batch_size, beta)
                
                policy_net.train()
                with torch.no_grad():
                    next_a = policy_net(s2_b).argmax(1, keepdim=True)
                    next_q = target_net(s2_b).gather(1, next_a).squeeze(1)
                    target = r_b + (1 - d_b) * (gamma ** n_steps) * next_q

                current = policy_net(s_b).gather(1, a_b.unsqueeze(1)).squeeze(1)
                td_err = target - current
                loss = (w_b * F.smooth_l1_loss(current, target, reduction='none')).mean()

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
                optimizer.step()

                replay.update_priorities(idxs, td_err.detach())

            # Target net update
            if total_steps % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())
                print(f"Updated target network at step {total_steps}")

            step_times.append(time.time() - step_start)

        episode_rewards.append(ep_reward)
        avg_reward = np.mean(episode_rewards)

        # Periodic evaluation
        if episode % eval_interval == 0:
            eval_score = evaluate_agent()
            eval_scores.append(eval_score)

            if eval_score > best_eval_score:
                best_eval_score = eval_score
                best_path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_best_model.pth")
                torch.save(policy_net.state_dict(), best_path)
                print(f"! NEW BEST! Episode {episode} | Eval Score: {eval_score:.1f} | Best: {best_eval_score:.1f}")

        # Logging and checkpointing
        if episode % report_interval == 0:
            path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_ep{episode}.pth")
            torch.save(policy_net.state_dict(), path)
            current_lr = optimizer.param_groups[0]['lr']
            avg_step_time = np.mean(step_times) * 1000 if step_times else 0
            steps_per_sec = 1.0 / np.mean(step_times) if step_times and np.mean(step_times) > 0 else 0
            
            print(f"Episode {episode:4d} | Reward {ep_reward:7.2f} | Avg {avg_reward:7.2f} | "
                  f"Eps {epsilon:.3f} | LR {current_lr:.2e} | Buffer {len(replay):6d} | "
                  f"Steps {total_steps} | {steps_per_sec:.1f} steps/s | {avg_step_time:.1f}ms/step")

    print(f"Training complete! Best eval score: {best_eval_score:.1f}")
    return agent

trained = resume_training_dqn(
    env_name="ALE/SpaceInvaders-v5",
    num_episodes=50000,
    seed=42,
    buffer_refill_episodes=1000,
)

Using device: mps
Found checkpoint: checkpoints/ALE_SpaceInvaders-v5_ep6300.pth (Episode 6300)


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 6
Loading model from checkpoints/ALE_SpaceInvaders-v5_ep6300.pth


  checkpoint_state = torch.load(latest_checkpoint, map_location=device)


Found best model at checkpoints/ALE_SpaceInvaders-v5_best_model.pth
Creating optimized replay buffer...
Calculated epsilon: 0.0100 based on estimated 1260000 total steps
Evaluating current model performance...
Current model evaluation score: 1416.0


  best_net.load_state_dict(torch.load(best_model_path, map_location=device))


Best model evaluation score: 907.0

Refilling replay buffer with 1000 episodes of intelligent play...
Refill episode 10/1000, Reward: 35.0, Buffer size: 8672
Refill episode 20/1000, Reward: 35.0, Buffer size: 16127
Refill episode 30/1000, Reward: 35.0, Buffer size: 22900
Refill episode 40/1000, Reward: 69.0, Buffer size: 30497
Refill episode 50/1000, Reward: 64.0, Buffer size: 38159
Refill episode 60/1000, Reward: 51.0, Buffer size: 47122
Refill episode 70/1000, Reward: 66.0, Buffer size: 56154
Refill episode 80/1000, Reward: 29.0, Buffer size: 63471
Refill episode 90/1000, Reward: 71.0, Buffer size: 71829
Refill episode 100/1000, Reward: 34.0, Buffer size: 78573
Refill episode 110/1000, Reward: 77.0, Buffer size: 87777
Refill episode 120/1000, Reward: 42.0, Buffer size: 94390
Refill episode 130/1000, Reward: 26.0, Buffer size: 101288
Refill episode 140/1000, Reward: 35.0, Buffer size: 107083
Refill episode 150/1000, Reward: 66.0, Buffer size: 116997
Refill episode 160/1000, Reward: 36

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


Updated target network at step 1280000
Episode 6350 | Reward   33.00 | Avg   36.54 | Eps 0.010 | LR 6.25e-05 | Buffer 845576 | Steps 1295831 | 174.1 steps/s | 5.7ms/step
Updated target network at step 1312000
! NEW BEST! Episode 6400 | Eval Score: 958.0 | Best: 958.0
Episode 6400 | Reward   33.00 | Avg   34.40 | Eps 0.010 | LR 6.25e-05 | Buffer 876120 | Steps 1326375 | 172.9 steps/s | 5.8ms/step
Updated target network at step 1344000
Episode 6450 | Reward   33.00 | Avg   33.99 | Eps 0.010 | LR 6.25e-05 | Buffer 910794 | Steps 1361049 | 171.9 steps/s | 5.8ms/step
Updated target network at step 1376000
Episode 6500 | Reward   34.00 | Avg   34.66 | Eps 0.010 | LR 6.25e-05 | Buffer 941772 | Steps 1392027 | 166.8 steps/s | 6.0ms/step
Updated target network at step 1408000
Episode 6550 | Reward   33.00 | Avg   34.35 | Eps 0.010 | LR 6.25e-05 | Buffer 976279 | Steps 1426534 | 172.1 steps/s | 5.8ms/step
Updated target network at step 1440000
! NEW BEST! Episode 6600 | Eval Score: 1107.0 | Best

KeyboardInterrupt: 

Now, let's try to adapt that for Pacman with minimal change!

In [None]:
import os
import ale_py
import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
from gymnasium.wrappers import AtariPreprocessing
import glob
import re
import numba
from numba import jit, prange
from typing import Tuple, Optional

# Try to import FrameStackObservation for newer gym versions
try:
    from gymnasium.wrappers import FrameStackObservation as FrameStack
    FS = "stack_size"
except ImportError:
    from gymnasium.wrappers import FrameStack
    FS = "num_stack"


def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@jit(nopython=True)
def update_segment_tree(tree: np.ndarray, capacity: int, idx: int, value: float):    
    idx += capacity
    tree[idx] = value
    while idx > 1:
        idx //= 2
        tree[idx] = tree[2 * idx] + tree[2 * idx + 1]

@jit(nopython=True)
def find_prefix_sum_idx(tree: np.ndarray, capacity: int, prefix_sum: float) -> int:
    idx = 1
    while idx < capacity:
        left = 2 * idx
        if tree[left] > prefix_sum:
            idx = left
        else:
            prefix_sum -= tree[left]
            idx = left + 1
    return idx - capacity

@jit(nopython=True, parallel=True)
def sample_indices_numba(tree: np.ndarray, capacity: int, batch_size: int, 
                         total_priority: float) -> np.ndarray:
    indices = np.empty(batch_size, dtype=np.int32)
    segment = total_priority / batch_size
    for i in prange(batch_size):
        a = segment * i
        b = segment * (i + 1)
        cumsum = np.random.uniform(a, b)
        indices[i] = find_prefix_sum_idx(tree, capacity, cumsum)
    return indices


class OptimizedPrioritizedReplayBuffer:
    """
    - Segment tree for O(log n) operations
    - Numba JIT compilation
    - Pre-allocated arrays
    - Vectorized operations
    - Direct GPU tensor creation
    """
    
    def __init__(self, capacity: int, device: torch.device, 
                 alpha: float = 0.6, state_shape: Tuple[int, ...] = (4, 84, 84)):
        self.capacity = capacity
        self.device = device
        self.alpha = alpha
        self.beta = 0.4
        self.eps = 1e-6
        self.state_shape = state_shape

        self.states = np.zeros((capacity, *state_shape), dtype=np.uint8)
        self.actions = np.zeros(capacity, dtype=np.int32)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8)
        self.dones = np.zeros(capacity, dtype=np.float32)

        tree_capacity = 1
        while tree_capacity < capacity:
            tree_capacity *= 2
        self.tree_capacity = tree_capacity
        self.sum_tree = np.zeros(2 * tree_capacity, dtype=np.float32)

        self.max_priority = 1.0
        self.ptr = 0
        self.size = 0

        self._preallocate_batch_tensors(32)
        
    def _preallocate_batch_tensors(self, batch_size: int):
        self.batch_size_allocated = batch_size
        self.batch_states = torch.zeros((batch_size, *self.state_shape), 
                                       dtype=torch.float32, device=self.device)
        self.batch_next_states = torch.zeros((batch_size, *self.state_shape), 
                                            dtype=torch.float32, device=self.device)
        self.batch_actions = torch.zeros(batch_size, dtype=torch.int64, device=self.device)
        self.batch_rewards = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
        self.batch_dones = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
        self.batch_weights = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
    
    def push(self, transition, td_error: torch.Tensor = None):        
        state, action, reward, next_state, done = transition

        idx = self.ptr
        self.states[idx] = state
        self.actions[idx] = action
        self.rewards[idx] = reward
        self.next_states[idx] = next_state
        self.dones[idx] = float(done)

        if td_error is None:
            priority = self.max_priority
        else:
            priority = (float(td_error.abs().cpu().item()) + self.eps) ** self.alpha

        update_segment_tree(self.sum_tree, self.tree_capacity, idx, priority)
        self.max_priority = max(self.max_priority, priority)
        # update pointers
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
    
    def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[torch.Tensor, ...]:        
        if batch_size > self.batch_size_allocated:
            self._preallocate_batch_tensors(int(batch_size * 1.5))

        total_priority = self.sum_tree[1]
        indices = sample_indices_numba(self.sum_tree, self.tree_capacity, batch_size, total_priority)
        indices = np.clip(indices, 0, self.size - 1)        
        priorities = np.array([self.sum_tree[idx + self.tree_capacity] for idx in indices], dtype=np.float32)

        probs = priorities / total_priority
        weights = (self.size * probs) ** (-beta)
        weights /= weights.max()
        batch_states = self.batch_states[:batch_size]
        batch_actions = self.batch_actions[:batch_size]
        batch_rewards = self.batch_rewards[:batch_size]
        batch_next_states = self.batch_next_states[:batch_size]
        batch_dones = self.batch_dones[:batch_size]
        batch_weights = self.batch_weights[:batch_size]

        # copy data to GPU
        batch_states.copy_(torch.from_numpy(self.states[indices]).float() / 255.0)
        batch_actions.copy_(torch.from_numpy(self.actions[indices]).long())
        batch_rewards.copy_(torch.from_numpy(self.rewards[indices]))
        batch_next_states.copy_(torch.from_numpy(self.next_states[indices]).float() / 255.0)
        batch_dones.copy_(torch.from_numpy(self.dones[indices]))
        batch_weights.copy_(torch.from_numpy(weights))

        return (
            batch_states,
            batch_actions,
            batch_rewards,
            batch_next_states,
            batch_dones,
            indices,
            batch_weights
        )
    
    def update_priorities(self, indices: np.ndarray, td_errors: torch.Tensor):
        td_errors_np = td_errors.detach().cpu().numpy()
        priorities = (np.abs(td_errors_np) + self.eps) ** self.alpha

        for idx, priority in zip(indices, priorities):
            update_segment_tree(self.sum_tree, self.tree_capacity, idx, priority)
        self.max_priority = max(self.max_priority, priorities.max())
    
    def __len__(self):
        return self.size


class NStepBuffer:
    def __init__(self, n: int, gamma: float):
        self.n = n
        self.gamma = gamma
        self.buffer = deque()

    def append(self, transition):
        self.buffer.append(transition)

    def is_ready(self):
        return len(self.buffer) >= self.n

    def pop(self):
        return self.buffer.popleft()

    def get_n_step_transition(self):
        R = 0.0
        for idx, (_, _, reward, _, done) in enumerate(self.buffer):
            R += (self.gamma ** idx) * reward
            if done:
                break
        state0, action0, _, _, _ = self.buffer[0]
        _, _, _, next_state_n, done_n = self.buffer[-1]
        return (state0, action0, R, next_state_n, done_n)

    def clear(self):
        self.buffer.clear()


class FastDuelingDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super().__init__()
        c, h, w = input_shape
        self.features = nn.Sequential(
            nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            flat_size = self.features(dummy).shape[1]
        self.value_stream = nn.Sequential(
            nn.Linear(flat_size, 512), nn.ReLU(),
            nn.Linear(512, 1)
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(flat_size, 512), nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        f = self.features(x)
        v = self.value_stream(f)
        a = self.advantage_stream(f)
        return v + (a - a.mean(dim=1, keepdim=True))


class DQNAgent:
    def __init__(self, network: nn.Module, device: torch.device):
        self.network = network.to(device)
        self.device = device

    def select_action(self, state_np):
        state = torch.from_numpy(state_np).unsqueeze(0).float().to(self.device) / 255.0
        with torch.no_grad():
            q = self.network(state)
        return int(q.argmax(1).item())


class RewardClippingWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        obs, reward, term, trunc, info = self.env.step(action)
        clipped = np.clip(reward, -1.0, 1.0)
        return obs, clipped, term, trunc, info


def find_latest_checkpoint(checkpoint_dir):
    pattern = os.path.join(checkpoint_dir, "ALE_MsPacman-v5_ep*.pth")
    files = glob.glob(pattern)
    if not files:
        return None, 0
    
    episodes = []
    for f in files:
        match = re.search(r'_ep(\d+)\.pth$', f)
        if match:
            episodes.append((int(match.group(1)), f))
    
    if episodes:
        episodes.sort(key=lambda x: x[0])
        return episodes[-1][1], episodes[-1][0]

    return None, 0


def resume_training_dqn(
    env_name="ALE/MsPacman-v5",
    num_episodes=10000,
    replay_buffer_capacity=1_000_000,
    batch_size=32,
    gamma=0.99,
    lr=6.25e-5,
    adam_eps=1.5e-4,
    n_steps=3,
    learning_freq=4,
    target_update_freq=32_000,
    epsilon_start=1.0,
    epsilon_final=0.01,
    epsilon_decay_steps=500_000,
    prio_alpha=0.6,
    beta_start=0.4,
    beta_frames=1_000_000,
    warmup_steps=50_000,
    report_interval=50,
    eval_interval=100,
    checkpoint_dir="checkpoints",
    seed=42,
    resume_from_checkpoint=True,
    buffer_refill_episodes=100,
):
    set_global_seed(seed)
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Find latest checkpoint
    latest_checkpoint, last_episode = find_latest_checkpoint(checkpoint_dir)
    is_resuming = latest_checkpoint is not None and resume_from_checkpoint
    
    if is_resuming:
        print(f"Found checkpoint: {latest_checkpoint} (Episode {last_episode})")
        
        estimated_steps_per_episode = 200
        total_steps = last_episode * estimated_steps_per_episode
    else:
        print("No checkpoint found or resume disabled. Starting fresh training.")
        last_episode = 0
        total_steps = 0
    
    # Environments
    base_env = gym.make("MsPacmanNoFrameskip-v4")
    atari_env = AtariPreprocessing(base_env, frame_skip=4, grayscale_obs=True, scale_obs=False, noop_max=30, terminal_on_life_loss=True)
    env = RewardClippingWrapper(FrameStack(atari_env, **{FS: 4}))
    eval_env = FrameStack(
        AtariPreprocessing(gym.make("MsPacmanNoFrameskip-v4"), frame_skip=4, grayscale_obs=True, scale_obs=False, noop_max=30, terminal_on_life_loss=False),
        **{FS: 4}
    )

    obs, _ = env.reset(seed=seed)
    obs_shape = env.observation_space.shape
    n_actions = env.action_space.n
    print(f"Observation shape: {obs_shape}, Actions: {n_actions}")

    policy_net = FastDuelingDQN(obs_shape, n_actions).to(device)
    target_net = FastDuelingDQN(obs_shape, n_actions).to(device)
    
    if is_resuming:
        # Load checkpoint
        print(f"Loading model from {latest_checkpoint}")
        checkpoint_state = torch.load(latest_checkpoint, map_location=device)
        policy_net.load_state_dict(checkpoint_state)
    
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    
    optimizer = optim.Adam(policy_net.parameters(), lr=lr, eps=adam_eps)

    # Check for best model
    best_model_path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_best_model.pth")
    best_eval_score = -float('inf')
    if os.path.exists(best_model_path):
        print(f"Found best model at {best_model_path}")

    # Create replay buffer
    print("Creating replay buffer...")
    replay = OptimizedPrioritizedReplayBuffer(
        replay_buffer_capacity, 
        device, 
        prio_alpha,
        state_shape=obs_shape
    )
    agent = DQNAgent(policy_net, device)

    # Calculate current epsilon
    epsilon = max(epsilon_final, epsilon_start - (epsilon_start - epsilon_final) * total_steps / epsilon_decay_steps)
    print(f"Current epsilon: {epsilon:.4f} based on {total_steps} total steps")

    episode_rewards = deque(maxlen=100)
    eval_scores = deque(maxlen=20)

    def evaluate_agent(n_eps=5):
        scores = []
        for _ in range(n_eps):
            o, _ = eval_env.reset()
            s = np.array(o, dtype=np.uint8)
            done = False
            score = 0
            while not done:
                a = agent.select_action(s)
                o2, r, t, tr, _ = eval_env.step(a)
                done = t or tr
                score += r
                s = np.array(o2, dtype=np.uint8)
            scores.append(score)
        return np.mean(scores)

    # Evaluate current performance
    if is_resuming:
        print("Evaluating current model performance...")
        current_eval = evaluate_agent()
        print(f"Current model evaluation score: {current_eval:.1f}")
        if os.path.exists(best_model_path):
            best_net = FastDuelingDQN(obs_shape, n_actions).to(device)
            best_net.load_state_dict(torch.load(best_model_path, map_location=device))
            best_agent = DQNAgent(best_net, device)
            temp_agent = agent
            agent = best_agent
            best_eval_score = evaluate_agent()
            agent = temp_agent
            print(f"Best model evaluation score: {best_eval_score:.1f}")
        else:
            best_eval_score = current_eval

    # Setup n-step buffer
    nbuf = NStepBuffer(n_steps, gamma)
    import time
    if is_resuming:
        print(f"\nRefilling replay buffer with {buffer_refill_episodes} episodes of realistic play...")
        fill_episodes = buffer_refill_episodes
        use_epsilon = epsilon
    else:
        print(f"\nWarming up replay buffer with random play (target: {warmup_steps} steps)...")
        fill_episodes = None 
        use_epsilon = 1.0

    cpu_start_time = time.time()
    warmup_total_steps = 0
    fill_ep = 0
    while True:       
        if is_resuming and fill_ep >= fill_episodes:
            break
        elif not is_resuming and warmup_total_steps >= warmup_steps:
            break

        o, _ = env.reset(seed=seed + last_episode + fill_ep)
        s = np.array(o, dtype=np.uint8)
        done = False
        ep_reward = 0
        nbuf.clear()        
        while not done:
            if random.random() < use_epsilon:
                a = env.action_space.sample()
            else:
                a = agent.select_action(s)

            o2, r, t, tr, _ = env.step(a)
            done = t or tr
            ep_reward += r
            s2 = np.array(o2, dtype=np.uint8)
            warmup_total_steps += 1

            # Store in n-step buffer
            nbuf.append((s, a, r, s2, done))
            if nbuf.is_ready():
                replay.push(nbuf.get_n_step_transition())
                nbuf.pop()
            if done:
                # Flush remaining transitions
                while len(nbuf.buffer) > 0:
                    replay.push(nbuf.get_n_step_transition())
                    nbuf.pop()

            s = s2
            if not is_resuming and warmup_total_steps >= warmup_steps:
                done = True

        fill_ep += 1
        
        if is_resuming and fill_ep % 10 == 0:
            print(f"Refill episode {fill_ep}/{buffer_refill_episodes}, "
                  f"Reward: {ep_reward:.1f}, Buffer size: {len(replay)}")
        elif not is_resuming and fill_ep % 10 == 0:
            print(f"Warmup episode {fill_ep}, Steps: {warmup_total_steps}/{warmup_steps}, "
                  f"Reward: {ep_reward:.1f}, Buffer size: {len(replay)}")

    cpu_fill_time = time.time() - cpu_start_time
    if is_resuming:
        print(f"Buffer refilled with {len(replay)} transitions in {cpu_fill_time:.1f} seconds")
    else:
        print(f"Warmup complete! Added {len(replay)} transitions in {cpu_fill_time:.1f} seconds")
        total_steps = warmup_total_steps

    # Continue/start training
    start_episode = last_episode + 1 if is_resuming else 1
    print(f"\n{'Resuming' if is_resuming else 'Starting'} training from episode {start_episode}...")

    # Track training performance and start training
    step_times = deque(maxlen=1000)
    for episode in range(start_episode, num_episodes + 1):
        o, _ = env.reset(seed=seed + episode)
        s = np.array(o, dtype=np.uint8)
        done = False
        ep_reward = 0
        nbuf.clear()
        episode_steps = 0
        while not done:
            step_start = time.time()
            total_steps += 1
            episode_steps += 1

            epsilon = max(epsilon_final, epsilon_start - (epsilon_start - epsilon_final) * total_steps / epsilon_decay_steps)
            if random.random() < epsilon:
                a = env.action_space.sample()
            else:
                a = agent.select_action(s)

            # Step
            o2, r, t, tr, _ = env.step(a)
            done = t or tr
            ep_reward += r
            s2 = np.array(o2, dtype=np.uint8)

            # Multi-step store
            nbuf.append((s, a, r, s2, done))
            if nbuf.is_ready():
                replay.push(nbuf.get_n_step_transition())
                nbuf.pop()            
            if done:
                # Flush remaining
                while len(nbuf.buffer) > 0:
                    replay.push(nbuf.get_n_step_transition())
                    nbuf.pop()
            s = s2

            # Learning step
            if total_steps % learning_freq == 0 and len(replay) >= batch_size:
                beta = min(1.0, beta_start + total_steps * (1.0 - beta_start) / beta_frames)
                s_b, a_b, r_b, s2_b, d_b, idxs, w_b = replay.sample(batch_size, beta)
                policy_net.train()
                with torch.no_grad():
                    next_a = policy_net(s2_b).argmax(1, keepdim=True)
                    next_q = target_net(s2_b).gather(1, next_a).squeeze(1)
                    target = r_b + (1 - d_b) * (gamma ** n_steps) * next_q
                current = policy_net(s_b).gather(1, a_b.unsqueeze(1)).squeeze(1)
                td_err = target - current
                loss = (w_b * F.smooth_l1_loss(current, target, reduction='none')).mean()

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 1.0)
                optimizer.step()
                
                replay.update_priorities(idxs, td_err.detach())

            # Target net update
            if total_steps % target_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())
                print(f"Updated target network at step {total_steps}")

            step_times.append(time.time() - step_start)

        episode_rewards.append(ep_reward)
        avg_reward = np.mean(episode_rewards)

        # Periodic evaluation
        if episode % eval_interval == 0:
            eval_score = evaluate_agent()
            eval_scores.append(eval_score)
            if eval_score > best_eval_score:
                best_eval_score = eval_score
                best_path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_best_model.pth")
                torch.save(policy_net.state_dict(), best_path)
                print(f"! NEW BEST! Episode {episode} | Eval Score: {eval_score:.1f} | Best: {best_eval_score:.1f}")

        # Logging and checkpointing
        if episode % report_interval == 0:
            path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_ep{episode}.pth")
            torch.save(policy_net.state_dict(), path)
            current_lr = optimizer.param_groups[0]['lr']
            avg_step_time = np.mean(step_times) * 1000 if step_times else 0
            steps_per_sec = 1.0 / np.mean(step_times) if step_times and np.mean(step_times) > 0 else 0

            print(f"Episode {episode:4d} | Reward {ep_reward:7.2f} | Avg {avg_reward:7.2f} | "
                  f"Eps {epsilon:.3f} | LR {current_lr:.2e} | Buffer {len(replay):6d} | "
                  f"Steps {total_steps} | {steps_per_sec:.1f} steps/s | {avg_step_time:.1f}ms/step")

    print(f"Training complete! Best eval score: {best_eval_score:.1f}")
    return agent


trained = resume_training_dqn(
    env_name="ALE/MsPacman-v5",
    num_episodes=50000,
    seed=42,
    buffer_refill_episodes=1000,
)

Using device: mps
No checkpoint found or resume disabled. Starting fresh training.


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 9
Creating replay buffer...
Current epsilon: 1.0000 based on 0 total steps

Warming up replay buffer with random play (target: 50000 steps)...
Warmup episode 10, Steps: 1775/50000, Reward: 8.0, Buffer size: 1775
Warmup episode 20, Steps: 3552/50000, Reward: 17.0, Buffer size: 3552
Warmup episode 30, Steps: 5478/50000, Reward: 12.0, Buffer size: 5478


Now, after months of work, R2D2 implementation! Note that I terminate training manually, which is fine because I can continue training from any point

In [None]:
import os
import ale_py
import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
from gymnasium.wrappers import AtariPreprocessing
import glob
import re
import numba
from numba import jit, prange
from typing import Tuple, Optional, List

# Try to import FrameStackObservation for newer gym versions
try:
    from gymnasium.wrappers import FrameStackObservation as FrameStack
    FS = "stack_size"
except ImportError:
    from gymnasium.wrappers import FrameStack
    FS = "num_stack"

def set_global_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Value rescaling function:
def value_transform(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
    # Forward transform is defined as h(x) = sign(x)*(sqrt(|x|+1)-1) + eps*x
    # It squashes large values while preserving sign and being invertible!
    return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1.0) - 1.0) + eps * x

# Inverse value rescaling function:
def inv_value_transform(y: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
    # Inverse transform is calculated as h^(-1)(y) = sign(y)*(z^2 - 1)
    # where z = (sqrt(1 + 4 * eps * (|y| + 1 + eps)) - 1) / (2 * eps)
    sign = torch.sign(y)
    y_abs = torch.abs(y)
    # solve quadratic for |x|: z = (sqrt(1+4*eps*(y_abs+1+eps)) -1)/(2*eps)
    z = (torch.sqrt(1.0 + 4.0 * eps * (y_abs + 1.0 + eps)) - 1.0) / (2.0 * eps)
    return sign * (z * z - 1.0)

# Vectorized n-step targets calculation (for sequences), takes rewards, dones and Q-values
def compute_n_step_targets_vectorized(
    rewards: torch.Tensor,      # (B, U)
    dones: torch.Tensor,        # (B, U) with elements in {0,1}
    q_bootstrap: torch.Tensor,  # (B, U) value for s_{t+n} aligned at t (zeros where invalid)
    gamma: float,
    n_steps: int
) -> torch.Tensor:
    """
    Vectorized n-step targets using sliding windows via as_strided.
    For each t:
      G_t^(n) = sum_{k=0}^{n-1} (gamma^k * r_{t+k} * Product_{j=0}^{k-1}(1-d_{t+j}))
                + (gamma^n * Product_{j=0}^{n-1}(1-d_{t+j})) * Q_{t+n}
    This assumes q_bootstrap is 0 wherever t+n exceeds the sequence.
    """
    B, U = rewards.shape
    device = rewards.device
    dtype = rewards.dtype

    not_done = (1.0 - dones).to(dtype)

    # Rewards windows (B, U, n_steps):
    # Right-pad rewards with zeros to make all windows valid.
    rewards_pad = F.pad(rewards, (0, n_steps - 1), value=0.0)  # (B, U + n - 1)
    s0, s1 = rewards_pad.stride()                              # (U+n-1, 1)
    r_win = rewards_pad.as_strided(
        size=(B, U, n_steps),
        stride=(s0, s1, s1)
    )  # (B, U, n_steps)

    # Survival windows via cumulative products:
    # cum[:, t] = Product_{j=0..t} not_done_j
    cum = torch.cumprod(not_done, dim=1)                        # (B, U)
    # cum_pad[:, t] = Product_{j=0..t-1} not_done_j   (leading 1)
    one = torch.ones(B, 1, device=device, dtype=dtype)
    cum_pad = torch.cat([one, cum], dim=1)                      # (B, U+1)
    # pad more 1s at the end so t+n is always addressable
    cum_pad = F.pad(cum_pad, (0, n_steps), value=1.0)           # (B, U+1+n)

    s0c, s1c = cum_pad.stride()
    # Window length n_steps+1 gives us indices t .. t+n
    cum_win = cum_pad.as_strided(
        size=(B, U, n_steps + 1),
        stride=(s0c, s1c, s1c)
    )                                                           # (B, U, n_steps+1)

    den = cum_win[:, :, 0:1]    # Product_{j=0..t-1} not_done_j    (B, U, 1)
    num_seq = cum_win[:, :, 1:] # Product_{j=0..t+k-1} not_done_j   (B, U, n_steps)

    # survival_at_k[:, :, 0] must be 1
    surv_k1 = num_seq / den.clamp_min(1e-12)                    # (B, U, n_steps)
    ones_k0 = torch.ones(B, U, 1, device=device, dtype=dtype)   # (B, U, 1)
    survival_at_k = torch.cat([ones_k0, surv_k1], dim=2)        # (B, U, n_steps+1) but we need first n_steps
    survival_at_k = survival_at_k[:, :, :n_steps]               # (B, U, n_steps)

    # Discounted reward sum
    gamma_powers = (gamma ** torch.arange(n_steps, device=device, dtype=dtype))  # (n_steps,)
    discounts = gamma_powers.view(1, 1, -1)                                      # (1,1,n_steps)

    returns = (r_win * discounts * survival_at_k).sum(dim=2)                     # (B, U)

    # Bootstrap over n steps
    # Product_{j=t..t+n-1} not_done_j = (Product_{j=0..t+n-1} / Product_{j=0..t-1})
    surv_n = cum_win[:, :, -1] / den.squeeze(2).clamp_min(1e-12)                 # (B, U)

    targets = returns + (gamma ** n_steps) * surv_n * q_bootstrap                # (B, U)
    return targets

from collections import OrderedDict
import struct

from episodic_replay_buffer import EpisodicReplayBuffer

from typing import Tuple

class LSTMCellUnroller(nn.Module):
    def __init__(self, cells: nn.ModuleList):
        super().__init__()
        self.cells = cells

    @torch.jit.export
    def forward(
        self,
        lstm_in: torch.Tensor, # (B, T, I)
        h0: torch.Tensor,      # (L, B, H)
        c0: torch.Tensor       # (L, B, H)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, T, _ = lstm_in.shape
        L, H    = h0.size(0), h0.size(2)

        h = list(h0.unbind(0))
        c = list(c0.unbind(0))
        out = torch.empty(B, T, H, dtype=lstm_in.dtype, device=lstm_in.device)

        for t in range(T):
            x = lstm_in[:, t, :]
            # TorchScript accepts enumerate over ModuleList
            for l, cell in enumerate(self.cells):
                h[l], c[l] = cell(x, (h[l], c[l]))
                x = h[l] # feed upward
            out[:, t, :] = h[-1]

        return out, torch.stack(h, 0), torch.stack(c, 0)

class RecurrentDuelingDQN(nn.Module):
    # R2D2-style recurrent DQN with CNN -> LSTM -> fully-connected architecture.
    def __init__(self, input_shape, num_actions, lstm_hidden_size=512, turn_off_lstm=True):
        super().__init__()
        c, h, w = input_shape
        self.num_actions = int(num_actions)
        self.lstm_hidden_size = int(lstm_hidden_size)
        self.turn_off_lstm = turn_off_lstm

        # CNN feature extraction
        self.features = nn.Sequential(
            nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
            nn.Flatten()
        )

        # Calculate CNN output size
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            cnn_output_size = self.features(dummy).shape[1]

        if self.turn_off_lstm:
            self.value_stream = nn.Sequential(
                nn.Linear(cnn_output_size, 512), nn.ReLU(),
                nn.Linear(512, 1)
            )
            self.advantage_stream = nn.Sequential(
                nn.Linear(cnn_output_size, 512), nn.ReLU(),
                nn.Linear(512, num_actions)
            )
            return

        # LSTM input = CNN features + one-hot previous action + previous reward
        lstm_input_size = int(cnn_output_size + num_actions + 1)
        # self.lstm = nn.LSTM(lstm_input_size, lstm_hidden_size, batch_first=True)
        self.num_layers = 1
        self.lstm_cells = nn.ModuleList(
            [nn.LSTMCell(lstm_input_size, lstm_hidden_size)]
        )

        # Dueling streams
        self.value_stream = nn.Sequential(
            nn.Linear(lstm_hidden_size, 512), nn.ReLU(),
            nn.Linear(512, 1)
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(lstm_hidden_size, 512), nn.ReLU(),
            nn.Linear(512, num_actions)
        )
        self._unroller = torch.jit.script(LSTMCellUnroller(self.lstm_cells))

    def forward(self, states, prev_actions, prev_rewards, hidden_state=None, out=None):
        """
        Args:
            states: (batch_size, seq_len, C, H, W) or (batch_size, C, H, W)
            prev_actions: (batch_size, seq_len) or (batch_size,)
            prev_rewards: (batch_size, seq_len) or (batch_size,)
            hidden_state: tuple of (h, c) each (1, batch_size, lstm_hidden_size) or None
            out: optional output tensor to write results to (to avoid allocation)
        """
        if states.dim() == 4: # Single step
            states = states.unsqueeze(1)
            prev_actions = prev_actions.unsqueeze(1)
            prev_rewards = prev_rewards.unsqueeze(1)
            single_step = True
        else:
            single_step = False
            
        batch_size, seq_len = states.shape[:2]
        
        # Process through CNN
        states_flat = states.reshape(-1, states.size(2), states.size(3), states.size(4))  # (B*T, C, H, W)
        cnn_features = self.features(states_flat) # (B*T, cnn_output_size)
        cnn_features = cnn_features.view(batch_size, seq_len, -1) # (B, T, cnn_output_size)
        
        if self.turn_off_lstm:
            # Dueling Q-values
            values = self.value_stream(cnn_features) # (B, T, 1)
            advantages = self.advantage_stream(cnn_features) # (B, T, num_actions)
            
            # Compute dueling Q-values with optional out parameter
            if out is not None:
                # Compute advantages - advantages.mean() in-place into out
                advantages_mean = advantages.mean(dim=-1, keepdim=True)
                torch.sub(advantages, advantages_mean, out=out)
                out.add_(values) # Add values in-place
                # Apply value transform in-place
                out.copy_(value_transform(out))
                if single_step:
                    out.squeeze_(1) # In-place squeeze
                q_values = out
            else:
                # Original allocation-based computation
                q_values = value_transform(values + (advantages - advantages.mean(dim=-1, keepdim=True)))
                if single_step:
                    q_values = q_values.squeeze(1) # (B, num_actions)
            
            return q_values, hidden_state if hidden_state else (torch.zeros((512)), torch.zeros((512)))
        
        # One-hot encode previous actions
        prev_actions_onehot = F.one_hot(prev_actions, self.num_actions).float() # (B, T, num_actions)
        
        # Prepare previous rewards
        prev_rewards = prev_rewards.unsqueeze(-1) # (B, T, 1)

        # Concatenate inputs for LSTM
        lstm_input = torch.cat([cnn_features, prev_actions_onehot, prev_rewards], dim=-1) # (B, T, lstm_input_size)

        # Tensors (L, B, H)
        if hidden_state is None:
            h0 = torch.zeros(self.num_layers, batch_size,
                            self.lstm_hidden_size, device=states.device)
            c0 = torch.zeros_like(h0)
        else:
            h0, c0 = hidden_state

        # Unroll
        lstm_out, h_final, c_final = self._unroller(lstm_input, h0, c0)
        new_hidden = (h_final, c_final)
        
        # Dueling Q-values
        values = self.value_stream(lstm_out) # (B, T, 1)
        advantages = self.advantage_stream(lstm_out) # (B, T, num_actions)

        # Compute dueling Q-values into out buffer
        if out is not None:
            # Check shape compatibility
            expected_shape = (batch_size, seq_len, self.num_actions) if not single_step else (batch_size, self.num_actions)
            if out.shape != expected_shape:
                raise ValueError(f"out tensor shape {out.shape} doesn't match expected shape {expected_shape}")
            
            # Compute advantages - advantages.mean() in-place into out
            advantages_mean = advantages.mean(dim=-1, keepdim=True)
            torch.sub(advantages, advantages_mean, out=out)
            out.add_(values) # Add values in-place
            # Apply value transform in-place
            out.copy_(value_transform(out))
            if single_step:
                out.squeeze_(1) # In-place squeeze
            q_values = out
        else:
            # Calculate into a new buffer
            q_values = value_transform(values + (advantages - advantages.mean(dim=-1, keepdim=True)))
            if single_step:
                q_values = q_values.squeeze(1) # (B, num_actions)
        
        return q_values, new_hidden


class RecurrentDQNAgent:
    def __init__(self, network: nn.Module, device: torch.device, num_actions: int):
        self.network = network.to(device)
        self.device = device
        self.num_actions = num_actions
        self.reset_hidden_state()

    def reset_hidden_state(self):
        self.hidden_state = None
        self.prev_action = 0  # Start with action 0
        self.prev_reward = 0.0

    def select_action(self, state_np, override_action=None):
        # Select action with optional action override for epsilon-greedy
        state = torch.from_numpy(state_np).unsqueeze(0).float().to(self.device) / 255.0
        prev_action = torch.tensor([self.prev_action], dtype=torch.int64, device=self.device)
        prev_reward = torch.tensor([self.prev_reward], dtype=torch.float32, device=self.device)

        with torch.no_grad():
            q_values, new_hidden = self.network(state, prev_action, prev_reward, self.hidden_state)
            greedy_action = int(q_values.argmax(1).item())

        # Use override action if provided (for epsilon-greedy), otherwise greedy
        action = override_action if override_action is not None else greedy_action

        # Always advance hidden state on the observation and set executed action
        self.hidden_state = new_hidden
        self.prev_action = action

        return action
    
    def update_prev_reward(self, reward: float):
        # Update the previous reward for next action selection
        self.prev_reward = reward


def find_latest_checkpoint(checkpoint_dir, game_name):
    pattern = os.path.join(checkpoint_dir, f"ALE_{game_name}-v5_ep*.pth")
    files = glob.glob(pattern)
    if not files:
        return None, 0
    episodes = []
    for f in files:
        match = re.search(r'_ep(\d+)\.pth$', f)
        if match:
            episodes.append((int(match.group(1)), f))
    if episodes:
        episodes.sort(key=lambda x: x[0])
        return episodes[-1][1], episodes[-1][0]
    return None, 0


def resume_training_optimized_r2d2(
    game_name="SpaceInvaders",
    num_episodes=10000,
    replay_buffer_capacity=100_000,
    sequence_length=80,
    burn_in_length=40,
    batch_size=64,
    gamma=0.997,
    lr=6.25e-5,
    adam_eps=1.5e-4,
    n_steps=5,
    learning_freq=4,
    target_update_freq=2500,
    prio_alpha=0.9,
    beta_fixed=0.6,
    beta_start=0.4,
    beta_frames=1_000_000,
    report_interval=10,
    eval_interval=100,
    checkpoint_dir="checkpoints_optimized",
    seed=42,
    resume_from_checkpoint=True,
    N_ACTORS = 256,
    stride = 40,
    warmup_size=5000,
    turn_off_lstm=False,
    clip_rewards=False,
    linear_eps=False,
    override_eps=False
):
    set_global_seed(seed)
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Find latest checkpoint
    latest_checkpoint, last_episode = find_latest_checkpoint(checkpoint_dir, game_name)
    is_resuming = latest_checkpoint is not None and resume_from_checkpoint

    if is_resuming:
        print(f"Found checkpoint: {latest_checkpoint} (Episode {last_episode})")
        estimated_steps_per_episode = 500
        total_steps = last_episode * estimated_steps_per_episode
    else:
        print("No checkpoint found or resume disabled. Starting fresh training.")
        last_episode = 0
        total_steps = 0

    update_count = 0  # Track learner updates separately
    last_target_update_printed = -1

    # Environments - using value rescaling instead of reward clipping
    torch.set_num_threads(1) # let ALE use the cores
    os.environ["OMP_NUM_THREADS"] = "1"

    actor_eps = [0.4 ** (1 + 7 * i/(N_ACTORS - 1)) for i in range(N_ACTORS)] if N_ACTORS > 1 else [0.3]
    if linear_eps:
        actor_eps = list(np.arange(0, 1., 1./N_ACTORS))
    print('Epsilons:', actor_eps)

    from gymnasium.vector import AsyncVectorEnv
    from gymnasium.wrappers import TransformReward
    def make_env(rank: int):
        def _thunk():
            base = gym.make(f"{game_name}NoFrameskip-v4")
            wrapped = AtariPreprocessing(
                base, frame_skip=4, grayscale_obs=True, scale_obs=False,
                noop_max=30, terminal_on_life_loss=False)
            wrapped = FrameStack(wrapped, **{FS: 4})
            if clip_rewards:
                wrapped = TransformReward(wrapped, lambda r: np.clip(r, -1.0, 1.0))
            wrapped.reset()
            return wrapped
        return _thunk
    import multiprocessing as mp

    env_fns = [make_env(i) for i in range(N_ACTORS)]

    # try the new API first, this depends on module versions
    try:
        env = AsyncVectorEnv(
            env_fns,
            shared_memory=False,
            copy=False,
            context="fork",
        )
    except TypeError:
        # fall back to the old API
        ctx = mp.get_context("fork")
        env = AsyncVectorEnv(
            env_fns,
            shared_memory=False,
            copy=False,
            ctx=ctx
        )
    env_name = f"ALE/{game_name}-v5"
    eval_env = FrameStack(
        AtariPreprocessing(gym.make(f"{game_name}NoFrameskip-v4"), frame_skip=4, grayscale_obs=True, scale_obs=False, noop_max=30, terminal_on_life_loss=False),
        **{FS: 4}
    )

    obs, _ = env.reset(seed=seed)
    obs_shape = eval_env.observation_space.shape
    n_actions = eval_env.action_space.n
    print(f"Observation shape: {obs_shape}, Actions: {n_actions}")

    policy_net = RecurrentDuelingDQN(obs_shape, n_actions, turn_off_lstm=turn_off_lstm).to(device)
    target_net = RecurrentDuelingDQN(obs_shape, n_actions, turn_off_lstm=turn_off_lstm).to(device)

    if is_resuming:
        print(f"Loading model from {latest_checkpoint}")
        checkpoint_state = torch.load(latest_checkpoint, map_location=device)
        policy_net.load_state_dict(checkpoint_state)

    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=lr, eps=adam_eps)

    # Check for best model
    best_model_path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_best_model.pth")
    best_eval_score = -float('inf')
    if os.path.exists(best_model_path):
        print(f"Found best model at {best_model_path}")

    # Create sequence replay buffer
    print(f"Creating sequence replay buffer...")
    replay = EpisodicReplayBuffer(
        seq_len=sequence_length,
        device=device,
        capacity_sequences=replay_buffer_capacity,
        alpha=prio_alpha,
        stride=stride
    )
    print("Pre-allocating scratch buffers for memory-efficient learning...")
    # Pre-allocate buffers that will be reused across all learning steps
    scratch_buffers = {
        # Sample batch buffers
        'sample_states': torch.zeros(batch_size, sequence_length, 4, 84, 84, device=device, dtype=torch.float32),  #  C,H,W
        'sample_next_states': torch.zeros(batch_size, sequence_length, 4, 84, 84, device=device, dtype=torch.float32),
        'sample_actions': torch.zeros(batch_size, sequence_length, device=device, dtype=torch.long),
        'sample_rewards': torch.zeros(batch_size, sequence_length, device=device, dtype=torch.float32),
        'sample_dones': torch.zeros(batch_size, sequence_length, device=device, dtype=torch.float32),
        'sample_hidden': torch.zeros(batch_size, 2, 512, device=device, dtype=torch.float32),
        'sample_weights': torch.zeros(batch_size, device=device, dtype=torch.float32),

        # Hidden state buffers
        'h0_policy': torch.zeros(1, batch_size, 512, device=device, dtype=torch.float32),
        'c0_policy': torch.zeros(1, batch_size, 512, device=device, dtype=torch.float32),
        'h0_target': torch.zeros(1, batch_size, 512, device=device, dtype=torch.float32),
        'c0_target': torch.zeros(1, batch_size, 512, device=device, dtype=torch.float32),

        # Action and reward buffers
        'prev_actions': torch.zeros(batch_size, sequence_length, device=device, dtype=torch.long),
        'prev_rewards': torch.zeros(batch_size, sequence_length, device=device, dtype=torch.float32),
        'prev_actions_unroll': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.long),
        'prev_rewards_unroll': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'prev_actions_next': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.long),
        'prev_rewards_next': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),

        # Q-value buffers
        'q_next_policy': torch.zeros(batch_size, sequence_length - burn_in_length, n_actions, device=device, dtype=torch.float32),
        'q_next_target': torch.zeros(batch_size, sequence_length - burn_in_length, n_actions, device=device, dtype=torch.float32),
        'next_actions': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.long),
        'q_next_selected': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'q_next_real': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'q_bootstrap': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'targets_real': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'targets': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'current_q_values': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'td_errors': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
        'td_errors_max': torch.zeros(batch_size, device=device, dtype=torch.float32),
        'td_errors_mean': torch.zeros(batch_size, device=device, dtype=torch.float32),
        'td_errors_for_priority': torch.zeros(batch_size, device=device, dtype=torch.float32),
        'per_step_loss': torch.zeros(batch_size, sequence_length - burn_in_length, device=device, dtype=torch.float32),
    }
    sample_out = {
        'states': scratch_buffers['sample_states'],
        'next_states': scratch_buffers['sample_next_states'], 
        'actions': scratch_buffers['sample_actions'],
        'rewards': scratch_buffers['sample_rewards'],
        'dones': scratch_buffers['sample_dones'],
        'hidden': scratch_buffers['sample_hidden'],
        'weights': scratch_buffers['sample_weights']
    }

    agents = [RecurrentDQNAgent(policy_net, device, n_actions)
              for _ in range(N_ACTORS)]

    episode_rewards = deque(maxlen=100)
    eval_scores = deque(maxlen=20)

    def evaluate_agent(n_eps=7):
        scores = []
        for _ in range(n_eps):
            eval_agent = RecurrentDQNAgent(policy_net, device, n_actions)
            o, info = eval_env.reset()
            s = np.array(o, dtype=np.uint8)
            eval_agent.reset_hidden_state()
            done = False
            score = 0
            lives = info["lives"]
            while not done:
                a = eval_agent.select_action(s)
                o2, r, t, tr, i = eval_env.step(a)
                eval_agent.update_prev_reward(r)
                if i["lives"] < lives:
                    # eval_agent.reset_hidden_state() # no longer resetting because agent learned across lives
                    lives = i["lives"]
                done = t or tr
                score += r
                s = np.array(o2, dtype=np.uint8)
            scores.append(score)
        return np.mean(scores)

    def build_actor_state(agents, idx_list, device, hidden_size=512):
        if not idx_list: # no live actors
            return None
        h_list, c_list = [], []
        for i in idx_list:
            if agents[i].hidden_state is None:
                h_list.append(torch.zeros(1, 1, hidden_size, device=device))
                c_list.append(torch.zeros(1, 1, hidden_size, device=device))
            else:
                h, c = agents[i].hidden_state
                h_list.append(h.detach()) # detach -> no graph kept
                c_list.append(c.detach())
        h0 = torch.cat(h_list, dim=1) # (1, k, H)
        c0 = torch.cat(c_list, dim=1) # (1, k, H)
        return (h0, c0)

    # Evaluate current performance if resuming
    if is_resuming:
        print("Evaluating current model performance...")
        current_eval = evaluate_agent()
        print(f"Current model evaluation score: {current_eval:.1f}")

        if os.path.exists(best_model_path):
            best_net = RecurrentDuelingDQN(obs_shape, n_actions, turn_off_lstm=turn_off_lstm).to(device)
            best_net.load_state_dict(torch.load(best_model_path, map_location=device))
            temp_net = policy_net
            policy_net = best_net
            best_eval_score = evaluate_agent()
            policy_net = temp_net
            print(f"Best model evaluation score: {best_eval_score:.1f}")
        else:
            best_eval_score = current_eval

    import time
    
    # Continue/start training
    start_episode = last_episode + 1 if is_resuming else 1
    print(f"\n{'Resuming' if is_resuming else 'Starting'} training from episode {start_episode}...")

    step_times = deque(maxlen=1000)
    obs, _ = env.reset(seed=seed + start_episode)
    dones = np.zeros(N_ACTORS, dtype=bool)

    # incremental compression buffers
    ep_frames   = [[] for _ in range(N_ACTORS)]
    ep_actions = [[] for _ in range(N_ACTORS)]
    ep_rewards = [[] for _ in range(N_ACTORS)]
    ep_dones   = [[] for _ in range(N_ACTORS)]
    ep_hiddens = [[] for _ in range(N_ACTORS)] # LSTM h,c for every frame
    prev_f_gpu = [None] * N_ACTORS
    frames_after_key = [0]*N_ACTORS
    total_parallel_steps = 0

    episode_counter = start_episode
    last_eval = -1
    last_report = -1
    last_loss = 1000.
    while episode_counter < num_episodes:
        if len(replay) >= warmup_size:
            total_steps += N_ACTORS
        total_parallel_steps += 1

        if override_eps:
            actor_eps[:] = [max(0.01, 1.0 - (1.0 - 0.01) * total_steps / 500_000)] * N_ACTORS
        # Choose actions
        step_start = time.time()

        live_idx    = [i for i, d in enumerate(dones) if not d]
        obs_batch_u8 = torch.from_numpy(obs).to(device)
        obs_batch   = obs_batch_u8[live_idx].float() / 255.0
        prev_a      = torch.tensor([agents[i].prev_action  for i in live_idx],
                                device=device)
        prev_r      = torch.tensor([agents[i].prev_reward for i in live_idx],
                                device=device, dtype=torch.float32)
        h_state     = build_actor_state(agents, live_idx, device)

        if live_idx:
            with torch.no_grad():
                q, new_h = policy_net(obs_batch, prev_a, prev_r, h_state)
            greedy = q.argmax(dim=1).tolist() # list of actions

        # Write results back into global action list and agents
        actions = [0] * N_ACTORS
        for slot, i in enumerate(live_idx):
            a = greedy[slot]
            if random.random() < actor_eps[i] or update_count < 3:
                a = env.single_action_space.sample() # random - keep new_h as is
            new_hidden_state = (new_h[0][:, slot:slot+1],
                                new_h[1][:, slot:slot+1])
            if agents[i].hidden_state is None:
                ep_hiddens[i].append((torch.zeros_like(new_hidden_state[0], device=device), torch.zeros_like(new_hidden_state[1], device=device)))
            else:
                ep_hiddens[i].append((agents[i].hidden_state[0].detach(), agents[i].hidden_state[1].detach()))
            agents[i].hidden_state = new_hidden_state
            agents[i].prev_action  = a # this is what matters later
            actions[i] = a

        # Environment step
        next_obs, rewards, terms, truncs, infos = env.step(actions)
        dones = np.logical_or(terms, truncs)

        # Store transitions and handle episode ends
        for i in range(N_ACTORS):
            agents[i].update_prev_reward(rewards[i])

            ep_actions[i].append(actions[i])
            ep_rewards[i].append(rewards[i])
            ep_dones[i].append(dones[i])
            ep_frames[i].append(obs_batch_u8[i])
            if dones[i]:
                frames_u8  = torch.stack(ep_frames[i], dim=0)
                episode_counter += 1
                episode_rewards.append(np.sum(ep_rewards[i]))
                actions_u8 = np.asarray(ep_actions[i], dtype=np.uint8)
                rewards_f32 = np.asarray(ep_rewards[i], dtype=np.float32)
                dones_u8 = np.asarray(ep_dones[i], dtype=np.uint8)
                replay.push_episode(
                    frames_u8=frames_u8,
                    actions_u8=actions_u8,
                    rewards=rewards_f32,
                    dones_u8=dones_u8,
                    hiddens=ep_hiddens[i]
                )
                if torch.backends.mps.is_available():
                    torch.mps.empty_cache()
                import gc
                gc.collect()
                # reset buffers, frame data and agent for this actor
                ep_frames[i].clear(); ep_actions[i].clear()
                ep_hiddens[i].clear()
                ep_rewards[i].clear(); ep_dones[i].clear()
                prev_f_gpu[i] = None
                frames_after_key[i] = 0
                agents[i].reset_hidden_state()

        del obs_batch_u8, obs_batch
        obs = next_obs
        step_times.append((time.time() - step_start) / N_ACTORS)

        if total_parallel_steps % learning_freq == 0 and len(replay) >= warmup_size:
            # Update beta based on learner updates instead of steps (ultimately not used, because I use fixed beta 0.6 as in R2D2)
            beta = min(1.0, beta_start + update_count * (1.0 - beta_start) / (beta_frames // learning_freq)) if not beta_fixed else beta_fixed
            
            # Sample sequences with batched GPU transfers
            (s_batch, a_batch, r_batch, s2_batch, d_batch, 
                h_batch, indices, w_batch) = replay.sample(batch_size, beta, out=sample_out)
            
            policy_net.train()
            
            # Reuse buffers for hidden states
            # Copy same hidden states into two separate buffers for policy and target, because policy and target networks diverge after burn-in.
            # Still we use pre-allocated buffers (no new allocations)
            scratch_buffers['h0_policy'].copy_(h_batch[:, 0, :].unsqueeze(0))
            scratch_buffers['c0_policy'].copy_(h_batch[:, 1, :].unsqueeze(0))
            scratch_buffers['h0_target'].copy_(h_batch[:, 0, :].unsqueeze(0)) 
            scratch_buffers['c0_target'].copy_(h_batch[:, 1, :].unsqueeze(0))
            hidden_state_policy = (scratch_buffers['h0_policy'], scratch_buffers['c0_policy'])
            hidden_state_target = (scratch_buffers['h0_target'], scratch_buffers['c0_target'])
            
            # Reuse buffers for actions/rewards
            # Zero out and fill previous action/reward buffers
            scratch_buffers['prev_actions'].zero_()
            scratch_buffers['prev_rewards'].zero_()
            # Shift actions and rewards: prev_action[t] = action[t-1], as LSTM needs prev_a, prev_r
            scratch_buffers['prev_actions'][:, 1:].copy_(a_batch[:, :-1])
            scratch_buffers['prev_rewards'][:, 1:].copy_(r_batch[:, :-1])
            
            # Burn-in phase
            with torch.no_grad():
                # Policy network burn-in
                _, hidden_after_burnin_temp = policy_net(
                    s_batch[:, :burn_in_length], 
                    scratch_buffers['prev_actions'][:, :burn_in_length], 
                    scratch_buffers['prev_rewards'][:, :burn_in_length], 
                    hidden_state_policy)
                
                # Copy burn-in results into our persistent buffers (detached)
                scratch_buffers['h0_policy'].copy_(hidden_after_burnin_temp[0].detach())
                scratch_buffers['c0_policy'].copy_(hidden_after_burnin_temp[1].detach())
                hidden_after_burnin_policy = (scratch_buffers['h0_policy'], scratch_buffers['c0_policy'])
                
                # Target network burn-in 
                _, hidden_after_burnin_temp = target_net(
                    s_batch[:, :burn_in_length], 
                    scratch_buffers['prev_actions'][:, :burn_in_length], 
                    scratch_buffers['prev_rewards'][:, :burn_in_length], 
                    hidden_state_target)
                
                # Copy target burn-in results
                scratch_buffers['h0_target'].copy_(hidden_after_burnin_temp[0].detach())
                scratch_buffers['c0_target'].copy_(hidden_after_burnin_temp[1].detach())
                hidden_after_burnin_target = (scratch_buffers['h0_target'], scratch_buffers['c0_target'])
            
            # Forward pass for unroll (with gradients)
            q_values, _ = policy_net(
                s_batch[:, burn_in_length:], 
                scratch_buffers['prev_actions'][:, burn_in_length:], 
                scratch_buffers['prev_rewards'][:, burn_in_length:], 
                hidden_after_burnin_policy,
                out=None  # Allocate results since we need gradients
            )
            
            # Prepare unroll data (reuse buffers)
            unroll_length = sequence_length - burn_in_length
            
            # Copy unroll portions into pre-allocated buffers
            scratch_buffers['prev_actions_unroll'].copy_(a_batch[:, burn_in_length:])
            scratch_buffers['prev_rewards_unroll'].copy_(r_batch[:, burn_in_length:])
            
            # Prepare previous actions/rewards for next states (reuse buffers)
            scratch_buffers['prev_actions_next'].zero_()
            scratch_buffers['prev_rewards_next'].zero_()
            scratch_buffers['prev_actions_next'][:, 1:].copy_(scratch_buffers['prev_actions_unroll'][:, :-1].float())
            scratch_buffers['prev_rewards_next'][:, 1:].copy_(scratch_buffers['prev_rewards_unroll'][:, :-1])
            # First timestep of next states uses last action from burn-in
            scratch_buffers['prev_actions_next'][:, 0].copy_(a_batch[:, burn_in_length].float())
            scratch_buffers['prev_rewards_next'][:, 0].copy_(r_batch[:, burn_in_length])

            # Target computation (reuse all buffers)
            with torch.no_grad():
                # Double DQN: get actions from policy network
                policy_net(s2_batch[:, burn_in_length:], 
                        scratch_buffers['prev_actions_next'], 
                        scratch_buffers['prev_rewards_next'], 
                        hidden_after_burnin_policy,
                        out=scratch_buffers['q_next_policy']) # In-place

                # Get actions
                torch.argmax(scratch_buffers['q_next_policy'], dim=-1, out=scratch_buffers['next_actions'])

                # Get Q-values from target network
                target_net(s2_batch[:, burn_in_length:], 
                        scratch_buffers['prev_actions_next'], 
                        scratch_buffers['prev_rewards_next'], 
                        hidden_after_burnin_target,
                        out=scratch_buffers['q_next_target'])

                # Gather Q-values for selected actions
                torch.gather(scratch_buffers['q_next_target'], -1, 
                            scratch_buffers['next_actions'].unsqueeze(-1), 
                            out=scratch_buffers['q_next_selected'])
                scratch_buffers['q_next_selected'].squeeze_(-1) # In-place squeeze

                # Apply inverse value transform
                scratch_buffers['q_next_real'].copy_(inv_value_transform(scratch_buffers['q_next_selected']))

                # Build Q_{t+n} by shifting for proper n-step bootstrap
                scratch_buffers['q_bootstrap'].zero_()
                if n_steps > 1:
                    # Shift by n_steps-1 to get Q(s_{t+n}) at position t
                    end_idx = unroll_length - (n_steps - 1)
                    if end_idx > 0:
                        scratch_buffers['q_bootstrap'][:, :end_idx].copy_(
                            scratch_buffers['q_next_real'][:, (n_steps-1):])
                else:
                    scratch_buffers['q_bootstrap'].copy_(scratch_buffers['q_next_real'])

                # Use vectorized n-step target computation
                scratch_buffers['targets_real'].copy_(
                    compute_n_step_targets_vectorized(
                        scratch_buffers['prev_rewards_unroll'], 
                        d_batch[:, burn_in_length:], 
                        scratch_buffers['q_bootstrap'], 
                        gamma, n_steps))

                # Apply value transform to targets
                scratch_buffers['targets'].copy_(value_transform(scratch_buffers['targets_real']))

            # Loss computation
            # Get current Q-values
            current_q = torch.gather(
                q_values, # (B, L, A)
                -1,
                scratch_buffers['prev_actions_unroll'].unsqueeze(-1)  # (B, L, 1)
            ).squeeze(-1) # -> (B, L)
            # scratch_buffers['current_q_values'].squeeze_(-1)  # In-place squeeze

            # Compute TD errors
            torch.sub(scratch_buffers['targets'], current_q.detach(), 
                    out=scratch_buffers['td_errors'])

            # Priority computation: TD-error stats
            torch.abs(scratch_buffers['td_errors'], out=scratch_buffers['td_errors'])  # In-place abs
            td_err_max, _ = torch.max(scratch_buffers['td_errors'], dim=1)

            td_err_mean = scratch_buffers['td_errors'].mean(dim=1) * 0.1

            # Combine max and mean
            torch.mul(td_err_max, 0.9, out=scratch_buffers['td_errors_for_priority'])
            scratch_buffers['td_errors_for_priority'].add_(td_err_mean) # already multiplied by 0.1
            
            # Loss computation
            #F.smooth_l1_loss(scratch_buffers['current_q_values'], scratch_buffers['targets'], 
            #                reduction='none', out=scratch_buffers['per_step_loss'])
            per_step_loss = F.smooth_l1_loss(
                current_q,                  # prediction that carries grad
                scratch_buffers['targets'], # target (no grad)
                reduction='none'
            )

            loss = (w_batch.unsqueeze(1) * per_step_loss).mean()
            last_loss = loss.detach().item() # Don't create new tensor

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            
            if update_count % 500 == 0:
                v, a = None, None
                for n, p in policy_net.named_parameters():
                    if p.grad is None: continue
                    if 'value_stream.2.weight' in n:
                        v = p.grad.abs().mean().item()
                    if 'advantage_stream.2.weight' in n:
                        a = p.grad.abs().mean().item()
                print(f'grad ‖value‖={v:.2e}  ‖advantage‖={a:.2e}')
            
            torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 40.)
            optimizer.step()
            
            update_count += 1
            
            # Update priorities in replay buffer
            replay.update_priorities(indices, scratch_buffers['td_errors_for_priority'].detach())
            
            # Clean up sampled data immediately
            del s_batch, a_batch, r_batch, s2_batch, d_batch, h_batch, w_batch, indices
            
            # Force MPS cache cleanup periodically
            if update_count % 100 == 0 and torch.backends.mps.is_available():
                torch.mps.empty_cache()
        # Target net update based on learner updates
        if update_count > 0 and update_count % target_update_freq == 0:
            if last_target_update_printed != update_count:
                target_net.load_state_dict(policy_net.state_dict())
                last_target_update_printed = update_count
                print(f"Updated target network at update {update_count}")

        avg_reward = np.mean(episode_rewards)

        # Periodic evaluation
        if episode_counter % eval_interval == 0 and last_eval != episode_counter:
            eval_score = evaluate_agent()
            last_eval = episode_counter
            eval_scores.append(eval_score)

            if eval_score > best_eval_score:
                best_eval_score = eval_score
                best_path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_best_model.pth")
                torch.save(policy_net.state_dict(), best_path)
                print(f"! NEW BEST! Episode {episode_counter} | Eval Score: {eval_score:.1f} | Best: {best_eval_score:.1f}")

        # Logging and checkpointing
        if episode_counter % report_interval == 0 and last_report != episode_counter:
            last_report = episode_counter
            path = os.path.join(checkpoint_dir, f"{env_name.replace('/', '_')}_ep{episode_counter}.pth")
            torch.save(policy_net.state_dict(), path)
            current_lr = optimizer.param_groups[0]['lr']
            avg_step_time = np.mean(step_times) * 1000 if step_times else 0
            steps_per_sec = 1.0 / np.mean(step_times) if step_times and np.mean(step_times) > 0 else 0
            print(f"Episode {episode_counter:4d} | Reward {episode_rewards[-1]:7.2f} | Avg {avg_reward:7.2f} | "
                  f"Loss {last_loss:.2f} | Updates {update_count} | LR {current_lr:.2e} | Sequences {len(replay):6d} | "
                  f"Steps {total_steps} | {steps_per_sec:.1f} steps/s | {avg_step_time:.1f}ms/step" + (f" | Epsilon {actor_eps[0]:.2f}" if override_eps else ""))

    print(f"Training complete! Best eval score: {best_eval_score:.1f}")
    return agents

In [None]:
# Pacman #1
trained_agent = resume_training_optimized_r2d2(
    game_name="MsPacman",
    num_episodes=50000000,
    replay_buffer_capacity=100_000,
    report_interval=10,
    eval_interval=1000,
    warmup_size=10,
    N_ACTORS=32,
    learning_freq=80,
    stride=40,
    sequence_length=120,
    burn_in_length=40,
    batch_size=64,
    gamma=0.997,
    n_steps=5,
    seed=42,
    lr=1e-4,
    adam_eps=1e-7,
    checkpoint_dir="checkpoints_pacman",
    target_update_freq=2500,
    turn_off_lstm=False,
    clip_rewards=False,
    linear_eps=False,
    override_eps=False,
    beta_fixed=0.6,
    prio_alpha=0.9,
)

Using device: mps
No checkpoint found or resume disabled. Starting fresh training.
Epsilons: [0.4, 0.32523896474908454, 0.2644509604776406, 0.21502439153162226, 0.17483577624386654, 0.14215851716664438, 0.11558872238386095, 0.09398489101199065, 0.07641887163698842, 0.06213598674626823, 0.05052261000754778, 0.041079803438191446, 0.033401881855833176, 0.027158983688656022, 0.022082899346339398, 0.017955548305154154, 0.014599609855676281, 0.011870904988001713, 0.009652202122331054, 0.007848180564539672, 0.006381335304936917, 0.0051886472207361665, 0.004218875626301115, 0.0034303568527583005, 0.002789214278777594, 0.002267902911232471, 0.0018440259875017097, 0.001499372757863661, 0.0012191361088513923, 0.0009912764148276346, 0.0008060042873468099, 0.0006553600000000003]


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 9
Creating optimized sequence replay buffer with delta compression...
Pre-allocating scratch buffers for memory-efficient learning...

Starting training from episode 1...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


grad ‖value‖=1.13e-03  ‖advantage‖=2.19e-05
Episode   10 | Reward  120.00 | Avg  165.56 | Loss 0.30 | Updates 1 | LR 1.00e-04 | Sequences     77 | Steps 1952 | 2139.0 steps/s | 0.5ms/step


  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


Episode   20 | Reward  280.00 | Avg  194.21 | Loss 0.27 | Updates 2 | LR 1.00e-04 | Sequences    175 | Steps 3520 | 1842.6 steps/s | 0.5ms/step
Episode   30 | Reward  340.00 | Avg  215.52 | Loss 0.34 | Updates 4 | LR 1.00e-04 | Sequences    290 | Steps 8192 | 1778.0 steps/s | 0.6ms/step
Episode   50 | Reward   60.00 | Avg  207.76 | Loss 0.34 | Updates 8 | LR 1.00e-04 | Sequences    485 | Steps 19008 | 1721.9 steps/s | 0.6ms/step
Episode   60 | Reward   90.00 | Avg  208.31 | Loss 0.23 | Updates 11 | LR 1.00e-04 | Sequences    599 | Steps 26272 | 1629.7 steps/s | 0.6ms/step
Episode   70 | Reward   80.00 | Avg  194.78 | Loss 0.02 | Updates 12 | LR 1.00e-04 | Sequences    700 | Steps 30368 | 1524.9 steps/s | 0.7ms/step
Episode   80 | Reward   60.00 | Avg  199.49 | Loss 0.16 | Updates 14 | LR 1.00e-04 | Sequences    822 | Steps 34432 | 1569.3 steps/s | 0.6ms/step
Episode   90 | Reward   60.00 | Avg  188.65 | Loss 0.18 | Updates 17 | LR 1.00e-04 | Sequences    925 | Steps 41888 | 1655.4 step

KeyboardInterrupt: 

In [None]:
trained_agent = resume_training_optimized_r2d2(
    game_name="Pong",
    num_episodes=50000000,
    replay_buffer_capacity=100_000,
    report_interval=50,#50,
    eval_interval=1000,
    warmup_size=10,#5000,
    N_ACTORS=32,#256,
    learning_freq=80,#1, #15,
    stride=40,
    sequence_length=120,
    burn_in_length=40,
    batch_size=64,#64
    gamma=0.997,#,#0.997,
    n_steps=5,# 10,
    seed=42,
    lr=1e-4,
    adam_eps=1e-7,
    checkpoint_dir="checkpoints_pong",
    target_update_freq=2500,
    turn_off_lstm=True,#False,#True,
    clip_rewards=False,#False,#True,
    linear_eps=False,#True,
    override_eps=False,#True,
    beta_fixed=0.6,#False,
    prio_alpha=0.9,
    # beta_start=0.4,
    # beta_frames=1_000_000
)

Using device: mps
Found checkpoint: checkpoints/ALE_Pong-v5_ep117700.pth (Episode 117700)
Epsilons: [0.4, 0.32523896474908454, 0.2644509604776406, 0.21502439153162226, 0.17483577624386654, 0.14215851716664438, 0.11558872238386095, 0.09398489101199065, 0.07641887163698842, 0.06213598674626823, 0.05052261000754778, 0.041079803438191446, 0.033401881855833176, 0.027158983688656022, 0.022082899346339398, 0.017955548305154154, 0.014599609855676281, 0.011870904988001713, 0.009652202122331054, 0.007848180564539672, 0.006381335304936917, 0.0051886472207361665, 0.004218875626301115, 0.0034303568527583005, 0.002789214278777594, 0.002267902911232471, 0.0018440259875017097, 0.001499372757863661, 0.0012191361088513923, 0.0009912764148276346, 0.0008060042873468099, 0.0006553600000000003]


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 6
Loading model from checkpoints/ALE_Pong-v5_ep117700.pth


  checkpoint_state = torch.load(latest_checkpoint, map_location=device)


Creating optimized sequence replay buffer with delta compression...
Pre-allocating scratch buffers for memory-efficient learning...
Evaluating current model performance...
Current model evaluation score: 21.0

Resuming training from episode 117701...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


grad ‖value‖=3.90e-04  ‖advantage‖=5.79e-06


  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


Episode 117750 | Reward   18.00 | Avg   -2.43 | Loss 0.00 | Updates 28 | LR 1.00e-04 | Sequences   1792 | Steps 58921808 | 1698.6 steps/s | 0.6ms/step
Episode 117800 | Reward   20.00 | Avg    7.81 | Loss 0.00 | Updates 66 | LR 1.00e-04 | Sequences   4033 | Steps 59017616 | 2051.8 steps/s | 0.5ms/step
Episode 117850 | Reward   21.00 | Avg   18.13 | Loss 0.00 | Updates 105 | LR 1.00e-04 | Sequences   6279 | Steps 59117552 | 2035.3 steps/s | 0.5ms/step
Episode 117900 | Reward   17.00 | Avg   18.32 | Loss 0.00 | Updates 144 | LR 1.00e-04 | Sequences   8530 | Steps 59218480 | 2185.1 steps/s | 0.5ms/step
Episode 117950 | Reward   18.00 | Avg   18.45 | Loss 0.00 | Updates 173 | LR 1.00e-04 | Sequences  10698 | Steps 59293648 | 1773.2 steps/s | 0.6ms/step
Episode 118000 | Reward   21.00 | Avg   18.80 | Loss 0.00 | Updates 210 | LR 1.00e-04 | Sequences  12852 | Steps 59386512 | 2024.7 steps/s | 0.5ms/step
Episode 118050 | Reward   21.00 | Avg   18.50 | Loss 0.00 | Updates 248 | LR 1.00e-04 | Se

KeyboardInterrupt: 

In [None]:
trained_agent = resume_training_optimized_r2d2(
    game_name="SpaceInvaders",
    num_episodes=50000000,
    replay_buffer_capacity=100_000,
    report_interval=50,#50,
    eval_interval=1000,
    warmup_size=10,#5000,
    N_ACTORS=32,#256,
    learning_freq=80,#1, #15,
    stride=40,
    sequence_length=120,
    burn_in_length=40,
    batch_size=64,#64
    gamma=0.997,#,#0.997,
    n_steps=5,# 10,
    seed=42,
    lr=1e-4,
    adam_eps=1e-7,
    checkpoint_dir="checkpoints_space_invaders",
    target_update_freq=2500,
    turn_off_lstm=False,#False,#True,
    clip_rewards=False,#False,#True,
    linear_eps=False,#True,
    override_eps=False,#True,
    beta_fixed=0.6,#False,
    prio_alpha=0.9
)

Using device: mps
Found checkpoint: checkpoints_space_invaders/ALE_SpaceInvaders-v5_ep266450.pth (Episode 266450)
Epsilons: [0.4, 0.32523896474908454, 0.2644509604776406, 0.21502439153162226, 0.17483577624386654, 0.14215851716664438, 0.11558872238386095, 0.09398489101199065, 0.07641887163698842, 0.06213598674626823, 0.05052261000754778, 0.041079803438191446, 0.033401881855833176, 0.027158983688656022, 0.022082899346339398, 0.017955548305154154, 0.014599609855676281, 0.011870904988001713, 0.009652202122331054, 0.007848180564539672, 0.006381335304936917, 0.0051886472207361665, 0.004218875626301115, 0.0034303568527583005, 0.002789214278777594, 0.002267902911232471, 0.0018440259875017097, 0.001499372757863661, 0.0012191361088513923, 0.0009912764148276346, 0.0008060042873468099, 0.0006553600000000003]


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 6
Loading model from checkpoints_space_invaders/ALE_SpaceInvaders-v5_ep266450.pth


  checkpoint_state = torch.load(latest_checkpoint, map_location=device)


Found best model at checkpoints_space_invaders/ALE_SpaceInvaders-v5_best_model.pth
Creating optimized sequence replay buffer with delta compression...
Pre-allocating scratch buffers for memory-efficient learning...
Evaluating current model performance...
Current model evaluation score: 2662.9


  best_net.load_state_dict(torch.load(best_model_path, map_location=device))


Best model evaluation score: 2657.1

Resuming training from episode 266451...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


grad ‖value‖=1.55e-03  ‖advantage‖=7.93e-05


  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


Episode 266500 | Reward 2500.00 | Avg 1035.92 | Loss 0.45 | Updates 29 | LR 1.00e-04 | Sequences   1320 | Steps 133298536 | 1820.6 steps/s | 0.5ms/step
Episode 266600 | Reward 2425.00 | Avg 1939.80 | Loss 0.42 | Updates 94 | LR 1.00e-04 | Sequences   5238 | Steps 133464904 | 2102.5 steps/s | 0.5ms/step
Episode 266650 | Reward 2295.00 | Avg 1938.50 | Loss 0.40 | Updates 128 | LR 1.00e-04 | Sequences   7194 | Steps 133552872 | 1855.8 steps/s | 0.5ms/step
Episode 266700 | Reward 1730.00 | Avg 2070.50 | Loss 0.43 | Updates 161 | LR 1.00e-04 | Sequences   9394 | Steps 133635560 | 1928.0 steps/s | 0.5ms/step
Episode 266750 | Reward 2810.00 | Avg 2153.95 | Loss 0.31 | Updates 197 | LR 1.00e-04 | Sequences  11478 | Steps 133729096 | 2172.8 steps/s | 0.5ms/step
Episode 266850 | Reward 1860.00 | Avg 2079.80 | Loss 0.44 | Updates 264 | LR 1.00e-04 | Sequences  15593 | Steps 133901032 | 1963.5 steps/s | 0.5ms/step
Episode 266900 | Reward 1845.00 | Avg 2086.70 | Loss 0.41 | Updates 300 | LR 1.00e-0

KeyboardInterrupt: 

In [None]:
trained_agent = resume_training_optimized_r2d2(
    game_name="SpaceInvaders",
    num_episodes=50000000,
    replay_buffer_capacity=100_000,
    report_interval=50,#50,
    eval_interval=1000,
    warmup_size=10,#5000,
    N_ACTORS=32,#256,
    learning_freq=80,#1, #15,
    stride=40,
    sequence_length=120,
    burn_in_length=40,
    batch_size=64,#64
    gamma=0.997,#,#0.997,
    n_steps=5,# 10,
    seed=42,
    lr=1e-4,
    adam_eps=1e-7,
    checkpoint_dir="checkpoints_space_invaders",
    target_update_freq=2500,
    turn_off_lstm=False,#False,#True,
    clip_rewards=False,#False,#True,
    linear_eps=False,#True,
    override_eps=False,#True,
    beta_fixed=0.6,#False,
    prio_alpha=0.9,
    # beta_start=0.4,
    # beta_frames=1_000_000
)

Using device: mps
Found checkpoint: checkpoints_space_invaders/ALE_SpaceInvaders-v5_ep275050.pth (Episode 275050)
Epsilons: [0.4, 0.32523896474908454, 0.2644509604776406, 0.21502439153162226, 0.17483577624386654, 0.14215851716664438, 0.11558872238386095, 0.09398489101199065, 0.07641887163698842, 0.06213598674626823, 0.05052261000754778, 0.041079803438191446, 0.033401881855833176, 0.027158983688656022, 0.022082899346339398, 0.017955548305154154, 0.014599609855676281, 0.011870904988001713, 0.009652202122331054, 0.007848180564539672, 0.006381335304936917, 0.0051886472207361665, 0.004218875626301115, 0.0034303568527583005, 0.002789214278777594, 0.002267902911232471, 0.0018440259875017097, 0.001499372757863661, 0.0012191361088513923, 0.0009912764148276346, 0.0008060042873468099, 0.0006553600000000003]


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 6
Loading model from checkpoints_space_invaders/ALE_SpaceInvaders-v5_ep275050.pth


  checkpoint_state = torch.load(latest_checkpoint, map_location=device)


Found best model at checkpoints_space_invaders/ALE_SpaceInvaders-v5_best_model.pth
Creating optimized sequence replay buffer with delta compression...
Pre-allocating scratch buffers for memory-efficient learning...
Evaluating current model performance...
Current model evaluation score: 2217.9


  best_net.load_state_dict(torch.load(best_model_path, map_location=device))


Best model evaluation score: 2860.7

Resuming training from episode 275051...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


grad ‖value‖=1.50e-03  ‖advantage‖=5.59e-05


  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


Episode 275100 | Reward 3020.00 | Avg  986.02 | Loss 0.42 | Updates 29 | LR 1.00e-04 | Sequences   1244 | Steps 137600232 | 1873.5 steps/s | 0.5ms/step
Episode 275150 | Reward 1835.00 | Avg 1609.49 | Loss 0.42 | Updates 67 | LR 1.00e-04 | Sequences   3458 | Steps 137696968 | 2052.6 steps/s | 0.5ms/step
Episode 275200 | Reward 2200.00 | Avg 2137.15 | Loss 0.42 | Updates 100 | LR 1.00e-04 | Sequences   5496 | Steps 137779816 | 2000.4 steps/s | 0.5ms/step
Episode 275250 | Reward 2495.00 | Avg 2132.90 | Loss 0.42 | Updates 133 | LR 1.00e-04 | Sequences   7680 | Steps 137866472 | 1910.8 steps/s | 0.5ms/step
Episode 275300 | Reward 2510.00 | Avg 2114.65 | Loss 0.39 | Updates 167 | LR 1.00e-04 | Sequences   9654 | Steps 137952360 | 1993.2 steps/s | 0.5ms/step
Episode 275350 | Reward 2725.00 | Avg 2085.60 | Loss 0.41 | Updates 202 | LR 1.00e-04 | Sequences  11724 | Steps 138041448 | 2093.9 steps/s | 0.5ms/step
Episode 275400 | Reward 2805.00 | Avg 2200.35 | Loss 0.38 | Updates 237 | LR 1.00e-0

KeyboardInterrupt: 

In [None]:
# Pacman #3
trained_agent = resume_training_optimized_r2d2(
    game_name="MsPacman",
    num_episodes=50000000,
    replay_buffer_capacity=100_000,
    report_interval=50,
    eval_interval=1000,
    warmup_size=10,
    N_ACTORS=32,
    learning_freq=80,
    stride=40,
    sequence_length=120,
    burn_in_length=40,
    batch_size=64,
    gamma=0.997,
    n_steps=5,
    seed=42,
    lr=1e-4,
    adam_eps=1e-7,
    checkpoint_dir="checkpoints_pacman",
    target_update_freq=2500,
    turn_off_lstm=False,
    clip_rewards=False,
    linear_eps=False,
    override_eps=False,
    beta_fixed=0.6,
    prio_alpha=0.9,
)

Using device: mps
Found checkpoint: checkpoints_pacman/ALE_MsPacman-v5_ep785130.pth (Episode 785130)
Epsilons: [0.4, 0.32523896474908454, 0.2644509604776406, 0.21502439153162226, 0.17483577624386654, 0.14215851716664438, 0.11558872238386095, 0.09398489101199065, 0.07641887163698842, 0.06213598674626823, 0.05052261000754778, 0.041079803438191446, 0.033401881855833176, 0.027158983688656022, 0.022082899346339398, 0.017955548305154154, 0.014599609855676281, 0.011870904988001713, 0.009652202122331054, 0.007848180564539672, 0.006381335304936917, 0.0051886472207361665, 0.004218875626301115, 0.0034303568527583005, 0.002789214278777594, 0.002267902911232471, 0.0018440259875017097, 0.001499372757863661, 0.0012191361088513923, 0.0009912764148276346, 0.0008060042873468099, 0.0006553600000000003]


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 9
Loading model from checkpoints_pacman/ALE_MsPacman-v5_ep785130.pth


  checkpoint_state = torch.load(latest_checkpoint, map_location=device)


Found best model at checkpoints_pacman/ALE_MsPacman-v5_best_model.pth
Creating optimized sequence replay buffer with delta compression...
Pre-allocating scratch buffers for memory-efficient learning...
Evaluating current model performance...
Current model evaluation score: 7848.7


  best_net.load_state_dict(torch.load(best_model_path, map_location=device))


Best model evaluation score: 14096.7

Resuming training from episode 785131...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


grad ‖value‖=6.65e-03  ‖advantage‖=4.85e-05


  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


Episode 785150 | Reward  180.00 | Avg  210.00 | Loss 1.21 | Updates 2 | LR 1.00e-04 | Sequences    184 | Steps 392569384 | 1833.9 steps/s | 0.5ms/step
Episode 785200 | Reward 12891.00 | Avg 3372.80 | Loss 0.68 | Updates 32 | LR 1.00e-04 | Sequences   1504 | Steps 392645960 | 2135.3 steps/s | 0.5ms/step
Episode 785250 | Reward 3400.00 | Avg 6037.46 | Loss 0.39 | Updates 60 | LR 1.00e-04 | Sequences   3240 | Steps 392718248 | 2003.6 steps/s | 0.5ms/step
Episode 785300 | Reward 6030.00 | Avg 7330.43 | Loss 0.71 | Updates 88 | LR 1.00e-04 | Sequences   4912 | Steps 392788328 | 1879.4 steps/s | 0.5ms/step
Episode 785350 | Reward 2520.00 | Avg 7535.12 | Loss 0.58 | Updates 117 | LR 1.00e-04 | Sequences   6694 | Steps 392862984 | 1861.4 steps/s | 0.5ms/step
Episode 785400 | Reward 3370.00 | Avg 7643.33 | Loss 0.73 | Updates 146 | LR 1.00e-04 | Sequences   8488 | Steps 392938920 | 1891.2 steps/s | 0.5ms/step
Episode 785450 | Reward 6370.00 | Avg 7310.13 | Loss 0.66 | Updates 175 | LR 1.00e-04 

KeyboardInterrupt: 

In [2]:
trained_agent = resume_training_optimized_r2d2(
    game_name="MsPacman",
    num_episodes=50000000,
    replay_buffer_capacity=100_000,
    report_interval=50,
    eval_interval=1000,
    warmup_size=10,
    N_ACTORS=32,
    learning_freq=80,
    stride=40,
    sequence_length=120,
    burn_in_length=40,
    batch_size=64,
    gamma=0.997,
    n_steps=5,
    seed=42,
    lr=1e-4,
    adam_eps=1e-7,
    checkpoint_dir="checkpoints_pacman",
    target_update_freq=2500,
    turn_off_lstm=False,
    clip_rewards=False,
    linear_eps=False,
    override_eps=False,
    beta_fixed=0.6,
    prio_alpha=0.9,
)

Using device: mps
Found checkpoint: checkpoints_pacman/ALE_MsPacman-v5_ep903320.pth (Episode 903320)
Epsilons: [0.4, 0.32523896474908454, 0.2644509604776406, 0.21502439153162226, 0.17483577624386654, 0.14215851716664438, 0.11558872238386095, 0.09398489101199065, 0.07641887163698842, 0.06213598674626823, 0.05052261000754778, 0.041079803438191446, 0.033401881855833176, 0.027158983688656022, 0.022082899346339398, 0.017955548305154154, 0.014599609855676281, 0.011870904988001713, 0.009652202122331054, 0.007848180564539672, 0.006381335304936917, 0.0051886472207361665, 0.004218875626301115, 0.0034303568527583005, 0.002789214278777594, 0.002267902911232471, 0.0018440259875017097, 0.001499372757863661, 0.0012191361088513923, 0.0009912764148276346, 0.0008060042873468099, 0.0006553600000000003]


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Observation shape: (4, 84, 84), Actions: 9
Loading model from checkpoints_pacman/ALE_MsPacman-v5_ep903320.pth


  checkpoint_state = torch.load(latest_checkpoint, map_location=device)


Found best model at checkpoints_pacman/ALE_MsPacman-v5_best_model.pth
Creating sequence replay buffer...
Pre-allocating scratch buffers for memory-efficient learning...
Evaluating current model performance...
Current model evaluation score: 10484.7


  best_net.load_state_dict(torch.load(best_model_path, map_location=device))


Best model evaluation score: 16446.4

Resuming training from episode 903321...


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


grad ‖value‖=7.28e-03  ‖advantage‖=1.88e-05


  episode_data = torch.load(filepath, map_location=self.device)
  torch.gather(scratch_buffers['q_next_target'], -1,


Episode 903350 | Reward  310.00 | Avg  220.34 | Loss 0.34 | Updates 2 | LR 1.00e-04 | Sequences    299 | Steps 451666144 | 1810.2 steps/s | 0.6ms/step
Episode 903400 | Reward 3890.00 | Avg 3198.24 | Loss 0.70 | Updates 34 | LR 1.00e-04 | Sequences   1626 | Steps 451746816 | 1998.8 steps/s | 0.5ms/step
Episode 903450 | Reward 11381.00 | Avg 7326.72 | Loss 0.47 | Updates 66 | LR 1.00e-04 | Sequences   3656 | Steps 451829600 | 1862.8 steps/s | 0.5ms/step
Episode 903500 | Reward 6860.00 | Avg 9537.22 | Loss 0.49 | Updates 100 | LR 1.00e-04 | Sequences   5680 | Steps 451916192 | 1793.2 steps/s | 0.6ms/step
Episode 903550 | Reward 13651.00 | Avg 9341.42 | Loss 0.74 | Updates 134 | LR 1.00e-04 | Sequences   7782 | Steps 452003776 | 1788.8 steps/s | 0.6ms/step
Episode 903600 | Reward 2100.00 | Avg 9268.14 | Loss 0.64 | Updates 169 | LR 1.00e-04 | Sequences   9899 | Steps 452093856 | 1836.6 steps/s | 0.5ms/step
Episode 903650 | Reward 10361.00 | Avg 8922.52 | Loss 0.46 | Updates 203 | LR 1.00e-

KeyboardInterrupt: 