## Preview


In [1]:
import os
import torch
import numpy as np
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing
import ale_py

try:
    from gymnasium.wrappers import FrameStackObservation as FrameStack
    FS = "stack_size"
except ImportError:
    from gymnasium.wrappers import FrameStack
    FS = "num_stack"

import torch.nn as nn
import torch.nn.functional as F

from typing import Tuple
class LSTMCellUnroller(nn.Module):
    def __init__(self, cells: nn.ModuleList):
        super().__init__()
        self.cells = cells

    @torch.jit.export
    def forward(
        self,
        lstm_in: torch.Tensor, # (B, T, I)
        h0: torch.Tensor,    # (L, B, H)
        c0: torch.Tensor     # (L, B, H)
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, T, _ = lstm_in.shape
        L, H    = h0.size(0), h0.size(2)

        h = list(h0.unbind(0))
        c = list(c0.unbind(0))
        out = torch.empty(B, T, H, dtype=lstm_in.dtype, device=lstm_in.device)

        for t in range(T):
            x = lstm_in[:, t, :]
            # TorchScript accepts enumerate over ModuleList
            for l, cell in enumerate(self.cells):
                h[l], c[l] = cell(x, (h[l], c[l]))
                x = h[l] # feed upward
            out[:, t, :] = h[-1]

        return out, torch.stack(h, 0), torch.stack(c, 0)

class RecurrentDuelingDQN(nn.Module):
    # R2D2-style recurrent DQN with CNN -> LSTM -> FC architecture.
    def __init__(self, input_shape, num_actions, lstm_hidden_size=512, turn_off_lstm=False):
        super().__init__()
        c, h, w = input_shape
        self.num_actions = int(num_actions)
        self.lstm_hidden_size = int(lstm_hidden_size)
        self.turn_off_lstm = turn_off_lstm

        # CNN feature extractor
        self.features = nn.Sequential(
            nn.Conv2d(c, 32, 8, 4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1), nn.ReLU(),
            nn.Flatten()
        )
        
        # Calculate CNN output size
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            cnn_output_size = self.features(dummy).shape[1]
        
        if self.turn_off_lstm:
            self.value_stream = nn.Sequential(
                nn.Linear(cnn_output_size, 512), nn.ReLU(),
                nn.Linear(512, 1)
            )
            self.advantage_stream = nn.Sequential(
                nn.Linear(cnn_output_size, 512), nn.ReLU(),
                nn.Linear(512, num_actions)
            )
            return
        
        # CNN features + one-hot previous action + previous reward
        lstm_input_size = int(cnn_output_size + num_actions + 1)
        self.num_layers = 1
        self.lstm_cells = nn.ModuleList(
            [nn.LSTMCell(lstm_input_size, lstm_hidden_size)]
        )

        self.value_stream = nn.Sequential(
            nn.Linear(lstm_hidden_size, 512), nn.ReLU(),
            nn.Linear(512, 1)
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(lstm_hidden_size, 512), nn.ReLU(),
            nn.Linear(512, num_actions)
        )
        self._unroller = torch.jit.script(LSTMCellUnroller(self.lstm_cells))

    def forward(self, states, prev_actions, prev_rewards, hidden_state=None, out=None):
        """
        Args:
            states: (batch_size, seq_len, C, H, W) or (batch_size, C, H, W)
            prev_actions: (batch_size, seq_len) or (batch_size,)
            prev_rewards: (batch_size, seq_len) or (batch_size,)
            hidden_state: tuple of (h, c) each (1, batch_size, lstm_hidden_size) or None
            out: optional output tensor to write results to (avoids allocation)
        """
        if states.dim() == 4: # Single step
            states = states.unsqueeze(1)
            prev_actions = prev_actions.unsqueeze(1)
            prev_rewards = prev_rewards.unsqueeze(1)
            single_step = True
        else:
            single_step = False
            
        batch_size, seq_len = states.shape[:2]
        
        # Process through CNN
        states_flat = states.reshape(-1, states.size(2), states.size(3), states.size(4)) # (B*T, C, H, W)
        cnn_features = self.features(states_flat) # (B*T, cnn_output_size)
        cnn_features = cnn_features.view(batch_size, seq_len, -1) # (B, T, cnn_output_size)
        
        if self.turn_off_lstm:
            # Dueling Q-values
            values = self.value_stream(cnn_features) # (B, T, 1)
            advantages = self.advantage_stream(cnn_features) # (B, T, num_actions)
            
            # Compute dueling Q-values with optional out parameter
            if out is not None:
                # Compute advantages - advantages.mean() in-place into out
                advantages_mean = advantages.mean(dim=-1, keepdim=True)
                torch.sub(advantages, advantages_mean, out=out)
                torch.add_(out, values) # Add values in-place
                out.copy_(out)
                if single_step:
                    out.squeeze_(1) # In-place squeeze
                q_values = out
            else:
                # Original allocation-based computation
                q_values = (values + (advantages - advantages.mean(dim=-1, keepdim=True)))
                if single_step:
                    q_values = q_values.squeeze(1) # (B, num_actions)
            
            return q_values, hidden_state if hidden_state else (torch.zeros((512)), torch.zeros((512)))
        
        # One-hot encode previous actions
        prev_actions_onehot = F.one_hot(prev_actions, self.num_actions).float() # (B, T, num_actions)
        
        # Prepare previous rewards
        prev_rewards = prev_rewards.unsqueeze(-1) # (B, T, 1)

        # Concatenate inputs for LSTM
        lstm_input = torch.cat([cnn_features, prev_actions_onehot, prev_rewards], dim=-1) # (B, T, lstm_input_size)

        # Tensors (L, B, H)
        if hidden_state is None:
            h0 = torch.zeros(self.num_layers, batch_size,
                            self.lstm_hidden_size, device=states.device)
            c0 = torch.zeros_like(h0)
        else:
            h0, c0 = hidden_state

        # Unroll
        lstm_out, h_final, c_final = self._unroller(lstm_input, h0, c0)
        new_hidden = (h_final, c_final)
        
        # Dueling Q-values
        values = self.value_stream(lstm_out) # (B, T, 1)
        advantages = self.advantage_stream(lstm_out) # (B, T, num_actions)

        # Compute dueling Q-values into out buffer
        if out is not None:
            # Check shape compatibility
            expected_shape = (batch_size, seq_len, self.num_actions) if not single_step else (batch_size, self.num_actions)
            if out.shape != expected_shape:
                raise ValueError(f"out tensor shape {out.shape} doesn't match expected shape {expected_shape}")
            
            # Compute advantages - advantages.mean() in-place into out
            advantages_mean = advantages.mean(dim=-1, keepdim=True)
            torch.sub(advantages, advantages_mean, out=out)
            out.add_(values) # Add values in-place
            
            # Apply value transform in-place
            if hasattr(self, '_use_value_transform') and self._use_value_transform:
                out.copy_(out)
            
            if single_step:
                out.squeeze_(1) # In-place squeeze
            q_values = out
        else:
            # Calculate into a new buffer
            q_values = (values + (advantages - advantages.mean(dim=-1, keepdim=True)))
            if single_step:
                q_values = q_values.squeeze(1) # (B, num_actions)
        
        return q_values, new_hidden

class RecurrentDQNAgent:
    def __init__(self, network: nn.Module, device: torch.device, num_actions: int):
        self.network = network.to(device)
        self.device = device
        self.num_actions = num_actions
        self.reset_hidden_state()

    def reset_hidden_state(self):
        self.hidden_state = None
        self.prev_action = 0 # Start with action 0
        self.prev_reward = 0.0

    def select_action(self, state_np, override_action=None):
        # Select action with optional action override for epsilon-greedy.
        state = torch.from_numpy(state_np).unsqueeze(0).float().to(self.device) / 255.0
        prev_action = torch.tensor([self.prev_action], dtype=torch.int64, device=self.device)
        prev_reward = torch.tensor([self.prev_reward], dtype=torch.float32, device=self.device)
        
        with torch.no_grad():
            q_values, new_hidden = self.network(state, prev_action, prev_reward, self.hidden_state)
            greedy_action = int(q_values.argmax(dim=1).item())
            
        # Use override action if provided (for epsilon-greedy), otherwise greedy
        action = override_action if override_action is not None else greedy_action
        
        # Always advance hidden state on the observation and set executed action
        self.hidden_state = new_hidden
        self.prev_action = action
        
        return action
    
    def update_prev_reward(self, reward: float):
        # Update the previous reward for next action selection
        self.prev_reward = reward


def _device():
    return torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda") if torch.cuda.is_available() else "cpu"

def preview_recurrent_model(
    checkpoint_path: str,
    num_episodes: int = 5,
    render: bool = True,
    epsilon: float = 0.0,
    game_name: str = 'MsPacman'
):
    device = _device()
    env_name = f"{game_name}NoFrameskip-v4"

    print(f"Using device: {device}")
    print(f"Loading model from: {checkpoint_path}")

    base_env = gym.make(env_name, render_mode="human" if render else None)
    atari_env = AtariPreprocessing(
        base_env,
        frame_skip=4,
        grayscale_obs=True,
        scale_obs=False,
        noop_max=30,
        terminal_on_life_loss=False
    )
    env = FrameStack(atari_env, **{FS: 4})

    obs_shape = env.observation_space.shape
    n_actions = env.action_space.n
    action_meanings = env.unwrapped.get_action_meanings()
    action_counts = {i: 0 for i in range(n_actions)}

    print(f"Observation shape: {obs_shape}, Actions: {n_actions}")
    print("Action meanings:", action_meanings)

    net = RecurrentDuelingDQN(obs_shape, n_actions).to(device)
    if not os.path.exists(checkpoint_path):
        print(f"✗ Checkpoint not found: {checkpoint_path}")
        env.close()
        return

    state_dict = torch.load(checkpoint_path, map_location=device)
    net.load_state_dict(state_dict)
    net.eval()

    agent = RecurrentDQNAgent(net, device, n_actions)

    episode_rewards, episode_lengths, episode_lives = [], [], []

    obs, _ = env.reset(seed=100500)
    s = np.array(obs, dtype=np.uint8)
    lives = 3
    for ep in range(1, num_episodes + 1):
        if lives == 0:
            obs, i = env.reset(seed=100500 + ep * 23917)
            lives = i['lives']
        done = False
        total_reward, steps = 0.0, 0

        print(f"\n=== Episode {ep} ===")
        while not done:
            if np.random.random() < epsilon:
                a = env.action_space.sample()
                agent.select_action(s, override_action=a)
            else:
                a = agent.select_action(s)

            action_counts[a] += 1
            obs2, r, term, trunc, i = env.step(a)

            agent.update_prev_reward(r)
            if i["lives"] == 0:
                lives = i["lives"]
                print("Lost last life! Left:", lives)
                agent.reset_hidden_state()

            s = np.array(obs2, dtype=np.uint8)
            total_reward += r
            steps += 1
            done = term or trunc

            if steps % 200 == 0:
                try:
                    print(f"Steps: {steps}, Lives: {lives}, Reward: {total_reward:.2f}")
                except Exception:
                    print(f"Steps: {steps}, Reward: {total_reward:.2f}")

        final_lives = lives

        episode_rewards.append(total_reward)
        episode_lengths.append(steps)
        episode_lives.append(final_lives)

        print(f"Episode {ep} finished: Reward={total_reward:.2f}, "
              f"Lives={final_lives}, Steps={steps}")

    env.close()

    print("\n" + "=" * 40)
    print(f"SUMMARY OVER {num_episodes} EPISODES")
    print("=" * 40)
    print(f"Avg Reward: {np.mean(episode_rewards):.2f}, "
          f"Max: {np.max(episode_rewards):.2f}, Min: {np.min(episode_rewards):.2f}")
    print(f"Avg Steps: {np.mean(episode_lengths):.1f}")

    print("\nAction Usage:")
    total_actions = sum(action_counts.values())
    for i, cnt in action_counts.items():
        pct = cnt / total_actions * 100 if total_actions else 0.0
        print(f"  {action_meanings[i]}: {cnt} ({pct:.1f}%)")

    return {
        "rewards": episode_rewards,
        "lengths": episode_lengths,
        "lives": episode_lives,
        "action_counts": action_counts,
    }

def find_latest_checkpoint(checkpoint_dir="checkpoints", pattern="MsPacman"):
    if not os.path.exists(checkpoint_dir):
        return None
    files = [f for f in os.listdir(checkpoint_dir)
             if f.endswith(".pth") and pattern in f]
    if not files:
        return None
    files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
    return os.path.join(checkpoint_dir, files[0])

def find_specific_checkpoint(checkpoint_dir="checkpoints", number=20400, pattern="MsPacman"):
    if not os.path.exists(checkpoint_dir):
        return None
    files = [f for f in os.listdir(checkpoint_dir)
             if f.endswith(".pth") and pattern in f and (('ep' + str(number)) in f)]
    if not files:
        return None
    files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
    return os.path.join(checkpoint_dir, files[0])

### Preview Ms. Pac-Man

In [5]:
game_name = "MsPacman"
latest = find_latest_checkpoint(checkpoint_dir="checkpoints", pattern=game_name)
specific = find_specific_checkpoint(checkpoint_dir="checkpoints", pattern=game_name, number=747560)
checkpoint = specific or os.path.join("checkpoints_pacman", f"ALE_{game_name}-v5_best_model.pth")
print(f"Using checkpoint: {checkpoint}")

preview_recurrent_model(
    checkpoint_path=checkpoint,
    num_episodes=1,
    render=True,
    epsilon=0.00,
    game_name=game_name
)


Using checkpoint: checkpoints_pacman/ALE_MsPacman-v5_best_model.pth
Using device: mps
Loading model from: checkpoints_pacman/ALE_MsPacman-v5_best_model.pth
Observation shape: (4, 84, 84), Actions: 9
Action meanings: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT']


  state_dict = torch.load(checkpoint_path, map_location=device)



=== Episode 1 ===
Steps: 200, Lives: 3, Reward: 3340.00
Steps: 400, Lives: 3, Reward: 5230.00
Steps: 600, Lives: 3, Reward: 7020.00
Steps: 800, Lives: 3, Reward: 10341.00
Steps: 1000, Lives: 3, Reward: 10511.00
Steps: 1200, Lives: 3, Reward: 10691.00
Steps: 1400, Lives: 3, Reward: 14441.00
Steps: 1600, Lives: 3, Reward: 16601.00
Steps: 1800, Lives: 3, Reward: 17531.00
Steps: 2000, Lives: 3, Reward: 17791.00
Steps: 2200, Lives: 3, Reward: 17901.00
Steps: 2400, Lives: 3, Reward: 18601.00
Steps: 2600, Lives: 3, Reward: 19581.00
Steps: 2800, Lives: 3, Reward: 19821.00
Lost last life! Left: 0
Episode 1 finished: Reward=19821.00, Lives=0, Steps=2825

SUMMARY OVER 1 EPISODES
Avg Reward: 19821.00, Max: 19821.00, Min: 19821.00
Avg Steps: 2825.0

Action Usage:
  NOOP: 42 (1.5%)
  UP: 169 (6.0%)
  RIGHT: 35 (1.2%)
  LEFT: 626 (22.2%)
  DOWN: 613 (21.7%)
  UPRIGHT: 912 (32.3%)
  UPLEFT: 140 (5.0%)
  DOWNRIGHT: 272 (9.6%)
  DOWNLEFT: 16 (0.6%)


{'rewards': [19821.0],
 'lengths': [2825],
 'lives': [0],
 'action_counts': {0: 42,
  1: 169,
  2: 35,
  3: 626,
  4: 613,
  5: 912,
  6: 140,
  7: 272,
  8: 16}}

### Preview Space Invaders

In [6]:
game_name = "SpaceInvaders"
latest = find_latest_checkpoint(checkpoint_dir="checkpoints", pattern=game_name)
specific = find_specific_checkpoint(checkpoint_dir="checkpoints", pattern=game_name, number=747560)
checkpoint = specific or os.path.join("checkpoints_space_invaders", f"ALE_{game_name}-v5_best_model.pth")
print(f"Using checkpoint: {checkpoint}")

preview_recurrent_model(
    checkpoint_path=checkpoint,
    num_episodes=1,
    render=True,
    epsilon=0.00,
    game_name=game_name
)


Using checkpoint: checkpoints_space_invaders/ALE_SpaceInvaders-v5_best_model.pth
Using device: mps
Loading model from: checkpoints_space_invaders/ALE_SpaceInvaders-v5_best_model.pth
Observation shape: (4, 84, 84), Actions: 6
Action meanings: ['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']


  state_dict = torch.load(checkpoint_path, map_location=device)



=== Episode 1 ===
Steps: 200, Lives: 3, Reward: 200.00
Steps: 400, Lives: 3, Reward: 465.00
Steps: 600, Lives: 3, Reward: 700.00
Steps: 800, Lives: 3, Reward: 1025.00
Steps: 1000, Lives: 3, Reward: 1270.00
Steps: 1200, Lives: 3, Reward: 1590.00
Steps: 1400, Lives: 3, Reward: 1905.00
Steps: 1600, Lives: 3, Reward: 2295.00
Steps: 1800, Lives: 3, Reward: 2595.00
Lost last life! Left: 0
Episode 1 finished: Reward=2820.00, Lives=0, Steps=1918

SUMMARY OVER 1 EPISODES
Avg Reward: 2820.00, Max: 2820.00, Min: 2820.00
Avg Steps: 1918.0

Action Usage:
  NOOP: 190 (9.9%)
  FIRE: 290 (15.1%)
  RIGHT: 274 (14.3%)
  LEFT: 252 (13.1%)
  RIGHTFIRE: 576 (30.0%)
  LEFTFIRE: 336 (17.5%)


{'rewards': [2820.0],
 'lengths': [1918],
 'lives': [0],
 'action_counts': {0: 190, 1: 290, 2: 274, 3: 252, 4: 576, 5: 336}}

### Preview CartPole

In [3]:
import os
import torch
import numpy as np
import gymnasium as gym
import torch.nn as nn

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU(),
            nn.Linear(64, action_dim)
        )
    def forward(self, x):
        return self.net(x)

class DQNAgent:
    def __init__(self, policy_net: nn.Module, device: torch.device):
        self.policy_net = policy_net.to(device).eval()
        self.device = device
    def predict(self, obs):
        with torch.no_grad():
            obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
            q_values = self.policy_net(obs_tensor)
            action = int(q_values.argmax(dim=1).item())
        return action, None

def _device():
    return (torch.device("mps") if torch.backends.mps.is_available()
            else torch.device("cuda") if torch.cuda.is_available()
            else torch.device("cpu"))

def find_latest_checkpoint(checkpoint_dir="checkpoints_cartpole", pattern="CartPole-v1"):
    if not os.path.exists(checkpoint_dir):
        return None
    files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth") and pattern in f]
    if not files:
        return None
    files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
    return os.path.join(checkpoint_dir, files[0])

def find_specific_checkpoint(checkpoint_dir="checkpoints_cartpole", number=10000, pattern="CartPole-v1"):
    if not os.path.exists(checkpoint_dir):
        return None
    tag = f"ep{number}"
    files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth") and pattern in f and tag in f]
    if not files:
        return None
    files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
    return os.path.join(checkpoint_dir, files[0])

def preview_cartpole(
    checkpoint_path: str = None,
    num_episodes: int = 3,
    render: bool = True,
    epsilon: float = 0.0
):
    env_id = "CartPole-v1"
    device = _device()
    print(f"Using device: {device}")

    if checkpoint_path is None:
        latest = find_latest_checkpoint("checkpoints_cartpole", env_id)
        best   = os.path.join("checkpoints_cartpole", f"{env_id}_best_model.pth")
        checkpoint_path = latest if latest is not None else best

    print(f"Loading model from: {checkpoint_path}")
    if not os.path.exists(checkpoint_path):
        print("Checkpoint not found.")
        return

    env = gym.make(env_id, render_mode="human" if render else None)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    print(f"Observation dim: {state_dim}, Actions: {action_dim}")

    net = DQN(state_dim, action_dim).to(device)
    state_dict = torch.load(checkpoint_path, map_location=device)
    net.load_state_dict(state_dict)
    net.eval()
    agent = DQNAgent(net, device)

    episode_rewards, episode_lengths = [], []

    for ep in range(1, num_episodes+1):
        state, _ = env.reset(seed=100500 + ep)
        done, steps, total_reward = False, 0, 0.0
        while not done:
            if np.random.random() < epsilon:
                action = env.action_space.sample()
            else:
                action, _ = agent.predict(state)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward
            steps += 1
        episode_rewards.append(total_reward)
        episode_lengths.append(steps)
        print(f"Episode {ep}: reward={total_reward:.2f}, steps={steps}")

    env.close()
    print("\n==== SUMMARY ====")
    print(f"Avg Reward: {np.mean(episode_rewards):.2f} "
          f"(max {np.max(episode_rewards):.2f}, min {np.min(episode_rewards):.2f})")
    print(f"Avg Steps: {np.mean(episode_lengths):.1f}")

env_id = "CartPole-v1"
specific = None# find_specific_checkpoint("checkpoints_cartpole", number=290, pattern=env_id)
latest = None # find_latest_checkpoint("checkpoints_cartpole", pattern=env_id)
checkpoint = specific or latest or os.path.join("checkpoints_cartpole", f"{env_id}_best_model.pth")
print(f"Using checkpoint: {checkpoint}")

preview_cartpole(
    checkpoint_path=checkpoint,
    num_episodes=3,
    render=True,
    epsilon=0.0
)


Using checkpoint: checkpoints_cartpole/CartPole-v1_best_model.pth
Using device: mps
Loading model from: checkpoints_cartpole/CartPole-v1_best_model.pth
Observation dim: 4, Actions: 2


  state_dict = torch.load(checkpoint_path, map_location=device)


Episode 1: reward=500.00, steps=500
Episode 2: reward=500.00, steps=500
Episode 3: reward=500.00, steps=500

==== SUMMARY ====
Avg Reward: 500.00 (max 500.00, min 500.00)
Avg Steps: 500.0
