In [3]:
from datetime import datetime
from enum import Enum

import numba
from IPython.core.pylabtools import figsize
from ale_py import ALEInterface
import gymnasium as gym
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, clear_output
import cv2
from typing import Optional, Tuple

In [4]:
CONST_COLOR_PLAYER = (240, 170, 103)
CONST_COLOR_WALL = (84, 92, 214)
CONST_COLOR_ENEMY = (210, 210, 64)

CAT_EMPTY = 0
CAT_PLAYER = 1
CAT_WALL = 2
CAT_ENEMY = 3

In [5]:
@numba.njit
def _get_one_hot_state(obs_resized, h, w):
    new_obs = np.full((h, w, 4), 0, dtype=np.uint8)
    for i in range(h):
        for j in range(w):
            pixel = obs_resized[i, j]
            if (pixel[0] == CONST_COLOR_PLAYER[0] and
                pixel[1] == CONST_COLOR_PLAYER[1] and
                pixel[2] == CONST_COLOR_PLAYER[2]):
                new_obs[i, j, CAT_PLAYER] = 1
            elif (pixel[0] == CONST_COLOR_WALL[0] and
                  pixel[1] == CONST_COLOR_WALL[1] and
                  pixel[2] == CONST_COLOR_WALL[2]):
                new_obs[i, j, CAT_WALL] = 1
            elif (pixel[0] == CONST_COLOR_ENEMY[0] and
                  pixel[1] == CONST_COLOR_ENEMY[1] and
                  pixel[2] == CONST_COLOR_ENEMY[2]):
                new_obs[i, j, CAT_ENEMY] = 1
            else:
                new_obs[i, j, CAT_EMPTY] = 1
    return new_obs

# --- Helper: Find Entities (Copied from your StateObserver) ---
@numba.njit
def _find_entities(one_hot_state):
    player_coords = np.argwhere(one_hot_state[:, :, CAT_PLAYER] == 1)
    enemy_coords = np.argwhere(one_hot_state[:, :, CAT_ENEMY] == 1)
    player_pos = player_coords[0] if len(player_coords) > 0 else None
    return player_pos, enemy_coords

# --- Helper: Find Closest Enemy Direction ---
@numba.njit
def _get_enemy_direction(player_pos, enemies):
    if player_pos is None or len(enemies) == 0:
        return 0.0, 0.0 # No enemy, no direction

    closest_enemy = enemies[0]
    min_dist = np.inf

    # Find the closest enemy
    for i in range(len(enemies)):
        dist = np.sum(np.abs(enemies[i] - player_pos))
        if dist < min_dist:
            min_dist = dist
            closest_enemy = enemies[i]

    # Calculate direction (returns -1.0, 0.0, or 1.0)
    # We normalize to prevent one direction from seeming "larger"
    dir_y = np.sign(closest_enemy[0] - player_pos[0])
    dir_x = np.sign(closest_enemy[1] - player_pos[1])

    return float(dir_y), float(dir_x)

# --- THE NEW PREPROCESSING FUNCTION ---
def extract_intelligent_features(frame, h=21, w=21):
    """
    This is the new main preprocessing function.
    It returns a 6-feature vector, not a 1764-pixel vector.
    """

    # 1. First, get the raw 0-255 image
    obs_resized = cv2.resize(frame, (w, h), interpolation=cv2.INTER_NEAREST)

    # 2. Get the 21x21x4 one-hot state
    one_hot_state = _get_one_hot_state(obs_resized, h, w)

    # 3. Find the player and enemies
    player_pos, enemies = _find_entities(one_hot_state)

    if player_pos is None:
        # Player is dead, return a zero-vector
        return np.zeros(6, dtype=np.float32)

    (y, x) = player_pos

    # 4. Feature 0-3: Wall Proximity
    # (1.0 if a wall is there, 0.0 if not)
    f_wall_up = one_hot_state[y-1, x, CAT_WALL] if y > 0 else 1.0
    f_wall_down = one_hot_state[y+1, x, CAT_WALL] if y < h-1 else 1.0
    f_wall_left = one_hot_state[y, x-1, CAT_WALL] if x > 0 else 1.0
    f_wall_right = one_hot_state[y, x+1, CAT_WALL] if x < w-1 else 1.0

    # 5. Feature 4-5: Enemy Direction
    f_enemy_dir_y, f_enemy_dir_x = _get_enemy_direction(player_pos, enemies)

    # 6. Return the final, intelligent feature vector
    features = np.array([
        f_wall_up, f_wall_down, f_wall_left, f_wall_right,
        f_enemy_dir_y, f_enemy_dir_x
    ], dtype=np.float32)

    return features

In [6]:
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, h=21, w=21):
        super().__init__(env)
        self.h, self.w = h, w

        # --- NEW OBSERVATION SPACE ---
        # 6 features: wall_up, wall_down, wall_left, wall_right,
        #             enemy_dir_y, enemy_dir_x
        # Values are -1.0 to 1.0 (for directions and wall flags)
        self.observation_space = gym.spaces.Box(
            low=-1.0, high=1.0, shape=(6,), dtype=np.float32
        )

    def observation(self, obs):
        # Call our new, intelligent feature extractor
        return extract_intelligent_features(obs, self.h, self.w)

In [7]:
ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
env = ResizeObservation(env, h=21, w=21)
observation, info = env.reset()


In [8]:
state, _ = env.reset()

for _ in range(20): # Take 5 steps to get into the game
    state, _, _, _, _ = env.step(0)

state, _, _, _, _ = env.step(0)

print("Initial state shape:", state.shape)
print(state)

Initial state shape: (6,)
[0. 0. 1. 0. 1. 0.]
