# Set up environment

In [1]:
from ale_py import ALEInterface
import gymnasium as gym
import numpy as np
import numba
import plotly.express as px

## Environment constants

In [2]:
FIRE_ACTIONS = [1, 10, 11, 12, 13, 14, 15, 16, 17]
MOVE_ACTIONS = [2, 3, 4, 5, 6, 7, 8, 9]
np.random.seed(42)

In [3]:
UP_DIRECTION = 0
RIGHT_DIRECTION = 1
LEFT_DIRECTION = 2
DOWN_DIRECTION = 3
UPRIGHT_DIRECTION = 4
UPLEFT_DIRECTION = 5
DOWNRIGHT_DIRECTION = 6
DOWNLEFT_DIRECTION = 7

ACTION_TO_DIRECTIONS = {
    2: [0],  # UP
    3: [1],  # RIGHT
    4: [2],  # LEFT
    5: [3],  # DOWN
    6: [0, 1],  # UPRIGHT
    7: [0, 2],  # UPLEFT
    8: [3, 1],  # DOWNRIGHT
    9: [3, 2],  # DOWNLEFT
}

In [4]:
EMPTY_COLOR = np.array([0, 0, 0], dtype=np.uint8)
WALL_COLOR = np.array([84, 92, 214], dtype=np.uint8)
ENEMY_COLOR = np.array([210, 210, 64], dtype=np.uint8)
PLAYER_COLOR = np.array([240, 170, 103], dtype=np.uint8)

EMPTY_INDEX = 0
WALL_INDEX = 1
ENEMY_INDEX = 2
PLAYER_INDEX = 3

PORTAL_INDEX = 4  # add manualy
PORTAL_COLOR = np.array([74, 255, 56], dtype=np.uint8)

AVG_PIXELS_IN_ENEMY = 74
MAX_ENEMIES = 8

## State processing functions

In [5]:
@numba.njit
def rgb_to_index(frame):
    state = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
    enemy_pixel_count = 0
    for i in range(frame.shape[0]):
        for j in range(frame.shape[1]):
            pixel = frame[i, j]
            if pixel[0] == WALL_COLOR[0]:
                state[i, j] = WALL_INDEX
            elif pixel[0] == ENEMY_COLOR[0]:
                state[i, j] = ENEMY_INDEX
                enemy_pixel_count += 1
            elif pixel[0] == PLAYER_COLOR[0]:
                state[i, j] = PLAYER_INDEX
            elif pixel[0] == PORTAL_COLOR[0]:
                state[i, j] = PORTAL_INDEX
            else:
                state[i, j] = EMPTY_INDEX
    return state, enemy_pixel_count


@numba.njit
def cut_empty_layers(state):
    skip_layers = 0
    while True:
        if state[skip_layers][skip_layers] == 0:
            skip_layers += 1
        else:
            break
    state = state[skip_layers:-skip_layers, skip_layers:-skip_layers]

    skip_layers_from_bottom = 0
    while True:
        if state[-(skip_layers_from_bottom + 1)][-(skip_layers_from_bottom + 1)] == 0:
            skip_layers_from_bottom += 1
        else:
            break
    state = state[:-skip_layers_from_bottom, :]
    return state


@numba.njit
def fill_part(matrix, a, b, c, d, fill_value):
    rows, cols = matrix.shape
    min_i = max(0, min(a[0], b[0], c[0], d[0]))
    max_i = min(rows - 1, max(a[0], b[0], c[0], d[0]))
    min_j = max(0, min(a[1], b[1], c[1], d[1]))
    max_j = min(cols - 1, max(a[1], b[1], c[1], d[1]))

    for ii in range(min_i, max_i + 1):
        for jj in range(min_j, max_j + 1):
            if matrix[ii, jj] == EMPTY_INDEX:
                matrix[ii, jj] = fill_value


@numba.njit
def fill_holes(state):
    PORTAL_MIN_LENGHT = 3
    PORTAL_WIDTH = 0
    WALL_WIDTH = 3

    if state[0][0] != WALL_INDEX:
        return

    P = 2 * (state.shape[0] + state.shape[1])

    i, j = 0, 0
    height, width = state.shape
    last_wall_pixel = (0, 0)
    iters = 0
    while True:
        if iters > P + 1:
            raise Warning("No walls found on border during hole filling")

        iters += 1

        if state[i][j] == WALL_INDEX:
            distance_to_last_wall = max(abs(i - last_wall_pixel[0]), abs(j - last_wall_pixel[1]))
            if distance_to_last_wall == 1:
                pass
            else:
                material_to_fill = WALL_INDEX if distance_to_last_wall < PORTAL_MIN_LENGHT else PORTAL_INDEX
                width_to_fill = WALL_WIDTH if material_to_fill == WALL_INDEX else PORTAL_WIDTH
                A, B, C, D = None, None, None, None
                if j == 0:
                    A = (last_wall_pixel[0] + 1, last_wall_pixel[1])
                    B = (last_wall_pixel[0] + 1, last_wall_pixel[1] + width_to_fill)
                    C = (i, j + width_to_fill)
                    D = (i, j)
                elif i == height - 1:
                    A = (last_wall_pixel[0], last_wall_pixel[1] + 1)
                    B = (last_wall_pixel[0] - width_to_fill, last_wall_pixel[1] + 1)
                    C = (i - width_to_fill, j)
                    D = (i, j)
                elif j == width - 1:
                    A = (last_wall_pixel[0] - 1, last_wall_pixel[1])
                    B = (last_wall_pixel[0] - 1, last_wall_pixel[1] - width_to_fill)
                    C = (i, j - width_to_fill)
                    D = (i, j)
                elif i == 0:
                    A = (last_wall_pixel[0], last_wall_pixel[1] - 1)
                    B = (last_wall_pixel[0] - width_to_fill, last_wall_pixel[1] - 1)
                    C = (i + width_to_fill, j)
                    D = (i, j)
                else:
                    raise ValueError("Unexpected wall pixel not on border")

                fill_part(state, A, B, C, D, material_to_fill)

            last_wall_pixel = (i, j)

        if j == 0 and i < height - 1:
            i += 1
        elif i == height - 1 and j < width - 1:
            j += 1
        elif j == width - 1 and i > 0:
            i -= 1
        elif i == 0 and j > 0:
            j -= 1

        if i == 0 and j == 0:
            break


@numba.njit
def find_player_box(state):
    player_positions = np.argwhere(state == PLAYER_INDEX)

    if player_positions.shape[0] == 0:
        return None

    i_pos = player_positions[:, 0]
    j_pos = player_positions[:, 1]
    box = (np.min(i_pos), np.min(j_pos), np.max(i_pos), np.max(j_pos))
    return box

@numba.njit
def _int_linspace(start, stop, count):
    arr = np.empty(count, dtype=np.int32)
    if count == 1:
        arr[0] = start
        return arr
    step = (stop - start) / (count - 1)
    for k in range(count):
        arr[k] = int(start + k * step)
    return arr

@numba.njit
def get_scanning_points(box):
    """
    :return: left_search_points, right_search_points, up_search_points, down_search_points
    """
    vertical_aligments_search_points_y = _int_linspace(box[0], box[2], 4)
    left_search_points = [(y, box[1]) for y in vertical_aligments_search_points_y]
    right_search_points = [(y, box[3]) for y in vertical_aligments_search_points_y]

    horizontal_search_points_x = _int_linspace(box[1], box[3], 2)
    up_search_points = [(box[0], x) for x in horizontal_search_points_x]
    down_search_points = [(box[2], x) for x in horizontal_search_points_x]

    return left_search_points, right_search_points, up_search_points, down_search_points


@numba.njit
def basic_observation(state, box):
    if box is None:
        return None, None, None, None, None, None

    left_search_points, right_search_points, up_search_points, down_search_points = get_scanning_points(box)

    l_wall = (left_search_points[0][0], 0)
    r_wall = (right_search_points[0][0], state.shape[1] - 1)
    u_wall = (0, up_search_points[0][1])
    d_wall = (state.shape[0] - 1, down_search_points[0][1])

    reach_l, reach_r, reach_u, reach_d = False, False, False, False

    closest_enemy = None
    closest_portal = None

    i = 0
    while True:
        if i > max(state.shape):
            raise ValueError("No walls found in any direction")

        if not reach_l:
            cords = [(point[0], point[1] - i) for point in left_search_points]
            for cord in cords:
                material = state[cord]
                if material == WALL_INDEX or cords[0][1] == 0:
                    reach_l = True
                    l_wall = cord
                if material == PORTAL_INDEX:
                    reach_l = True
                    if closest_portal is None:
                        closest_portal = cord
                elif material == ENEMY_INDEX:
                    if closest_enemy is None:
                        closest_enemy = cord

                if reach_l:
                    break

        if not reach_r:
            cords = [(point[0], point[1] + i) for point in right_search_points]
            for cord in cords:
                material = state[cord]
                if material == WALL_INDEX or cord[1] == state.shape[1] - 1:
                    reach_r = True
                    r_wall = cord
                elif material == PORTAL_INDEX:
                    reach_r = True
                    if closest_portal is None:
                        closest_portal = cord
                elif material == ENEMY_INDEX:
                    if closest_enemy is None:
                        closest_enemy = cord

                if reach_r:
                    break

        if not reach_u:
            cords = [(point[0] - i, point[1]) for point in up_search_points]
            for cord in cords:
                material = state[cord]
                if material == WALL_INDEX or cord[0] == 0:
                    reach_u = True
                    u_wall = cord
                elif material == PORTAL_INDEX:
                    reach_u = True
                    if closest_portal is None:
                        closest_portal = cord
                elif material == ENEMY_INDEX:
                    if closest_enemy is None:
                        closest_enemy = cord

                if reach_u:
                    break

        if not reach_d:
            cords = [(point[0] + i, point[1]) for point in down_search_points]
            for cord in cords:
                material = state[cord]
                if material == WALL_INDEX or cord[0] == state.shape[0] - 1:
                    reach_d = True
                    d_wall = cord
                if material == PORTAL_INDEX:
                    reach_d = True
                    if closest_portal is None:
                        closest_portal = cord
                elif material == ENEMY_INDEX:
                    if closest_enemy is None:
                        closest_enemy = cord

                if reach_d:
                    break

        if reach_l and reach_r and reach_u and reach_d:
            break
        i += 1

    return l_wall, r_wall, u_wall, d_wall, closest_enemy, closest_portal

In [6]:
@numba.njit
def cut_empty_layers_in_frame(frame):
    skip_layers = 0
    while True:
        if np.array_equal(frame[skip_layers][skip_layers], EMPTY_COLOR):
            skip_layers += 1
        else:
            break
    frame = frame[skip_layers:-skip_layers, skip_layers:-skip_layers]

    skip_layers_from_bottom = 0
    while True:
        if np.array_equal(frame[-(skip_layers_from_bottom + 1)][-(skip_layers_from_bottom + 1)], EMPTY_COLOR):
            skip_layers_from_bottom += 1
        else:
            break
    frame = frame[:-skip_layers_from_bottom, :]
    return frame


@numba.njit
def from_frame_to_feature_state(frame):
    frame = cut_empty_layers_in_frame(frame)


## State representation class

In [7]:
VECTOR_STATE_SIZE = 16


class State:
    def __init__(self, frame):
        frame = cut_empty_layers_in_frame(frame)
        state, enemy_pixels = rgb_to_index(frame)
        self.is_empty = len(state) == 0 or len(state[0]) == 0
        if not self.is_empty:
            fill_holes(state)
            self.state = state
            self.state_h = state.shape[0]
            self.state_w = state.shape[1]
            self.player_box = find_player_box(state)
            obs = basic_observation(state, self.player_box) if not self.player_box is None else None

            self.left_wall = obs[0] if obs is not None and obs[0] is not None else (0, 0)
            self.right_wall = obs[1] if obs is not None and obs[1] is not None else (0, 0)
            self.up_wall = obs[2] if obs is not None and obs[2] is not None else (0, 0)
            self.down_wall = obs[3] if obs is not None and obs[3] is not None else (0, 0)
            self.closest_enemy = obs[4] if obs is not None and obs[4] is not None else None
            self.closest_portal = obs[5] if obs is not None and obs[5] is not None else None
            self.enemies = np.round(float(enemy_pixels) / AVG_PIXELS_IN_ENEMY).astype(np.int32)

            self.area = self.state_h * self.state_w
        else:
            self.player_box = None
            self.area = 0

    def has_enemy(self):
        return self.enemies > 0 or self.closest_enemy is not None

    def percentage_from_area(self, pixels):
        if self.is_empty or self.area == 0:
            return 0.0
        return pixels / self.area

    def get_direction_on_closest_wall(self):
        center_of_player = self.center_of_player()
        up_dist = center_of_player[0] - self.up_wall[0]
        down_dist = self.down_wall[0] - center_of_player[0]
        left_dist = center_of_player[1] - self.left_wall[1]
        right_dist = self.right_wall[1] - center_of_player[1]

        dists = [up_dist, right_dist, left_dist, down_dist]
        min_dist = min(dists)
        return dists.index(min_dist)

    def get_distance_to_closest_wall(self):
        center_of_player = self.center_of_player()
        up_dist = center_of_player[0] - self.up_wall[0]
        down_dist = self.down_wall[0] - center_of_player[0]
        left_dist = center_of_player[1] - self.left_wall[1]
        right_dist = self.right_wall[1] - center_of_player[1]

        dists = [up_dist, right_dist, left_dist, down_dist]
        min_dist = min(dists)
        return min_dist

    def as_vector(self):
        player_center = self.center_of_player()
        px_norm = player_center[0] / self.state_h
        py_norm = player_center[1] / self.state_w

        up_wall_dist = player_center[0] - self.up_wall[0]
        down_wall_dist = self.down_wall[0] - player_center[0]
        left_wall_dist = player_center[1] - self.left_wall[1]
        right_wall_dist = self.right_wall[1] - player_center[1]

        dist_up_norm = up_wall_dist / self.state_h
        dist_down_norm = down_wall_dist / self.state_h
        dist_left_norm = left_wall_dist / self.state_w
        dist_right_norm = right_wall_dist / self.state_w

        epsilon = 1e-3
        inv_up = 1.0 / (dist_up_norm + epsilon)
        inv_down = 1.0 / (dist_down_norm + epsilon)
        inv_left = 1.0 / (dist_left_norm + epsilon)
        inv_right = 1.0 / (dist_right_norm + epsilon)

        enemy_x_norm = self.closest_enemy[0] / self.state_h if self.closest_enemy is not None else 0
        enemy_y_norm = self.closest_enemy[1] / self.state_w if self.closest_enemy is not None else 0
        enemy_visible = 1.0 if self.closest_enemy is not None else 0.0

        enemy_count = self.enemies / MAX_ENEMIES
        portal_x_norm = self.closest_portal[0] / self.state_h if self.closest_portal is not None else 0
        portal_y_norm = self.closest_portal[1] / self.state_w if self.closest_portal is not None else 0

        return np.array([px_norm, py_norm,
                         dist_up_norm, dist_down_norm,
                         dist_left_norm, dist_right_norm,
                         inv_up, inv_down,
                         inv_left, inv_right,
                         enemy_x_norm, enemy_y_norm,
                         enemy_visible, enemy_count, portal_x_norm, portal_y_norm], dtype=np.float32)

    def center_of_player(self):
        if self.player_box is None:
            return 0, 0
        i_center = (self.player_box[0] + self.player_box[2]) // 2
        j_center = (self.player_box[1] + self.player_box[3]) // 2
        return i_center, j_center

    def distance_from_player(self, cord, failure_val=-1):
        if self.is_empty or self.player_box is None:
            return failure_val

        player_x, player_y = self.center_of_player()
        return np.sqrt((player_x - cord[0]) ** 2 + (player_y - cord[1]) ** 2)

    def distance_to_closest_border(self):
        player_x, player_y = self.center_of_player()
        up_dist = player_x - self.up_wall[0]
        down_dist = self.down_wall[0] - player_x
        left_dist = player_y - self.left_wall[1]
        right_dist = self.right_wall[1] - player_y
        return min(up_dist, down_dist, left_dist, right_dist)

    def direction_to_player(self, cord):
        i, j = cord
        left_up_corner = (self.player_box[0], self.player_box[1])
        left_down_corner = (self.player_box[2], self.player_box[1])
        right_up_corner = (self.player_box[0], self.player_box[3])
        right_down_corner = (self.player_box[2], self.player_box[3])

        if i < left_up_corner[0]:
            if j < left_up_corner[1]:
                return UPLEFT_DIRECTION
            elif j > right_up_corner[1]:
                return UPRIGHT_DIRECTION
            else:
                return UP_DIRECTION
        elif i > left_down_corner[0]:
            if j < left_down_corner[1]:
                return DOWNLEFT_DIRECTION
            elif j > right_down_corner[1]:
                return DOWNRIGHT_DIRECTION
            else:
                return DOWN_DIRECTION
        else:
            if j < left_up_corner[1]:
                return LEFT_DIRECTION
            elif j > right_up_corner[1]:
                return RIGHT_DIRECTION
            else:
                return -1


In [8]:
@numba.njit
def mark_scanned_line(mask, r0, c0, r1, c1):
    """
    Marks pixels along a scan ray as 'seen'.
    Returns number of newly seen pixels.
    """
    new_pixels = 0
    if r0 == r1: # Horizontal line
        c_start = min(c0, c1)
        c_end = max(c0, c1)
        for c in range(c_start, c_end + 1):
             if 0 <= r0 < mask.shape[0] and 0 <= c < mask.shape[1]:
                if mask[r0, c] == 0:
                    mask[r0, c] = 2 # 2 = Scanned/Seen
                    new_pixels += 1
    elif c0 == c1: # Vertical line
        r_start = min(r0, r1)
        r_end = max(r0, r1)
        for r in range(r_start, r_end + 1):
            if 0 <= r < mask.shape[0] and 0 <= c0 < mask.shape[1]:
                if mask[r, c0] == 0:
                    mask[r, c0] = 2
                    new_pixels += 1
    return new_pixels

@numba.njit
def mark_scanned_traces_from_box(mask, box, up_wall, down_wall, left_wall, right_wall):
    """
    Marks the scan lines from the player box to the walls.
    Returns number of newly seen pixels.
    """
    left_points, right_points, up_points, down_points = get_scanning_points(box)
    new_pixels = 0

    for point in left_points:
        new_pixels += mark_scanned_line(mask, point[0], point[1], left_wall[0], left_wall[1])
    for point in right_points:
        new_pixels += mark_scanned_line(mask, point[0], point[1], right_wall[0], right_wall[1])
    for point in up_points:
        new_pixels += mark_scanned_line(mask, point[0], point[1], up_wall[0], up_wall[1])
    for point in down_points:
        new_pixels += mark_scanned_line(mask, point[0], point[1], down_wall[0], down_wall[1])

    return new_pixels

@numba.njit
def mark_pixels_in_box(mask, r_min, c_min, r_max, c_max):
    """
    Marks the rectangular area covered by the player.
    Returns the number of *newly* visited pixels.
    """
    r_min = max(0, min(r_min, mask.shape[0]))
    c_min = max(0, min(c_min, mask.shape[1]))
    r_max = max(0, min(r_max, mask.shape[0]))
    c_max = max(0, min(c_max, mask.shape[1]))

    new_pixels = 0
    for r in range(r_min, r_max):
        for c in range(c_min, c_max):
            if mask[r, c] == 0:
                mask[r, c] = 1 # 1 = Visited by Player Body
                new_pixels += 1
    return new_pixels

class ExplorationTracker:
    def __init__(self, h, w):
        self.height = h
        self.width = w
        self.space = np.zeros((w, h), dtype=np.int32)

    def reset(self):
        self.space = np.zeros((self.width, self.height), dtype=np.int32)

    def cover(self, state: State):
        """
        :return: new visited pixels, new scanned pixels
        """

        if state.player_box is None:
            return 0, 0

        r_min, c_min, r_max, c_max = state.player_box
        new_visited = mark_pixels_in_box(self.space, r_min, c_min, r_max, c_max)

        new_scanned = mark_scanned_traces_from_box(self.space, state.player_box, state.up_wall, state.down_wall, state.left_wall, state.right_wall)

        return new_visited, new_scanned

# Training

## SARSA agent class

In [9]:
class Sarsa:
    alpha = 1e-5
    gamma = 0.99
    epsilon = 1
    lmbda = 0.9
    weight_decay = 1e-5
    z_clip = 10.0

    def __init__(self, n_actions):
        self.state_dim = VECTOR_STATE_SIZE
        self.n_actions = n_actions
        self.w = np.zeros((self.n_actions, self.state_dim), dtype=np.float32)
        self.b = np.zeros(self.n_actions, dtype=np.float32)
        self.z_w = np.zeros_like(self.w, dtype=np.float32)
        self.z_b = np.zeros((self.n_actions,), dtype=np.float32)

    def phi_from_state_action(self, features, action):
        phi_w = np.zeros_like(self.w, dtype=np.float32)
        phi_w[action] = features.astype(np.float32)
        phi_b = np.zeros_like(self.b, dtype=np.float32)
        phi_b[action] = 1.0
        return phi_w, phi_b

    def _q_values_all_actions(self, state_features):
        return self.w.dot(state_features.astype(np.float32)) + self.b

    def epsilon_greedy(self, features):
        eps = float(getattr(self, "epsilon", 0.0))
        eps = max(0.0, min(1.0, eps))

        q_vals = self._q_values_all_actions(features)

        if np.random.rand() < eps:
            action = np.random.randint(self.n_actions)
        else:
            action = np.argmax(q_vals)

        return action, q_vals

    def save(self, file_name="sarsa_weights.npz"):
        np.savez(file_name, w=self.w.reshape(-1), b=self.b, n_actions=self.n_actions, state_dim=self.state_dim)

    def restrict_exploration(self):
        self.epsilon = 0.0

    def reset_traces(self):
        self.z_w.fill(0.0)
        self.z_b.fill(0.0)

    @staticmethod
    def load(file_name="sarsa_weights.npz"):
        data = np.load(file_name)
        n_actions = int(data['n_actions'])
        state_dim = int(data['state_dim'])
        agent = Sarsa(n_actions)
        agent.state_dim = state_dim
        agent.w = data['w'].reshape((n_actions, state_dim)).astype(np.float32)
        agent.b = data['b'].astype(np.float32) if 'b' in data else np.zeros(n_actions, dtype=np.float32)
        agent.z_w = np.zeros_like(agent.w, dtype=np.float32)
        agent.z_b = np.zeros_like(agent.b, dtype=np.float32)
        return agent

In [10]:
def file_exist(file_name):
    try:
        _ = np.load(file_name)
        return True
    except FileNotFoundError:
        return False

## Examinator class for reward shaping

In [11]:
class Examinator:
    NOT_SHOOT_ENEMY_PENALTY = -1.2  # reduced from -1.5
    LIVING_PENALTY = -0.04  # reduced from -0.1
    DEATH_PENALTY = -25  # increased from -12
    STAY_IN_DANGER_PENALTY = -0.4  # reduced from -0.5
    FIRE_ENEMY_BONUS = 1.2  # reduced from 6.0 (or set to 0.0)
    FAR_FROM_WALL_BONUS = 0.15  # reduced from 0.03
    MOVE_BONUS = 0.015
    BONUS_FOR_NOT_SHOOT_NOWHERE = 0.6
    GO_TO_WALL_PENALTY = -0.15
    SHOOT_WHEN_NO_ENEMIES_PENALTY = -0.3
    STABLE_BONUS_ON_CLEARED_LEVEL = 0.05
    BONUS_FOR_SCANNED_PIXEL = 0.0002
    BONUS_FOR_VISITED_PIXEL = 0.005


    def __init__(self):
        pass

    def examine(self, state, action, model, reward, prev_state, scanned_pixels=0, visited_pixels=0):
        shaped_reward = reward / 20

        shaped_reward += self.LIVING_PENALTY

        if state.closest_enemy is not None:
            enemy_dir = state.direction_to_player(state.closest_enemy)
            required_action = FIRE_ACTIONS[1] + enemy_dir
            if action != required_action and reward <= 0:
                shaped_reward += Examinator.NOT_SHOOT_ENEMY_PENALTY
            if action == 0:
                shaped_reward += Examinator.STAY_IN_DANGER_PENALTY
            if action == required_action:
                shaped_reward += Examinator.FIRE_ENEMY_BONUS
        else:
            if action not in FIRE_ACTIONS and action != 0:
                shaped_reward += Examinator.BONUS_FOR_NOT_SHOOT_NOWHERE

        alive = state.player_box is not None
        if not alive:
            shaped_reward += Examinator.DEATH_PENALTY

        min_distance_to_wall = min(
            state.center_of_player()[0] - state.up_wall[0],
            state.down_wall[0] - state.center_of_player()[0],
            state.center_of_player()[1] - state.left_wall[1],
            state.right_wall[1] - state.center_of_player()[1]
        )
        if min_distance_to_wall > 15:
            shaped_reward += Examinator.FAR_FROM_WALL_BONUS

        if action in MOVE_ACTIONS:
            shaped_reward += Examinator.BONUS_FOR_SCANNED_PIXEL * scanned_pixels
            shaped_reward += Examinator.BONUS_FOR_VISITED_PIXEL * visited_pixels

            dir_to_closest_wall = state.get_direction_on_closest_wall()
            action_components = ACTION_TO_DIRECTIONS.get(action, [])

            if dir_to_closest_wall in action_components:
                # Calculate how dangerous this move is based on distance
                dist_to_wall = state.distance_to_closest_border()

                # If very close (e.g., < 15 pixels), apply heavy penalty
                if dist_to_wall < 15:
                    # The closer we are, the higher the penalty.
                    # Example: at 1px dist, penalty is -1.0. At 15px, it is -0.06
                    proximity_penalty = -1.0 * (15.0 / (dist_to_wall + 1.0))
                    shaped_reward += proximity_penalty
                else:
                    shaped_reward += Examinator.GO_TO_WALL_PENALTY  # Standard small penalty

        if state.enemies == 0:
            if not alive:
                shaped_reward += Examinator.DEATH_PENALTY
                return shaped_reward
            else:
                shaped_reward += Examinator.STABLE_BONUS_ON_CLEARED_LEVEL

            if action in FIRE_ACTIONS:
                shaped_reward += Examinator.SHOOT_WHEN_NO_ENEMIES_PENALTY

            # rewarding to be near walls when no enemies to find portals
            distance_to_closest_wall = state.distance_to_closest_border()
            if 15 < distance_to_closest_wall < 30:
                shaped_reward += 0.15
            elif 30 <= distance_to_closest_wall < 40:
                shaped_reward += 0.07

            if state.closest_portal is not None:
                #bonus for finding portal
                if prev_state.closest_portal is None:
                    shaped_reward += 15.0
                else:
                    prev_distance = prev_state.distance_from_player(prev_state.closest_portal)
                    curr_distance = state.distance_from_player(state.closest_portal)
                    if curr_distance < prev_distance:
                        shaped_reward += 1

            if state.closest_portal is None and prev_state.closest_portal is not None:
                #penalty for losing portal
                shaped_reward += -5.0

        return shaped_reward


In [12]:
class LastActionTracker:
    def __init__(self, space_size):
        self.space_size = space_size
        self.actions = []

    def rec(self, action):
        self.actions.append(action)
        if len(self.actions) > self.space_size:
            self.actions.pop(0)

    def last_same_count(self):
        if not self.actions:
            return 0
        last_action = self.actions[-1]
        count = 0
        for action in reversed(self.actions):
            if action == last_action:
                count += 1
            else:
                break
        return count

## Trainer class

In [13]:
class Trainer:
    def __init__(self, epsilon_min=0.05, epsilon_decay_fraction=0.999, initial_epsilon=1.0):
        self.epsilon_min = epsilon_min
        self.epsilon_decay_fraction = epsilon_decay_fraction
        self.initial_epsilon = initial_epsilon

    @staticmethod
    def _file_name_for_class(class_name):
        return f"sarsa-weights-{class_name.lower()}.npz"

    def train_if_needed(self, model, env, class_name, n_episodes=1000):
        file_name = Trainer._file_name_for_class(class_name)
        print(f'Checking for existing model file: {file_name}')
        if not file_exist(file_name):
            self.train(model, env, class_name, n_episodes)
            return model

        return Sarsa.load(file_name)

    def train(self, model, env, class_name, n_episodes=1000):
        print(f"Training {class_name} agent...")
        action_counts = np.zeros(env.action_space.n, dtype=np.float32)

        examinator = Examinator()

        rewards = []
        w_changes = []
        previous_w = model.w.copy()

        model.epsilon = self.initial_epsilon
        decay_episodes = int(n_episodes * self.epsilon_decay_fraction)
        if decay_episodes > 0:
            epsilon_decay_step = (self.initial_epsilon - self.epsilon_min) / decay_episodes
        else:
            epsilon_decay_step = 0
        print(f"Epsilon will decay from {self.initial_epsilon} to {self.epsilon_min} over {decay_episodes} episodes.")

        log_step = max(1, n_episodes // 100)
        action_duplicate_tolerance = 8

        scanned_pixels_by_episode_percentage = []
        visited_pixels_by_episode_percentage = []

        for episode in range(n_episodes):
            _ = env.reset()

            for j in range(0, 6):  # skip initial no-op frames
                _ = env.step(0)

            last_action_tracker = LastActionTracker(space_size=action_duplicate_tolerance)
            exploration_tracker = ExplorationTracker(160, 210)

            state = env.render()
            featured_state = State(state)
            state_vector = featured_state.as_vector()
            distance_to_closest_enemy = featured_state.distance_from_player(
                featured_state.closest_enemy) if featured_state.closest_enemy is not None else -1
            model.reset_traces()

            action, q_values = model.epsilon_greedy(state_vector)
            action_counts[action] += 1

            done = False
            ep_reward = 0

            visited_pixels_percantages = []
            scanned_pixels_percantages = []

            while not done:
                next_state, reward, terminated, truncated, _ = env.step(action)
                next_featured_state = State(next_state)

                #end episode if no player box (death) or empty state
                if next_featured_state.is_empty:
                    done = True
                    shaped_reward = reward + Examinator.DEATH_PENALTY
                    q_next = 0.0
                    next_action = None
                else:
                    done = terminated or truncated or next_featured_state.player_box is None
                    next_features = next_featured_state.as_vector()
                    next_action, next_q_values = model.epsilon_greedy(next_features)

                    visited_pixels, scanned_pixels = exploration_tracker.cover(next_featured_state)

                    visited_percentage = next_featured_state.percentage_from_area(visited_pixels)
                    scanned_percentage = next_featured_state.percentage_from_area(scanned_pixels)
                    visited_pixels_percantages.append(visited_percentage)
                    scanned_pixels_percantages.append(scanned_percentage)

                    shaped_reward = examinator.examine(next_featured_state, action, model, reward, featured_state, scanned_pixels, visited_pixels)

                    if last_action_tracker.last_same_count() >= action_duplicate_tolerance and action == next_action:
                        shaped_reward += -0.2

                    new_distance_to_closest_enemy = next_featured_state.distance_from_player(
                        next_featured_state.closest_enemy) if next_featured_state.closest_enemy is not None else -1

                    if new_distance_to_closest_enemy != -1 and distance_to_closest_enemy != -1 and distance_to_closest_enemy - new_distance_to_closest_enemy > 0 and new_distance_to_closest_enemy < 20:
                        shaped_reward += -0.02
                    distance_to_closest_enemy = new_distance_to_closest_enemy

                    q_next = 0.0 if done else next_q_values[next_action]

                q = q_values[action]
                delta = shaped_reward + model.gamma * q_next - q

                phi_w, phi_b = model.phi_from_state_action(state_vector, action)

                model.z_w = (model.gamma * model.lmbda * model.z_w) + phi_w
                model.z_b = (model.gamma * model.lmbda * model.z_b) + phi_b

                model.w += model.alpha * delta * model.z_w
                model.b += model.alpha * delta * model.z_b

                model.w *= (1.0 - model.weight_decay)
                model.b *= (1.0 - model.weight_decay)

                model.z_w = np.clip(model.z_w, -model.z_clip, model.z_clip)
                model.z_b = np.clip(model.z_b, -model.z_clip, model.z_clip)

                if not done:
                    featured_state = next_featured_state
                    state_vector = next_features
                    action = next_action
                    q_values = next_q_values
                    action_counts[action] += 1
                    last_action_tracker.rec(action)

                ep_reward += reward

            scanned_pixels_by_episode_percentage.append(np.max(scanned_pixels_percantages) if scanned_pixels_percantages else 0.0)
            visited_pixels_by_episode_percentage.append(np.max(visited_pixels_percantages) if visited_pixels_percantages else 0.0)

            if episode > 0 and episode % 5 == 0:
                action_freq = action_counts / max(1, action_counts.sum())
                action_entropy = -np.sum(action_freq * np.log(action_freq + 1e-10))
                target_entropy = np.log(env.action_space.n) * 0.65

                if action_entropy < target_entropy and episode < decay_episodes:
                    most_used = np.argmax(action_counts)
                    second_most = np.argsort(action_counts)[-2]

                    # Penalize top 2 most-used actions
                    model.b[most_used] *= 0.85
                    model.b[second_most] *= 0.92

                    # Boost least-used actions
                    least_used_indices = np.where(action_freq < 0.02)[0]
                    for idx in least_used_indices:
                        model.b[idx] *= 1.05

                    if episode % 50 == 0:
                        print(
                            f"  [Episode {episode}] Entropy={action_entropy:.2f}, most_used={most_used} ({action_freq[most_used] * 100:.1f}%), bias_penalty applied")

            model.epsilon = max(self.epsilon_min, model.epsilon - epsilon_decay_step)

            w_change = np.mean(np.abs(model.w - previous_w))
            w_changes.append(w_change)
            previous_w = model.w.copy()
            rewards.append(ep_reward)

            if (episode + 1) % log_step == 0:
                recent_max = float(np.max(rewards[-log_step:])) if len(rewards) > 0 else float(ep_reward)
                print(
                    f"Episode {episode + 1}/{n_episodes}: Max reward for period={recent_max:.2f}, Eps={model.epsilon:.4f}")

        px.line(x=np.arange(1, n_episodes + 1), y=w_changes, labels={'x': 'Episode', 'y': 'Mean |Δw|'},
                title='Mean Weight Change over Episodes').show()
        px.line(x=np.arange(1, n_episodes + 1), y=rewards, labels={'x': 'Episode', 'y': 'Reward'},
                title='Episode Rewards over Time').show()

        px.line(x=np.arange(1, n_episodes + 1), y=scanned_pixels_by_episode_percentage, labels={'x': 'Episode', 'y': 'Scanned Pixels Percentage'}, title='Scanned Pixels Percentage over Episodes').show()
        px.line(x=np.arange(1, n_episodes + 1), y=visited_pixels_by_episode_percentage, labels={'x': 'Episode', 'y': 'Visited Pixels Percentage'}, title='Visited Pixels Percentage over Episodes').show()

        action_freq = action_counts / action_counts.sum()
        entropy = -np.sum(action_freq * np.log(action_freq + 1e-10))
        print(f'Action distribution during training: {action_counts}')
        print(f'Action entropy: {entropy:.3f} (max={np.log(env.action_space.n):.3f})')
        print(f'Most used action: {np.argmax(action_counts)} ({action_counts.max() / action_counts.sum() * 100:.1f}%)')
        print(f"Training completed. Max score ever: {np.max(rewards)}")

        model.save(self._file_name_for_class(class_name))


## Train run

In [14]:
CLASS_NAME = "Berzerk"

ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
agent = Sarsa(env.action_space.n)
observation, info = env.reset()

In [15]:

epsilon_min = 0.1

trainer = Trainer(epsilon_min=0.05, epsilon_decay_fraction=0.95, initial_epsilon=1.0)
agent = trainer.train_if_needed(agent, env, class_name=CLASS_NAME, n_episodes=1000)

env.close()

Checking for existing model file: sarsa-weights-berzerk.npz
Training Berzerk agent...
Epsilon will decay from 1.0 to 0.05 over 950 episodes.
Episode 10/1000: Max reward for period=100.00, Eps=0.9900
Episode 20/1000: Max reward for period=100.00, Eps=0.9800
Episode 30/1000: Max reward for period=300.00, Eps=0.9700
Episode 40/1000: Max reward for period=300.00, Eps=0.9600
Episode 50/1000: Max reward for period=100.00, Eps=0.9500
Episode 60/1000: Max reward for period=200.00, Eps=0.9400
Episode 70/1000: Max reward for period=300.00, Eps=0.9300
Episode 80/1000: Max reward for period=100.00, Eps=0.9200
Episode 90/1000: Max reward for period=300.00, Eps=0.9100
Episode 100/1000: Max reward for period=50.00, Eps=0.9000
Episode 110/1000: Max reward for period=150.00, Eps=0.8900
Episode 120/1000: Max reward for period=300.00, Eps=0.8800
Episode 130/1000: Max reward for period=150.00, Eps=0.8700
Episode 140/1000: Max reward for period=300.00, Eps=0.8600
Episode 150/1000: Max reward for period=150

Action distribution during training: [ 3881.  4311.  7170. 21934. 13148. 12313.  4384.  4334.  4701.  4892.
  4583.  5063.  4483. 22520.  9795.  5673.  4908.  3999.]
Action entropy: 2.678 (max=2.890)
Most used action: 13 (15.8%)
Training completed. Max score ever: 300.0


# Test

In [16]:
agent = Sarsa.load(Trainer._file_name_for_class(CLASS_NAME))

In [17]:
ale = ALEInterface()
gym.register_envs(ale)

test_env = gym.make("ALE/Berzerk-v5", render_mode="human", frameskip=4)
agent.restrict_exploration()

In [18]:
n_episodes = 5
total_rewards = []

# agent.w = normalize_weights(agent.w)

for ep in range(n_episodes):
    state, _ = test_env.reset()
    done = False
    ep_reward = 0

    actions_count = np.zeros(test_env.action_space.n, dtype=np.int32)
    while not done:
        feature_state = State(state)
        if feature_state.is_empty:
            action = 0
        else:
            action, _ = agent.epsilon_greedy(feature_state.as_vector())
        actions_count[action] += 1
        next_state, reward, terminated, truncated, _ = test_env.step(action)
        done = terminated or truncated

        state = next_state
        ep_reward += reward

    test_env.render()
    print(f"Episode {ep + 1}: Total Reward = {ep_reward}")
    print(f'Action count during round: {actions_count}')
    print('---------------------------')
    total_rewards.append(ep_reward)

test_env.close()

print(f"\nAverage Test Reward over {n_episodes} episodes: {np.mean(total_rewards):.2f}")

Episode 1: Total Reward = 350.0
Action count during round: [73  0 92 15  0  0 37  0  0  0  0  0  0 47  0  0  2  0]
---------------------------
Episode 2: Total Reward = 500.0
Action count during round: [251   0  53  70   0   0   8   4   0   0  17   0   0  54   0  12  32   0]
---------------------------
Episode 3: Total Reward = 400.0
Action count during round: [ 61   0 110  11   0   0  42   0   0   0   2   0   0  60   0   0  13   0]
---------------------------
Episode 4: Total Reward = 300.0
Action count during round: [ 72   0 108  21   0   0   8   0   0   0   4   0   0  54   0  35  14   0]
---------------------------
Episode 5: Total Reward = 450.0
Action count during round: [134   0  58   7   0   0   8   0   0   0   4   0   0  89   0   1   2   0]
---------------------------

Average Test Reward over 5 episodes: 400.00
