In [81]:
from datetime import datetime
from enum import Enum

from IPython.core.pylabtools import figsize
from ale_py import ALEInterface
import gymnasium as gym
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, clear_output
import cv2
from typing import Optional, Tuple

In [82]:
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, h=21, w=21):
        super().__init__(env)
        self.h, self.w = h, w
        self.observation_space = gym.spaces.Box(0, 1, (h, w, 3), np.float32)

    def observation(self, obs):
        obs = cv2.resize(obs, (self.w, self.h), interpolation=cv2.INTER_NEAREST)
        return obs.astype(np.float32) / 255.0


In [83]:
def prepare_image(frame, h=21, w=21):
    arr = np.asarray(frame)
    if np.issubdtype(arr.dtype, np.floating):
        arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
    else:
        arr = arr.astype(np.uint8)
    resized = cv2.resize(arr, (w, h), interpolation=cv2.INTER_NEAREST)
    return (resized.astype(np.float32) / 255.0).flatten()


In [84]:
ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
env = ResizeObservation(env, h=21, w=21)
observation, info = env.reset()


In [85]:
state, _ = env.reset()

for _ in range(20): # Take 5 steps to get into the game
    state, _, _, _, _ = env.step(0)

state, _, _, _, _ = env.step(0)

print("Initial state shape:", state.shape)

Initial state shape: (21, 21, 3)


In [86]:
image = prepare_image(state)

In [87]:
arr = np.asarray(state)

# handle flattened arrays
if arr.ndim == 1:
    if arr.size == h * w * 3:
        arr = arr.reshape((h, w, 3))
    elif arr.size == h * w:
        arr = arr.reshape((h, w))
    else:
        raise ValueError(f"Unexpected flattened size: {arr.size}")

# single-channel -> replicate to 3 channels
if arr.ndim == 2:
    arr = np.stack([arr, arr, arr], axis=-1)

# ensure shape is HxWx3
if arr.ndim != 3 or arr.shape[2] != 3:
    raise ValueError(f"Unexpected image shape after processing: {arr.shape}")

# convert floats in [0,1] to uint8 or clip existing ints
if np.issubdtype(arr.dtype, np.floating):
    img_uint8 = (arr * 255.0).clip(0, 255).astype(np.uint8)
else:
    img_uint8 = arr.astype(np.uint8)

# Save files (no backticks in filenames)
Image.fromarray(img_uint8).save('out_pil.png')
cv2.imwrite('out_cv.png', cv2.cvtColor(img_uint8, cv2.COLOR_RGB2BGR))
plt.imsave('out_plt.png', img_uint8)

In [88]:
def _to_rgb_uint8(img, h: Optional[int] = None, w: Optional[int] = None) -> np.ndarray:
    arr = np.asarray(img)
    # if path provided, load with PIL
    if isinstance(img, str):
        arr = np.asarray(Image.open(img))
    # flattened -> try to reshape
    if arr.ndim == 1:
        if h is None or w is None:
            raise ValueError("Provide h and w for flattened image")
        if arr.size == h * w * 3:
            arr = arr.reshape((h, w, 3))
        elif arr.size == h * w:
            arr = arr.reshape((h, w))
        else:
            raise ValueError(f"Unexpected flattened size: {arr.size}")
    # single-channel -> replicate to 3 channels
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=-1)
    if arr.ndim != 3 or arr.shape[2] not in (3, 4):
        raise ValueError(f"Unexpected image shape: {arr.shape}")
    # drop alpha if present
    if arr.shape[2] == 4:
        arr = arr[:, :, :3]
    # convert floats [0,1] to uint8
    if np.issubdtype(arr.dtype, np.floating):
        arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
    else:
        arr = arr.astype(np.uint8)
    return arr

In [89]:
def find_unique_colors(img, h: Optional[int] = None, w: Optional[int] = None,
                       return_counts: bool = True, save_palette: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """
    img: file path or numpy array (HxWx3, HxW, flattened, or floats in [0,1])
    h,w: required only for flattened images
    return_counts: if True returns counts per unique color
    save_palette: optional filename (e.g. `palette.png`) to save palette visualization
    Returns: (unique_colors (N x 3 uint8), counts (N,) or None)
    """
    arr = _to_rgb_uint8(img, h=h, w=w)
    h_, w_, _ = arr.shape
    flat = arr.reshape((-1, 3))
    # use np.unique on rows
    unique_colors, inv_indices, counts = np.unique(flat, axis=0, return_inverse=True, return_counts=True)
    # sort by frequency descending
    order = np.argsort(-counts)
    unique_colors = unique_colors[order]
    counts = counts[order]
    if save_palette:
        # build a small palette image: rows of color swatches
        sw = 32
        cols = min(16, unique_colors.shape[0])
        rows = (unique_colors.shape[0] + cols - 1) // cols
        palette = np.zeros((rows * sw, cols * sw, 3), dtype=np.uint8)
        for i, c in enumerate(unique_colors):
            r = (i // cols) * sw
            c0 = (i % cols) * sw
            palette[r:r+sw, c0:c0+sw] = c
        Image.fromarray(palette).save(save_palette)
    return (unique_colors, counts if return_counts else None)


In [90]:
find_unique_colors(state, h=21, w=21, save_palette='palette.png')

(array([[  0,   0,   0],
        [ 84,  92, 214],
        [210, 210,  64],
        [240, 170, 103]], dtype=uint8),
 array([389,  46,   5,   1]))