# Atari Gas Exploration on Ms. Pac-Man

This experiment demonstrates how the `fragile.atari_gas` module behaves on the classic Atari task **Ms. Pac-Man**.  We instantiate the Atari Gas algorithm with a uniform action prior, let it explore for a few hundred steps, and then inspect the evolution of the best cumulative reward alongside the screen observed by the top-performing walker.

**Key parameters:**
- `dt_range=(1, 4)`: Each walker randomly applies its action between 1-4 times consecutively before taking a new action. This parameter controls action persistence and exploration speed.

In [1]:
import numpy as np
import pandas as pd
import torch
import holoviews as hv
import hvplot.pandas
import panel as pn

import plangym

from fragile.atari_gas import AtariGas, AtariGasParams
from fragile.euclidean_gas import CloningParams

# Enable HoloViews with Bokeh backend
hv.extension('bokeh')
pn.extension()

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7b6a4cd41d70>

In [2]:
def observation_transform(obs: np.ndarray) -> np.ndarray:
    arr = np.asarray(obs, dtype=np.float32)
    if arr.ndim == 2:
        arr = arr[..., None]
    if arr.ndim == 3 and arr.shape[0] in (1, 3):
        arr = np.transpose(arr, (1, 2, 0))
    return arr / 255.0

env = plangym.make(
    'MsPacman-v4',
    obs_type='grayscale',
    return_image=True,
    frameskip=3,
    episodic_life=False,
)

params = AtariGasParams(
    N=32,
    env=env,
    cloning=CloningParams(
        sigma_x=0.05,
        lambda_alg=0.01,
        alpha_restitution=0.0,
        use_inelastic_collision=False,
    ),
    device='cpu',
    dtype='float32',
    dt_range=(1, 4),
    observation_transform=observation_transform,
)

gas = AtariGas(params)
state = gas.initialize_state()

state.N

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


32

In [3]:
n_steps = 300
history = []
best_snapshot: tuple[int, np.ndarray] | None = None
best_value = float('-inf')

for step in range(1, n_steps + 1):
    _, state, _ = gas.step(state)

    cumulative = state.rewards.detach().cpu()
    step_rewards = state.step_rewards.detach().cpu()

    best_ix = int(torch.argmax(cumulative).item())
    best_reward = float(cumulative[best_ix].item())
    best_step_reward = float(step_rewards[best_ix].item())
    mean_step_reward = float(step_rewards.mean().item())

    info = state.infos[best_ix] if best_ix < len(state.infos) else None
    frame = None
    if isinstance(info, dict) and 'rgb' in info:
        frame = np.asarray(info['rgb'])
    else:
        raw = state.observations[best_ix].detach().cpu().numpy()
        if raw.ndim == 3 and raw.shape[0] in (1, 3):
            raw = np.transpose(raw, (1, 2, 0))
        raw = np.squeeze(raw)
        frame = np.clip(raw * 255.0, 0, 255).astype(np.uint8)

    if frame is not None and best_reward > best_value:
        best_value = best_reward
        best_snapshot = (step, frame.copy())

    history.append(
        {
            'step': step,
            'best_reward': best_reward,
            'best_step_reward': best_step_reward,
            'mean_step_reward': mean_step_reward,
            'best_index': best_ix,
        }
    )

history_df = pd.DataFrame(history)
history_df.tail()


Unnamed: 0,step,best_reward,best_step_reward,mean_step_reward,best_index
295,296,830.0,0.0,0.0,0
296,297,830.0,0.0,0.0,0
297,298,830.0,0.0,0.0,0
298,299,830.0,0.0,0.0,0
299,300,830.0,0.0,0.0,0


In [4]:
# Create interactive plot using hvPlot
plot = history_df.hvplot(
    x='step',
    y=['best_reward', 'mean_step_reward', 'best_step_reward'],
    kind='line',
    width=800,
    height=400,
    title='Atari Gas progress on Ms. Pac-Man',
    xlabel='Step',
    ylabel='Reward',
    legend='top_left',
    line_width=2,
    grid=True,
)

plot

In [5]:
if best_snapshot is not None:
    best_step, best_frame = best_snapshot
    
    # Convert frame to RGB if grayscale
    if best_frame.ndim == 2:
        # For grayscale, create RGB image
        img = hv.Image(best_frame, bounds=(0, 0, best_frame.shape[1], best_frame.shape[0])).opts(
            cmap='gray',
            width=400,
            height=400,
            title=f'Best walker screen (step {best_step})',
            toolbar='above',
            xaxis=None,
            yaxis=None,
        )
    else:
        # For RGB, use RGB element
        img = hv.RGB(np.flipud(best_frame), bounds=(0, 0, best_frame.shape[1], best_frame.shape[0])).opts(
            width=400,
            height=400,
            title=f'Best walker screen (step {best_step})',
            toolbar='above',
            xaxis=None,
            yaxis=None,
        )
    
    img
else:
    print('No frame captured yet.')

In [6]:
env.close()
