In [1]:
from ale_py import ALEInterface
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import cv2
import numba

In [2]:
CONST_COLOR_PLAYER = (240, 170, 103)
CONST_COLOR_WALL = (84, 92, 214)
CONST_COLOR_ENEMY = (210, 210, 64)

CAT_EMPTY = 0
CAT_PLAYER = 1
CAT_WALL = 2
CAT_ENEMY = 3

In [3]:
@numba.njit
def prepare_state_categorical_inner(obs_resized, h, w):
    new_obs = np.full((h, w, 4), 0, dtype=np.uint8)

    for i in range(h):
        for j in range(w):
            pixel = obs_resized[i, j]

            # Compare pixel (which is a 1D np.array) to the tuple elements
            if (pixel[0] == CONST_COLOR_PLAYER[0] and
                pixel[1] == CONST_COLOR_PLAYER[1] and
                pixel[2] == CONST_COLOR_PLAYER[2]):
                new_obs[i, j, CAT_PLAYER] = 1
            elif (pixel[0] == CONST_COLOR_WALL[0] and
                  pixel[1] == CONST_COLOR_WALL[1] and
                  pixel[2] == CONST_COLOR_WALL[2]):
                new_obs[i, j, CAT_WALL] = 1
            elif (pixel[0] == CONST_COLOR_ENEMY[0] and
                  pixel[1] == CONST_COLOR_ENEMY[1] and
                  pixel[2] == CONST_COLOR_ENEMY[2]):
                new_obs[i, j, CAT_ENEMY] = 1
            else:
                new_obs[i, j, CAT_EMPTY] = 1
    return new_obs

def prepare_state_categorical(frame, h=21, w=21):
    obs_resized = cv2.resize(frame, (w, h), interpolation=cv2.INTER_NEAREST)
    return prepare_state_categorical_inner(obs_resized, h, w)

In [4]:
class ResizeObservation(gym.ObservationWrapper):

    def __init__(self, env, h=21, w=21):
        super().__init__(env)
        self.h, self.w = h, w
        self.observation_space = gym.spaces.Box(0, 1, (h, w, 4), dtype=np.uint8)

    def observation(self, obs):
        return prepare_state_categorical(obs, self.h, self.w)


In [5]:
ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
env = ResizeObservation(env, h=21, w=21)
observation, info = env.reset()


In [6]:
print("action_space:", env.action_space)
print("n actions:", env.action_space.n)

action_space: Discrete(18)
n actions: 18


In [7]:
try:
    meanings = env.unwrapped.get_action_meanings()
except Exception:
    try:
        meanings = env.get_action_meanings()
    except Exception:
        meanings = None

if meanings:
    print("Action index -> meaning:")
    for i, name in enumerate(meanings):
        print(f"{i}: {name}")
else:
    print("No action meanings available from the env. Use index numbers (0..n-1).")


Action index -> meaning:
0: NOOP
1: FIRE
2: UP
3: RIGHT
4: LEFT
5: DOWN
6: UPRIGHT
7: UPLEFT
8: DOWNRIGHT
9: DOWNLEFT
10: UPFIRE
11: RIGHTFIRE
12: LEFTFIRE
13: DOWNFIRE
14: UPRIGHTFIRE
15: UPLEFTFIRE
16: DOWNRIGHTFIRE
17: DOWNLEFTFIRE


In [8]:
FIRE_ACTIONS = [1, 10, 11, 12, 13, 14, 15, 16, 17]
MOVE_ACTIONS = [2, 3, 4, 5, 6, 7, 8, 9]

In [9]:
seed = 42
np.random.seed(seed)

In [10]:
class Sarsa:
    alpha = 1e-5
    gamma = 0.99
    epsilon = 1
    feature_h, feature_w = 21, 21
    use_traces = False
    lmbda = 0.9

    def __init__(self, n_actions):
        self.state_dim = self.feature_h * self.feature_w * 4
        feature_dim = self.state_dim + n_actions
        self.w = np.zeros(feature_dim, dtype=np.float32)
        self.n_actions = n_actions

    def phi_from_state_action(self, features, action):
        a_onehot = np.zeros(self.n_actions, dtype=np.float32)
        a_onehot[action] = 1.0
        return np.concatenate([features, a_onehot])

    def q_value(self, phi):
        return np.dot(self.w, phi)

    def _q_values_all_actions(self, state_features):
        q_base = np.dot(state_features, self.w[:self.state_dim])
        return q_base + self.w[self.state_dim:]

    def epsilon_greedy(self, features):
        eps = float(getattr(self, "epsilon", 0.0))
        eps = max(0.0, min(1.0, eps))

        q_vals = self._q_values_all_actions(features)

        if np.random.rand() < eps:
            action = np.random.randint(self.n_actions)
        else:
            action = np.argmax(q_vals)

        return action, q_vals

    def save(self, file_name="sarsa_weights.npz"):
        np.savez(file_name, w=self.w)

    @staticmethod
    def load(file_name="sarsa_weights.npz"):
        data = np.load(file_name)
        n_featues = 21*21*4
        ag = Sarsa(n_actions=data['w'].shape[0] - n_featues)
        ag.w = data['w']

        print("Loaded SARSA agent with rules characteristics:")
        print("w shape:", ag.w.shape)
        print("w norm:", np.linalg.norm(ag.w))
        print("non-zero weights:", np.count_nonzero(ag.w))
        return ag

    def restrict_exploration(self):
        self.epsilon = 0.0

In [11]:
def is_model_trained():
    try:
        _ = np.load("sarsa_weights.npz")
        return True
    except FileNotFoundError:
        return False

In [12]:
def file_exist(file_name):
    try:
        _ = np.load(file_name)
        return True
    except FileNotFoundError:
        return False

In [13]:
R_LIVING = -0.1            # MANUAL: Must be a small negative "living penalty".
R_INACTION = -0.4603       # IRL: Learned from your data.
R_WALL = -0.6082           # IRL: Learned from your data.
R_PROXIMITY = -0.0499      # IRL: Learned from your data.
R_HUNTING = 0.2            # MANUAL: Override. The model failed to learn this.
R_KILL = 15.0              # MANUAL: Override. The model failed to learn this.
R_DEATH = -10.0            # MANUAL: Override. The model's value was too small.

class StateObserver:
    def __init__(self, w, h):
        self.w = w
        self.h = h
        self.prev_pos = None
        self.prev_num_enemies = 0
        self.prev_min_distance = None

    def _find_entities(self, state):
        player_coords = np.argwhere(state[:, :, CAT_PLAYER] == 1)
        enemy_coords = np.argwhere(state[:, :, CAT_ENEMY] == 1)
        player_pos = tuple(player_coords[0]) if len(player_coords) > 0 else None
        return player_pos, enemy_coords

    def analyze_state(self, state, reward, action):
        player_pos, enemies = self._find_entities(state)
        num_enemies = len(enemies)

        # Start with the real reward from the game
        shaped_reward = 0

        # 1. Living Penalty: Small cost for every step to encourage speed.
        shaped_reward += R_LIVING

        if action not in MOVE_ACTIONS:
            shaped_reward += R_INACTION

        if (action in MOVE_ACTIONS and
            player_pos is not None and
            player_pos == self.prev_pos):
            shaped_reward -= R_WALL

        if num_enemies < self.prev_num_enemies:
            shaped_reward += R_KILL

        # 4. Proximity Penalty: Penalize being too close to an enemy.
        if player_pos is not None and num_enemies > 0:
            distances = np.sum(np.abs(enemies - player_pos), axis=1)
            min_distance = np.min(distances)

            if min_distance <= 2: # Too close!
                shaped_reward += R_PROXIMITY

            if self.prev_min_distance is not None and min_distance < self.prev_min_distance:
                shaped_reward += R_HUNTING

            self.prev_min_distance = min_distance
        else:
            self.prev_min_distance = None

        # 5. Death Penalty: Big penalty if player disappears.
        if player_pos is None and self.prev_pos is not None:
             # Player was alive last step, but is gone now
             shaped_reward += R_DEATH

        # Update memory for the next step
        self.prev_pos = player_pos
        self.prev_num_enemies = num_enemies
        return shaped_reward

    def reset(self):
        self.prev_pos = None
        self.prev_num_enemies = 0
        self.prev_min_distance = None


In [14]:
class Trainer:
    def __init__(self, epsilon_min = 0.05, epsilon_decay_fraction = 0.999, initial_epsilon = 1.0):
        self.epsilon_min = epsilon_min
        self.epsilon_decay_fraction = epsilon_decay_fraction
        self.initial_epsilon = initial_epsilon


    @staticmethod
    def _file_name_for_class(class_name):
        return f"sarsa-weights-{class_name.lower()}.npz"

    def train_if_needed(self, model, env, class_name, n_episodes=1000):
        file_name = Trainer._file_name_for_class(class_name)
        print(f'Checking for existing model file: {file_name}')
        if not file_exist(file_name):
            self.train(model, env, class_name, n_episodes)
            return model

        return Sarsa.load(file_name)

    def train(self, model, env, class_name, n_episodes=1000):
        print(f"Training {class_name} agent...")
        state_observer = StateObserver(w=21, h=21)
        action_counts = np.zeros(env.action_space.n, dtype=np.float32)

        max_score_ever = -np.inf

        model.epsilon = self.initial_epsilon
        decay_episodes = int(n_episodes * self.epsilon_decay_fraction)
        if decay_episodes > 0:
             epsilon_decay_step = (self.initial_epsilon - self.epsilon_min) / decay_episodes
        else:
             epsilon_decay_step = 0
        print(f"Epsilon will decay from {self.initial_epsilon} to {self.epsilon_min} over {decay_episodes} episodes.")

        log_step = max(1, n_episodes // 100)

        for episode in range(n_episodes):
            state, _ = env.reset()
            state_observer.reset()
            # Ensure state is properly flattened
            features = np.array(state, dtype=np.float32).flatten()
            action, q_values = model.epsilon_greedy(features)
            action_counts[action] += 1
            phi = model.phi_from_state_action(features, action)

            done = False
            ep_reward = 0

            while not done:
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                # Ensure next_state is properly flattened
                next_features = np.array(next_state, dtype=np.float32).flatten()
                next_action, next_q_values = model.epsilon_greedy(next_features)
                action_counts[next_action] += 1
                next_phi = model.phi_from_state_action(next_features, next_action)

                if len(q_values) != model.n_actions:
                    raise ValueError(f"Expected q_values of length {model.n_actions}, got {len(q_values)}")

                q = q_values[action]
                q_next = next_q_values[next_action]

                # Use the original state (2D array) for state observer
                shaped_reward = state_observer.analyze_state(state, reward, action)
                delta = shaped_reward + model.gamma * q_next - q
                model.w += model.alpha * delta * phi


                state = next_state
                action = next_action
                q_values = next_q_values
                phi = next_phi
                ep_reward += reward

            new_epsilon = model.epsilon - epsilon_decay_step
            model.epsilon = max(self.epsilon_min, new_epsilon)

            if ep_reward > max_score_ever:
                max_score_ever = ep_reward


            if (episode + 1) % log_step == 0:
                print(f"Episode {episode+1}/{n_episodes}: Reward={ep_reward:.2f}, Eps={model.epsilon:.4f}")

        print(f'Action distribution during training: {action_counts}')
        print(f"Training completed. Max score ever: {max_score_ever:.2f}")
        model.save(self._file_name_for_class(class_name))


In [15]:
CLASS_NAME = "Berzerk-Default"

agent = Sarsa(env.action_space.n)

epsilon_min = 0.1

trainer = Trainer(epsilon_min)
agent = trainer.train_if_needed(agent, env, class_name=CLASS_NAME, n_episodes=1000)

env.close()


Checking for existing model file: sarsa-weights-berzerk-default.npz
Training Berzerk-Default agent...
Epsilon will decay from 1.0 to 0.1 over 999 episodes.
Episode 10/1000: Reward=250.00, Eps=0.9910
Episode 20/1000: Reward=300.00, Eps=0.9820
Episode 30/1000: Reward=200.00, Eps=0.9730
Episode 40/1000: Reward=100.00, Eps=0.9640
Episode 50/1000: Reward=400.00, Eps=0.9550
Episode 60/1000: Reward=0.00, Eps=0.9459
Episode 70/1000: Reward=100.00, Eps=0.9369
Episode 80/1000: Reward=50.00, Eps=0.9279
Episode 90/1000: Reward=250.00, Eps=0.9189
Episode 100/1000: Reward=410.00, Eps=0.9099
Episode 110/1000: Reward=200.00, Eps=0.9009
Episode 120/1000: Reward=250.00, Eps=0.8919
Episode 130/1000: Reward=350.00, Eps=0.8829
Episode 140/1000: Reward=200.00, Eps=0.8739
Episode 150/1000: Reward=50.00, Eps=0.8649
Episode 160/1000: Reward=250.00, Eps=0.8559
Episode 170/1000: Reward=50.00, Eps=0.8468
Episode 180/1000: Reward=250.00, Eps=0.8378
Episode 190/1000: Reward=150.00, Eps=0.8288
Episode 200/1000: Rewa

# Test

In [16]:
agent = Sarsa.load(Trainer._file_name_for_class(CLASS_NAME))

Loaded SARSA agent with rules characteristics:
w shape: (1782,)
w norm: 3.1991591
non-zero weights: 815


In [17]:
ale = ALEInterface()
gym.register_envs(ale)

test_env = gym.make("ALE/Berzerk-v5", render_mode="human", frameskip=4)
agent.restrict_exploration()

In [18]:
n_episodes = 5
total_rewards = []

for ep in range(n_episodes):
    state, _ = test_env.reset()
    done = False
    ep_reward = 0

    actions_count = np.zeros(test_env.action_space.n, dtype=np.int32)
    while not done:
        state = prepare_state_categorical(state).flatten()
        action, _ = agent.epsilon_greedy(state)
        actions_count[action] += 1
        next_state, reward, terminated, truncated, _ = test_env.step(action)
        done = terminated or truncated

        state = next_state
        ep_reward += reward

    test_env.render()
    print(f"Episode {ep + 1}: Total Reward = {ep_reward}")
    print(f'Action count during round: {actions_count}')
    print('---------------------------')
    total_rewards.append(ep_reward)

test_env.close()

print(f"\nAverage Test Reward over {n_episodes} episodes: {np.mean(total_rewards):.2f}")

Episode 1: Total Reward = 50.0
Action count during round: [  0   0   0   0   0 196   0   0   0   0   0   0   0   0   0   0   0   0]
---------------------------
Episode 2: Total Reward = 50.0
Action count during round: [  0   0   0   0   0 196   0   0   0   0   0   0   0   0   0   0   0   0]
---------------------------
Episode 3: Total Reward = 50.0
Action count during round: [  0   0   0   0   0 196   0   0   0   0   0   0   0   0   0   0   0   0]
---------------------------
Episode 4: Total Reward = 50.0
Action count during round: [  0   0   0   0   0 196   0   0   0   0   0   0   0   0   0   0   0   0]
---------------------------
Episode 5: Total Reward = 50.0
Action count during round: [  0   0   0   0   0 196   0   0   0   0   0   0   0   0   0   0   0   0]
---------------------------

Average Test Reward over 5 episodes: 50.00
