In [1]:
from ale_py import ALEInterface

from env_info import *
from sarsa import Sarsa
from utils import file_exist
from exploration_tracker import ExplorationTracker
from state import State

import numpy as np
import plotly.express as px
import gymnasium as gym

## Examinator class

In [2]:
class Examinator:
    ENV_REWARD_DESCALE = 20

    NOT_SHOOT_ENEMY_PENALTY = -1.3051  # reduced from -1.5
    LIVING_PENALTY = -0.016801  # reduced from -0.1
    DEATH_PENALTY = -25  # increased from -12
    STAY_IN_DANGER_PENALTY = -0.4  # reduced from -0.5
    FIRE_ENEMY_BONUS = 1.2  # reduced from 6.0 (or set to 0.0)
    FAR_FROM_WALL_BONUS = 0.15  # reduced from 0.03
    MOVE_BONUS = 0.015
    BONUS_FOR_NOT_SHOOT_NOWHERE = 0.62
    GO_TO_WALL_PENALTY = -0.15
    SHOOT_WHEN_NO_ENEMIES_PENALTY = -0.3
    STABLE_BONUS_ON_CLEARED_LEVEL = 0.05
    BONUS_FOR_SCANNED_PIXEL = 0.0002
    BONUS_FOR_VISITED_PIXEL = 0.005


    def __init__(self):
        pass

    def examine(self, state, action, model, reward, prev_state, scanned_pixels=0, visited_pixels=0):
        shaped_reward = reward / Examinator.ENV_REWARD_DESCALE

        shaped_reward += self.LIVING_PENALTY

        if state.closest_enemy is not None:
            enemy_dir = state.direction_to_player(state.closest_enemy)
            required_action = FIRE_ACTIONS[1] + enemy_dir
            if action != required_action and reward <= 0:
                shaped_reward += Examinator.NOT_SHOOT_ENEMY_PENALTY
            if action == 0:
                shaped_reward += Examinator.STAY_IN_DANGER_PENALTY
            if action == required_action:
                shaped_reward += Examinator.FIRE_ENEMY_BONUS
        else:
            if action not in FIRE_ACTIONS and action != 0:
                shaped_reward += Examinator.BONUS_FOR_NOT_SHOOT_NOWHERE

        alive = state.player_box is not None
        if not alive:
            shaped_reward += Examinator.DEATH_PENALTY

        min_distance_to_wall = min(
            state.center_of_player()[0] - state.up_wall[0],
            state.down_wall[0] - state.center_of_player()[0],
            state.center_of_player()[1] - state.left_wall[1],
            state.right_wall[1] - state.center_of_player()[1]
        )
        if min_distance_to_wall > 15:
            shaped_reward += Examinator.FAR_FROM_WALL_BONUS

        if action in MOVE_ACTIONS:
            shaped_reward += Examinator.BONUS_FOR_SCANNED_PIXEL * scanned_pixels
            shaped_reward += Examinator.BONUS_FOR_VISITED_PIXEL * visited_pixels

            dir_to_closest_wall = state.get_direction_on_closest_wall()
            action_components = ACTION_TO_DIRECTIONS.get(action, [])

            if dir_to_closest_wall in action_components:
                # Calculate how dangerous this move is based on distance
                dist_to_wall = state.distance_to_closest_border()

                # If very close (e.g., < 15 pixels), apply heavy penalty
                if dist_to_wall < 15:
                    # The closer we are, the higher the penalty.
                    # Example: at 1px dist, penalty is -1.0. At 15px, it is -0.06
                    proximity_penalty = -1.0 * (15.0 / (dist_to_wall + 1.0))
                    shaped_reward += proximity_penalty
                else:
                    shaped_reward += Examinator.GO_TO_WALL_PENALTY  # Standard small penalty

        if state.enemies == 0:
            if not alive:
                shaped_reward += Examinator.DEATH_PENALTY
                return shaped_reward
            else:
                shaped_reward += Examinator.STABLE_BONUS_ON_CLEARED_LEVEL

            if action in FIRE_ACTIONS:
                shaped_reward += Examinator.SHOOT_WHEN_NO_ENEMIES_PENALTY

            # rewarding to be near walls when no enemies to find portals
            distance_to_closest_wall = state.distance_to_closest_border()
            if 15 < distance_to_closest_wall < 30:
                shaped_reward += 0.15
            elif 30 <= distance_to_closest_wall < 40:
                shaped_reward += 0.07

            if state.closest_portal is not None:
                #bonus for finding portal
                if prev_state.closest_portal is None:
                    shaped_reward += 15.0
                else:
                    prev_distance = prev_state.distance_from_player(prev_state.closest_portal)
                    curr_distance = state.distance_from_player(state.closest_portal)
                    if curr_distance < prev_distance:
                        shaped_reward += 1

            if state.closest_portal is None and prev_state.closest_portal is not None:
                #penalty for losing portal
                shaped_reward += -5.0

        return shaped_reward


In [3]:
class LastActionTracker:
    def __init__(self, space_size):
        self.space_size = space_size
        self.actions = []

    def rec(self, action):
        self.actions.append(action)
        if len(self.actions) > self.space_size:
            self.actions.pop(0)

    def last_same_count(self):
        if not self.actions:
            return 0
        last_action = self.actions[-1]
        count = 0
        for action in reversed(self.actions):
            if action == last_action:
                count += 1
            else:
                break
        return count

## Trainer class

In [4]:
class Trainer:
    def __init__(self, epsilon_min=0.05, epsilon_decay_fraction=0.999, initial_epsilon=1.0, alpha=1e-5):
        self.epsilon_min = epsilon_min
        self.epsilon_decay_fraction = epsilon_decay_fraction
        self.initial_epsilon = initial_epsilon

    @staticmethod
    def _file_name_for_class(class_name):
        return f"sarsa-weights-{class_name.lower()}.npz"

    def train_if_needed(self, model, env, class_name, n_episodes=1000):
        file_name = Trainer._file_name_for_class(class_name)
        print(f'Checking for existing model file: {file_name}')
        if not file_exist(file_name):
            self.train(model, env, class_name, n_episodes)
            return model

        return Sarsa.load(file_name)

    def train(self, model, env, class_name, n_episodes=1000):
        print(f"Training {class_name} agent...")
        action_counts = np.zeros(env.action_space.n, dtype=np.float32)

        examinator = Examinator()

        rewards = []
        w_changes = []
        previous_w = model.w.copy()

        model.epsilon = self.initial_epsilon
        decay_episodes = int(n_episodes * self.epsilon_decay_fraction)
        if decay_episodes > 0:
            epsilon_decay_step = (self.initial_epsilon - self.epsilon_min) / decay_episodes
        else:
            epsilon_decay_step = 0
        print(f"Epsilon will decay from {self.initial_epsilon} to {self.epsilon_min} over {decay_episodes} episodes.")

        log_step = max(1, n_episodes // 100)
        action_duplicate_tolerance = 8

        scanned_pixels_by_episode_percentage = []
        visited_pixels_by_episode_percentage = []

        for episode in range(n_episodes):
            _ = env.reset()

            for j in range(0, 6):  # skip initial no-op frames
                _ = env.step(0)

            last_action_tracker = LastActionTracker(space_size=action_duplicate_tolerance)
            exploration_tracker = ExplorationTracker(160, 210)

            state = env.render()
            featured_state = State(state)
            state_vector = featured_state.as_vector()
            distance_to_closest_enemy = featured_state.distance_from_player(
                featured_state.closest_enemy) if featured_state.closest_enemy is not None else -1
            model.reset_traces()

            action, q_values = model.epsilon_greedy(state_vector)
            action_counts[action] += 1

            done = False
            ep_reward = 0

            visited_pixels_percantages = []
            scanned_pixels_percantages = []

            while not done:
                next_state, reward, terminated, truncated, _ = env.step(action)
                next_featured_state = State(next_state)

                #end episode if no player box (death) or empty state
                if next_featured_state.is_empty:
                    done = True
                    shaped_reward = reward + Examinator.DEATH_PENALTY
                    q_next = 0.0
                    next_action = None
                else:
                    done = terminated or truncated or next_featured_state.player_box is None
                    next_features = next_featured_state.as_vector()
                    next_action, next_q_values = model.epsilon_greedy(next_features)

                    visited_pixels, scanned_pixels = exploration_tracker.cover(next_featured_state)

                    visited_percentage = next_featured_state.percentage_from_area(visited_pixels)
                    scanned_percentage = next_featured_state.percentage_from_area(scanned_pixels)
                    visited_pixels_percantages.append(visited_percentage)
                    scanned_pixels_percantages.append(scanned_percentage)

                    shaped_reward = examinator.examine(next_featured_state, action, model, reward, featured_state, scanned_pixels, visited_pixels)

                    if last_action_tracker.last_same_count() >= action_duplicate_tolerance and action == next_action:
                        shaped_reward += -0.2

                    new_distance_to_closest_enemy = next_featured_state.distance_from_player(
                        next_featured_state.closest_enemy) if next_featured_state.closest_enemy is not None else -1

                    if new_distance_to_closest_enemy != -1 and distance_to_closest_enemy != -1 and distance_to_closest_enemy - new_distance_to_closest_enemy > 0 and new_distance_to_closest_enemy < 20:
                        shaped_reward += -0.02
                    distance_to_closest_enemy = new_distance_to_closest_enemy

                    q_next = 0.0 if done else next_q_values[next_action]

                q = q_values[action]
                delta = shaped_reward + model.gamma * q_next - q

                phi_w, phi_b = model.phi_from_state_action(state_vector, action)

                model.z_w = (model.gamma * model.lmbda * model.z_w) + phi_w
                model.z_b = (model.gamma * model.lmbda * model.z_b) + phi_b

                model.w += model.alpha * delta * model.z_w
                model.b += model.alpha * delta * model.z_b

                model.w *= (1.0 - model.weight_decay)
                model.b *= (1.0 - model.weight_decay)

                model.z_w = np.clip(model.z_w, -model.z_clip, model.z_clip)
                model.z_b = np.clip(model.z_b, -model.z_clip, model.z_clip)

                if not done:
                    featured_state = next_featured_state
                    state_vector = next_features
                    action = next_action
                    q_values = next_q_values
                    action_counts[action] += 1
                    last_action_tracker.rec(action)

                ep_reward += reward

            scanned_pixels_by_episode_percentage.append(np.max(scanned_pixels_percantages) if scanned_pixels_percantages else 0.0)
            visited_pixels_by_episode_percentage.append(np.max(visited_pixels_percantages) if visited_pixels_percantages else 0.0)

            if episode > 0 and episode % 5 == 0:
                action_freq = action_counts / max(1, action_counts.sum())
                action_entropy = -np.sum(action_freq * np.log(action_freq + 1e-10))
                target_entropy = np.log(env.action_space.n) * 0.65

                if action_entropy < target_entropy and episode < decay_episodes:
                    most_used = np.argmax(action_counts)
                    second_most = np.argsort(action_counts)[-2]

                    # Penalize top 2 most-used actions
                    model.b[most_used] *= 0.85
                    model.b[second_most] *= 0.92

                    # Boost least-used actions
                    least_used_indices = np.where(action_freq < 0.02)[0]
                    for idx in least_used_indices:
                        model.b[idx] *= 1.05

                    if episode % 50 == 0:
                        print(
                            f"  [Episode {episode}] Entropy={action_entropy:.2f}, most_used={most_used} ({action_freq[most_used] * 100:.1f}%), bias_penalty applied")

            model.epsilon = max(self.epsilon_min, model.epsilon - epsilon_decay_step)

            w_change = np.mean(np.abs(model.w - previous_w))
            w_changes.append(w_change)
            previous_w = model.w.copy()
            rewards.append(ep_reward)

            if (episode + 1) % log_step == 0:
                recent_max = float(np.max(rewards[-log_step:])) if len(rewards) > 0 else float(ep_reward)
                print(
                    f"Episode {episode + 1}/{n_episodes}: Max reward for period={recent_max:.2f}, Eps={model.epsilon:.4f}")

        px.line(x=np.arange(1, n_episodes + 1), y=w_changes, labels={'x': 'Episode', 'y': 'Mean |Δw|'},
                title='Mean Weight Change over Episodes').show()
        px.line(x=np.arange(1, n_episodes + 1), y=rewards, labels={'x': 'Episode', 'y': 'Reward'},
                title='Episode Rewards over Time').show()

        px.line(x=np.arange(1, n_episodes + 1), y=scanned_pixels_by_episode_percentage, labels={'x': 'Episode', 'y': 'Scanned Pixels Percentage'}, title='Scanned Pixels Percentage over Episodes').show()
        px.line(x=np.arange(1, n_episodes + 1), y=visited_pixels_by_episode_percentage, labels={'x': 'Episode', 'y': 'Visited Pixels Percentage'}, title='Visited Pixels Percentage over Episodes').show()

        action_freq = action_counts / action_counts.sum()
        entropy = -np.sum(action_freq * np.log(action_freq + 1e-10))
        print(f'Action distribution during training: {action_counts}')
        print(f'Action entropy: {entropy:.3f} (max={np.log(env.action_space.n):.3f})')
        print(f'Most used action: {np.argmax(action_counts)} ({action_counts.max() / action_counts.sum() * 100:.1f}%)')
        print(f"Training completed. Max score ever: {np.max(rewards)}")

        model.save(self._file_name_for_class(class_name))

## Training

In [5]:
CLASS_NAME = "Berzerk"

ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
agent = Sarsa(env.action_space.n)
observation, info = env.reset()

In [6]:
epsilon_min = 0.1

trainer = Trainer(epsilon_min=0.05, epsilon_decay_fraction=0.95, initial_epsilon=1.0)
agent = trainer.train_if_needed(agent, env, class_name=CLASS_NAME, n_episodes=10000)

env.close()

Checking for existing model file: sarsa-weights-berzerk.npz
Training Berzerk agent...
Epsilon will decay from 1.0 to 0.05 over 9500 episodes.
Episode 100/10000: Max reward for period=300.00, Eps=0.9900
Episode 200/10000: Max reward for period=300.00, Eps=0.9800
Episode 300/10000: Max reward for period=300.00, Eps=0.9700
Episode 400/10000: Max reward for period=300.00, Eps=0.9600
Episode 500/10000: Max reward for period=200.00, Eps=0.9500
Episode 600/10000: Max reward for period=300.00, Eps=0.9400
Episode 700/10000: Max reward for period=300.00, Eps=0.9300
Episode 800/10000: Max reward for period=300.00, Eps=0.9200
Episode 900/10000: Max reward for period=200.00, Eps=0.9100
Episode 1000/10000: Max reward for period=300.00, Eps=0.9000
Episode 1100/10000: Max reward for period=300.00, Eps=0.8900
Episode 1200/10000: Max reward for period=200.00, Eps=0.8800
Episode 1300/10000: Max reward for period=300.00, Eps=0.8700
Episode 1400/10000: Max reward for period=300.00, Eps=0.8600
Episode 1500/

Action distribution during training: [ 41432.  39172.  77738. 216894. 177372.  91068.  41264.  36766.  50340.
  38691.  44686.  47120.  43453. 227326.  41360.  38976.  48980.  40529.]
Action entropy: 2.629 (max=2.890)
Most used action: 13 (16.9%)
Training completed. Max score ever: 300.0
