# Interactive Atari Gas Visualization for Ms. Pac-Man

This notebook provides real-time, interactive visualizations for the Atari Gas algorithm running on Ms. Pac-Man.

**Features:**
- Real-time best walker screen display
- Cumulative reward curve over time
- Visit heatmap showing explored regions
- Summary table with step information (rewards, actions, dt values)
- Interactive controls (play/pause/reset/step)

**Observation Support:**
- **Image observations** (`obs_type='grayscale'` or `'rgb'`): Visualizations work directly from observations
- **RAM observations** (`obs_type='ram'`): Requires `return_image=True` to provide RGB frames in info dict

**Visualization Stack:**
- HoloViews with Bokeh backend for interactive plots
- Panel for dashboards and controls
- Streaming data for real-time updates

In [1]:
import ray
#ray.init(runtime_env={"working_dir": "/home/guillem/fragile"})

In [2]:
import numpy as np
import pandas as pd
import torch
import holoviews as hv
import hvplot.pandas
import panel as pn
import param
import time
from functools import partial

import plangym
from plangym.utils import process_frame

from fragile.atari_gas import AtariGas, AtariGasParams
from fragile.euclidean_gas import CloningParams
from fragile.shaolin.stream_plots import RGB, Image, Curve

# Enable HoloViews with Bokeh backend
hv.extension('bokeh')
pn.extension('tabulator')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x786028342bf0>

## PacManDisplay Class

Interactive display for Pac-Man with:
- Best walker's RGB screen
- Grayscale screen with visit heatmap overlay
- Cumulative reward curve
- Summary table with walker statistics

**Observation Compatibility:**
- Works with both image and RAM observations
- For RAM observations, requires `return_image=True` in plangym environment
- Automatically detects observation type and extracts RGB frames appropriately

In [3]:
class PacManDisplay:
    """Interactive visualization for Atari Gas on Pac-Man.
    
    Supports both image and RAM observations. When using RAM observations,
    RGB frames must be provided via the 'rgb' key in info dictionaries.
    
    Uses HoloViz stack (HoloViews + Panel) with streaming data for real-time updates.
    """
    
    SUMMARY_COLUMNS = [
        'step',
        'best_walker',
        'step_reward',
        'cumulative_reward',
        'mean_step_reward',
        'action',
        'dt',  # Number of times action was applied consecutively
    ]
    
    def __init__(
        self,
        frame_shape: tuple[int, int, int] = (210, 160, 3),
        reward_history: int = 5000,
        visit_shape: tuple[int, int] = (210, 160),
        use_ram_obs: bool = False,
    ):
        """Initialize the display.
        
        Args:
            frame_shape: Expected RGB frame shape (height, width, channels)
            reward_history: Number of points to keep in reward history
            visit_shape: Shape for visit counting (height, width)
            use_ram_obs: If True, expects RGB frames from info['rgb'] instead of observations
        """
        self.frame_shape = frame_shape
        self.visit_shape = visit_shape
        self.use_ram_obs = use_ram_obs
        
        # Initialize streaming plots using fragile.shaolin
        self.best_rgb = RGB(
            bokeh_opts={
                'height': 300,
                'width': 300,
                'toolbar': None,
                'title': 'Best Walker Screen',
            }
        )
        
        self.screen_grey = Image(
            bokeh_opts={
                'height': 300,
                'width': 300,
                'cmap': 'greys',
                'toolbar': None,
                'title': 'Grayscale Screen',
            }
        )
        
        self.visits_image = Image(
            bokeh_opts={
                'height': 300,
                'width': 300,
                'alpha': 0.7,
                'cmap': 'fire',
                'toolbar': None,
                'title': 'Visit Heatmap',
            }
        )
        
        self.reward_curve = Curve(
            data=pd.DataFrame({'step': [], 'reward': []}),
            buffer_length=reward_history,
            data_names=('step', 'reward'),
            bokeh_opts={
                'height': 300,
                'width': 620,
                'line_width': 3,
                'color': '#ffbf00',
                'xlabel': 'Step',
                'ylabel': 'Best Cumulative Reward',
                'tools': ['hover'],
                'title': 'Reward Progress',
            }
        )
        
        # Initialize visit tracking
        self.visits = np.zeros(visit_shape, dtype=np.int32)
        self.grey_frame = np.zeros(visit_shape, dtype=np.float32)
        self._curr_best_idx = -1
        self._step_counter = 0
        
    def reset(self):
        """Reset the display state."""
        self.visits = np.zeros(self.visit_shape, dtype=np.int32)
        self.grey_frame = np.zeros(self.visit_shape, dtype=np.float32)
        self._curr_best_idx = -1
        self._step_counter = 0
        
        # Clear plots
        self.best_rgb.send(np.zeros(self.frame_shape, dtype=np.uint8))
        self.screen_grey.send(self.grey_frame)
        self.visits_image.send(np.zeros(self.visit_shape))
        
    def update(self, state, step: int):
        """Update visualizations with new state.
        
        Args:
            state: AtariSwarmState from the gas algorithm
            step: Current step number
        """
        self._step_counter = step
        
        # Find best walker
        best_idx = int(torch.argmax(state.rewards).item())
        best_reward = float(state.rewards[best_idx].item())
        
        # Update best walker screen if changed
        if best_idx != self._curr_best_idx:
            self._update_best_screen(state, best_idx)
            self._curr_best_idx = best_idx
        
        # Update visit heatmap
        self._update_visits(state)
        
        # Update reward curve - send as DataFrame
        self.reward_curve.send(pd.DataFrame({'step': [step], 'reward': [best_reward]}))
        
    def _is_ram_observation(self, obs: np.ndarray) -> bool:
        """Check if observation is RAM (1D vector) or image (2D/3D)."""
        return obs.ndim == 1 or (obs.ndim == 2 and obs.shape[0] == 128)
    
    def _get_rgb_frame(self, state, best_idx: int) -> np.ndarray | None:
        """Extract RGB frame from state.
        
        Returns:
            RGB frame as uint8 array [H, W, 3], or None if not available
        """
        info = state.infos[best_idx] if best_idx < len(state.infos) else {}
        
        # Try to get RGB from info (works for both RAM and image obs when return_image=True)
        if 'rgb' in info:
            rgb_frame = np.asarray(info['rgb'])
            if rgb_frame.dtype != np.uint8:
                rgb_frame = np.clip(rgb_frame, 0, 255).astype(np.uint8)
            return rgb_frame
        
        # If no RGB in info, try to extract from observations (only works for image obs)
        obs = state.observations[best_idx].detach().cpu().numpy()
        
        # Check if this is RAM observation
        if self._is_ram_observation(obs):
            # RAM observations - can't visualize without RGB in info
            return None
        
        # Convert observation to RGB (for image observations)
        if obs.ndim == 3 and obs.shape[0] in (1, 3):
            obs = np.transpose(obs, (1, 2, 0))
        obs = np.squeeze(obs)
        rgb_frame = np.clip(obs * 255.0, 0, 255).astype(np.uint8)
        
        # Convert grayscale to RGB if needed
        if rgb_frame.ndim == 2:
            rgb_frame = np.stack([rgb_frame] * 3, axis=-1)
            
        return rgb_frame
        
    def _update_best_screen(self, state, best_idx: int):
        """Update the best walker's screen displays."""
        rgb_frame = self._get_rgb_frame(state, best_idx)
        
        if rgb_frame is None:
            # No visual data available (RAM obs without return_image=True)
            # Keep previous frame or show blank
            return
        
        # Update RGB display
        self.best_rgb.send(rgb_frame)
        
        # Update grayscale display
        self.grey_frame = process_frame(rgb_frame, mode='L').astype(np.float32)
        self.screen_grey.send(self.grey_frame)
        
    def _update_visits(self, state):
        """Update visit heatmap from state observations or frames."""
        best_idx = int(torch.argmax(state.rewards).item())
        
        # Try to use RGB frame for visit tracking
        rgb_frame = self._get_rgb_frame(state, best_idx)
        
        if rgb_frame is not None:
            # Use RGB frame for visit tracking
            if rgb_frame.ndim == 2:
                # Grayscale
                mask = rgb_frame > 25  # Use threshold for binary mask
            else:
                # RGB - use luminance
                luminance = (0.299 * rgb_frame[:, :, 0] + 
                           0.587 * rgb_frame[:, :, 1] + 
                           0.114 * rgb_frame[:, :, 2])
                mask = luminance > 25
        else:
            # RAM observations without RGB - can't track spatial visits
            # Skip visit tracking
            return
            
        self.visits += mask.astype(np.int32)
        
        # Prepare visit data for display (log scale for better visualization)
        visits_display = np.log1p(self.visits.astype(np.float32))
        visits_display[visits_display == 0] = np.nan  # Make zeros transparent
        
        self.visits_image.send(visits_display)
        
    def get_summary(self, state, step: int) -> dict:
        """Get summary statistics for display.
        
        Returns:
            Dictionary with summary statistics
        """
        best_idx = int(torch.argmax(state.rewards).item())
        
        return {
            'step': step,
            'best_walker': best_idx,
            'step_reward': float(state.step_rewards[best_idx].item()),
            'cumulative_reward': float(state.rewards[best_idx].item()),
            'mean_step_reward': float(state.step_rewards.mean().item()),
            'action': int(state.actions[best_idx].item()),
            'dt': int(state.dts[best_idx].item()),
        }
    
    def __panel__(self):
        """Return Panel layout for display."""
        return pn.Column(
            pn.Row(
                self.best_rgb.plot,
                self.screen_grey.plot * self.visits_image.plot,
            ),
            self.reward_curve.plot,
        )

## GasRunner Class

Interactive control panel for running the Gas algorithm:
- Play/Pause/Reset/Step controls
- Progress bar
- Summary statistics table
- Adjustable parameters

In [None]:
class GasRunner(param.Parameterized):
    """Interactive runner for Atari Gas with Panel controls."""
    
    is_running = param.Boolean(default=False)
    
    def __init__(self, gas, n_steps: int, display: PacManDisplay, report_interval: int = 100):
        super().__init__()
        self.gas = gas
        self.state = None
        self.n_steps = n_steps
        self.curr_step = 0
        self.display = display
        self.report_interval = report_interval
        self.terminated_early = False
        self.termination_reason = ""
        
        # Create controls
        self.reset_btn = pn.widgets.Button(
            icon='restore',
            button_type='primary',
            name='Reset',
        )
        self.play_btn = pn.widgets.Button(
            icon='player-play',
            button_type='success',
            name='Play',
        )
        self.pause_btn = pn.widgets.Button(
            icon='player-pause',
            button_type='warning',
            name='Pause',
            disabled=True,
        )
        self.step_btn = pn.widgets.Button(
            name='Step',
            button_type='primary',
        )
        
        self.progress = pn.indicators.Progress(
            name='Progress',
            value=0,
            max=n_steps,
            width=800,
            bar_color='primary',
        )
        
        self.status_text = pn.pane.Markdown('')
        
        self.sleep_input = pn.widgets.FloatInput(
            name='Sleep (s)',
            value=0.0,
            start=0.0,
            end=1.0,
            step=0.01,
            width=100,
        )
        
        self.report_input = pn.widgets.IntInput(
            name='Report Interval',
            value=report_interval,
            start=1,
            width=120,
        )
        
        self.summary_table = pn.widgets.Tabulator(
            pagination='remote',
            page_size=1,
            width=800,
        )
        
        # Wire up callbacks
        self.reset_btn.on_click(lambda event: self.on_reset())
        self.play_btn.on_click(lambda event: self.on_play())
        self.pause_btn.on_click(lambda event: self.on_pause())
        self.step_btn.on_click(lambda event: self.on_step())
        
        # Initialize
        self.on_reset()
        
    def on_reset(self):
        """Reset the simulation."""
        self.state = self.gas.initialize_state()
        self.curr_step = 0
        self.progress.value = 0
        self.progress.bar_color = 'primary'
        self.is_running = False
        self.terminated_early = False
        self.termination_reason = ""
        self.status_text.object = ""
        self.play_btn.disabled = False
        self.pause_btn.disabled = True
        self.step_btn.disabled = False
        
        # Reset display
        self.display.reset()
        
        # Update summary
        summary = self.display.get_summary(self.state, self.curr_step)
        self.summary_table.value = pd.DataFrame([summary])
        
    def on_play(self):
        """Start continuous execution."""
        if not self.terminated_early:
            self.is_running = True
            self.play_btn.disabled = True
            self.pause_btn.disabled = False
        
    def on_pause(self):
        """Pause execution."""
        self.is_running = False
        self.play_btn.disabled = False
        self.pause_btn.disabled = True
        
    def on_step(self):
        """Execute a single step."""
        if not self.is_running and not self.terminated_early:
            self.run_single_step()
    
    def run_single_step(self):
        """Execute one step of the gas algorithm."""
        if self.curr_step >= self.n_steps or self.terminated_early:
            return
            
        # Step the algorithm
        _, self.state, _ = self.gas.step(self.state)
        self.curr_step += 1
        self.progress.value = self.curr_step
        
        # Check for early termination
        should_stop, reason = self.gas.should_terminate(self.state)
        if should_stop:
            self.terminated_early = True
            self.termination_reason = reason
            self.is_running = False
            self.progress.bar_color = 'warning'
            self.status_text.object = f'**⚠️ Terminated Early:** {reason}'
            self.play_btn.disabled = True
            self.pause_btn.disabled = True
            self.step_btn.disabled = True
        elif self.curr_step >= self.n_steps:
            # Normal completion
            self.is_running = False
            self.progress.bar_color = 'success'
            self.status_text.object = '**✓ Completed all steps**'
            self.play_btn.disabled = True
            self.pause_btn.disabled = True
            self.step_btn.disabled = True
            
        # Update display at report intervals
        if self.curr_step % self.report_input.value == 0:
            self.display.update(self.state, self.curr_step)
            summary = self.display.get_summary(self.state, self.curr_step)
            self.summary_table.value = pd.DataFrame([summary])
    
    def run(self):
        """Periodic callback for continuous execution."""
        if self.is_running:
            self.run_single_step()
            time.sleep(self.sleep_input.value)
    
    def __panel__(self):
        """Return Panel layout for controls."""
        # Add periodic callback for continuous execution
        pn.state.add_periodic_callback(self.run, period=50)
        
        return pn.Column(
            self.summary_table,
            self.progress,
            self.status_text,
            pn.Row(
                self.play_btn,
                self.pause_btn,
                self.reset_btn,
                self.step_btn,
                self.sleep_input,
                self.report_input,
            ),
        )

## Setup Environment and Gas Algorithm

In [5]:
def observation_transform(obs: np.ndarray) -> np.ndarray:
    """Transform observations to [H, W, C] format with values in [0, 1]."""
    arr = np.asarray(obs, dtype=np.float32)
    if arr.ndim == 2:
        arr = arr[..., None]
    if arr.ndim == 3 and arr.shape[0] in (1, 3):
        arr = np.transpose(arr, (1, 2, 0))
    return arr / 255.0

# Create environment
env = plangym.make(
    'MsPacman-v4',
    obs_type='ram',
    return_image=True,
    frameskip=3,
    episodic_life=False,
    #ray=True,
    n_workers=8
    
)

# Configure gas parameters
params = AtariGasParams(
    N=64,  # Number of walkers
    env=env,
    cloning=CloningParams(
        sigma_x=0.05,
        lambda_alg=0.01,
        alpha_restitution=0.0,
        use_inelastic_collision=False,
    ),
    device='cpu',
    dtype='float32',
    dt_range=(2, 10),  # Apply each action 1-4 times consecutively
    observation_transform=observation_transform,
)

# Create gas instance
gas = AtariGas(params)

print(f"Environment: {env.gym_env.spec.id}")
print(f"Action space: {env.action_space}")
print(f"Number of walkers: {params.N}")

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Environment: MsPacman-v4
Action space: Discrete(9)
Number of walkers: 64


Process Process-1:
Process Process-4:
Process Process-8:
Process Process-3:
Process Process-6:
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Process Process-7:
Process Process-5:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Process Process-2:
  File "/home/guillem/.local/share/uv/python/cpython-3.10.0-linux-x86_64-gnu/lib/python3.10/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/guillem/.local/share/uv/python/cpython-3.10.0-linux-x86_64-gnu/lib/python3.10/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/guillem/.local/share/uv/python/cpython-3.10.0-linux-x86_64-gnu/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self

## Create Interactive Dashboard

In [6]:
# Detect if using RAM observations
use_ram = hasattr(env, 'obs_type') and env.obs_type == 'ram'

# Create display and runner
display = PacManDisplay(
    frame_shape=(210, 160, 3),
    reward_history=5000,
    visit_shape=(210, 160),
    use_ram_obs=use_ram,
)

runner = GasRunner(
    gas=gas,
    n_steps=5000,
    display=display,
    report_interval=10,
)

# Create dashboard layout
dashboard = pn.Column(
    pn.pane.Markdown(f'# Atari Gas Interactive Dashboard - Ms. Pac-Man'),
    pn.pane.Markdown(f'**Observation Type:** {"RAM" if use_ram else "Image"} | **Return Image:** {env.return_image if hasattr(env, "return_image") else "N/A"}'),
    runner,
    display,
)

# Display
dashboard

## Cleanup

In [7]:
# Close environment when done
#env.close()