In [81]:
from datetime import datetime
from enum import Enum

from IPython.core.pylabtools import figsize
from ale_py import ALEInterface
import gymnasium as gym
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, clear_output
import cv2
import numba

In [82]:
CONST_COLOR_PLAYER = (240, 170, 103)
CONST_COLOR_WALL = (84, 92, 214)
CONST_COLOR_ENEMY = (210, 210, 64)

CAT_EMPTY = 0
CAT_PLAYER = 1
CAT_WALL = 2
CAT_ENEMY = 3

@numba.njit
def prepare_state_categorical_inner(obs_resized, h, w):
    new_obs = np.full((h, w), CAT_EMPTY, dtype=np.uint8)

    for i in range(h):
        for j in range(w):
            pixel = obs_resized[i, j]

            # Compare pixel (which is a 1D np.array) to the tuple elements
            if (pixel[0] == CONST_COLOR_PLAYER[0] and
                pixel[1] == CONST_COLOR_PLAYER[1] and
                pixel[2] == CONST_COLOR_PLAYER[2]):
                new_obs[i, j] = CAT_PLAYER # Use constant
            elif (pixel[0] == CONST_COLOR_WALL[0] and
                  pixel[1] == CONST_COLOR_WALL[1] and
                  pixel[2] == CONST_COLOR_WALL[2]):
                new_obs[i, j] = CAT_WALL # Use constant
            elif (pixel[0] == CONST_COLOR_ENEMY[0] and
                  pixel[1] == CONST_COLOR_ENEMY[1] and
                  pixel[2] == CONST_COLOR_ENEMY[2]):
                new_obs[i, j] = CAT_ENEMY # Use constant
    return new_obs

def prepare_state_categorical(frame, h=21, w=21):
    obs_resized = cv2.resize(frame, (w, h), interpolation=cv2.INTER_NEAREST)
    return prepare_state_categorical_inner(obs_resized, h, w)

In [83]:
class ResizeObservation(gym.ObservationWrapper):

    def __init__(self, env, h=21, w=21):
        super().__init__(env)
        self.h, self.w = h, w
        self.observation_space = gym.spaces.Box(0, 3, (h, w), np.float32)

    def observation(self, obs):
        return prepare_state_categorical(obs, self.h, self.w)


In [84]:
ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
env = ResizeObservation(env, h=21, w=21)
observation, info = env.reset()

In [85]:
state, _ = env.reset()

for _ in range(20): # Take 5 steps to get into the game
    state, _, _, _, _ = env.step(0)

state, _, _, _, _ = env.step(0)

print("Initial state shape:", state.shape)

Initial state shape: (21, 21)


In [86]:
print(state)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 1 0 0 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 2]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]


In [None]:
class StateObserver:
    def __init__(self, w, h):
        self.w = w
        self.h = h
        self.prev_pos = None

    @numba.njit
    def _find_entities(self, state):
        players = []
        enemies = []

        for i in range(self.h):
            for j in range(self.w):
                if state[i, j] == CAT_PLAYER:
                    players.append((i, j))
                elif state[i, j] == CAT_ENEMY:
                    enemies.append((i, j))
        return players, enemies

    def analyze_state(self, state, reward):
        players, enemies = self._find_entities(state)
        player_pos = players[0]

        shaped_reward = reward

        shaped_reward -= 0.1

        if self.prev_pos is not None and len(players) > 0 and player_pos == self.prev_pos:
            shaped_reward -= 0.5

        if player_pos and len(enemies) > 0:
            # Calculate Manhattan distance to all enemies
            distances = np.sum(np.abs(enemies - player_pos), axis=1)
            min_distance = np.min(distances)

            if min_distance <= 2:
                shaped_reward -= 0.5

        self.prev_pos = player_pos
        return shaped_reward
