In [15]:
from ale_py import ALEInterface
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import cv2
import numba
import plotly.express as px

In [16]:
CONST_COLOR_PLAYER = (240, 170, 103)
CONST_COLOR_WALL = (84, 92, 214)
CONST_COLOR_ENEMY = (210, 210, 64)

CAT_EMPTY = 0
CAT_PLAYER = 1
CAT_WALL = 2
CAT_ENEMY = 3

In [17]:
# --- Helper: The One-Hot Encoder (Unchanged from your code) ---
@numba.njit
def _get_one_hot_state(obs_resized, h, w):
    new_obs = np.full((h, w, 4), 0, dtype=np.uint8)
    for i in range(h):
        for j in range(w):
            pixel = obs_resized[i, j]
            if (pixel[0] == CONST_COLOR_PLAYER[0] and
                pixel[1] == CONST_COLOR_PLAYER[1] and
                pixel[2] == CONST_COLOR_PLAYER[2]):
                new_obs[i, j, CAT_PLAYER] = 1
            elif (pixel[0] == CONST_COLOR_WALL[0] and
                  pixel[1] == CONST_COLOR_WALL[1] and
                  pixel[2] == CONST_COLOR_WALL[2]):
                new_obs[i, j, CAT_WALL] = 1
            elif (pixel[0] == CONST_COLOR_ENEMY[0] and
                  pixel[1] == CONST_COLOR_ENEMY[1] and
                  pixel[2] == CONST_COLOR_ENEMY[2]):
                new_obs[i, j, CAT_ENEMY] = 1
            else:
                new_obs[i, j, CAT_EMPTY] = 1
    return new_obs

# --- Helper: Find Entities (Copied from your StateObserver) ---
@numba.njit
def _find_entities(one_hot_state):
    player_coords = np.argwhere(one_hot_state[:, :, CAT_PLAYER] == 1)
    enemy_coords = np.argwhere(one_hot_state[:, :, CAT_ENEMY] == 1)
    player_pos = player_coords[0] if len(player_coords) > 0 else None
    return player_pos, enemy_coords

# --- Helper: Find Closest Enemy Direction ---
@numba.njit
def _get_enemy_direction(player_pos, enemies):
    if player_pos is None or len(enemies) == 0:
        return 0.0, 0.0 # No enemy, no direction

    closest_enemy = enemies[0]
    min_dist = np.inf

    # Find the closest enemy
    for i in range(len(enemies)):
        dist = np.sum(np.abs(enemies[i] - player_pos))
        if dist < min_dist:
            min_dist = dist
            closest_enemy = enemies[i]

    # Calculate direction (returns -1.0, 0.0, or 1.0)
    # We normalize to prevent one direction from seeming "larger"
    dir_y = np.sign(closest_enemy[0] - player_pos[0])
    dir_x = np.sign(closest_enemy[1] - player_pos[1])

    return float(dir_y), float(dir_x)

# --- THE NEW PREPROCESSING FUNCTION ---
def extract_intelligent_features(frame, h=21, w=21):
    """
    This is the new main preprocessing function.
    It returns a 6-feature vector, not a 1764-pixel vector.
    """

    # 1. First, get the raw 0-255 image
    obs_resized = cv2.resize(frame, (w, h), interpolation=cv2.INTER_NEAREST)

    # 2. Get the 21x21x4 one-hot state
    one_hot_state = _get_one_hot_state(obs_resized, h, w)

    # 3. Find the player and enemies
    player_pos, enemies = _find_entities(one_hot_state)

    if player_pos is None:
        # Player is dead, return a zero-vector
        return np.zeros(6, dtype=np.float32)

    (y, x) = player_pos

    # 4. Feature 0-3: Wall Proximity
    # (1.0 if a wall is there, 0.0 if not)
    f_wall_up = one_hot_state[y-1, x, CAT_WALL] if y > 0 else 1.0
    f_wall_down = one_hot_state[y+1, x, CAT_WALL] if y < h-1 else 1.0
    f_wall_left = one_hot_state[y, x-1, CAT_WALL] if x > 0 else 1.0
    f_wall_right = one_hot_state[y, x+1, CAT_WALL] if x < w-1 else 1.0

    # 5. Feature 4-5: Enemy Direction
    f_enemy_dir_y, f_enemy_dir_x = _get_enemy_direction(player_pos, enemies)

    # 6. Return the final, intelligent feature vector
    features = np.array([
        f_wall_up, f_wall_down, f_wall_left, f_wall_right,
        f_enemy_dir_y, f_enemy_dir_x
    ], dtype=np.float32)

    return features

In [18]:
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, h=21, w=21):
        super().__init__(env)
        self.h, self.w = h, w

        # --- NEW OBSERVATION SPACE ---
        # 6 features: wall_up, wall_down, wall_left, wall_right,
        #             enemy_dir_y, enemy_dir_x
        # Values are -1.0 to 1.0 (for directions and wall flags)
        self.observation_space = gym.spaces.Box(
            low=-1.0, high=1.0, shape=(6,), dtype=np.float32
        )

    def observation(self, obs):
        # Call our new, intelligent feature extractor
        return extract_intelligent_features(obs, self.h, self.w)

In [19]:
ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
env = ResizeObservation(env, h=21, w=21)
observation, info = env.reset()


In [20]:
print("action_space:", env.action_space)
print("n actions:", env.action_space.n)

action_space: Discrete(18)
n actions: 18


In [21]:
try:
    meanings = env.unwrapped.get_action_meanings()
except Exception:
    try:
        meanings = env.get_action_meanings()
    except Exception:
        meanings = None

if meanings:
    print("Action index -> meaning:")
    for i, name in enumerate(meanings):
        print(f"{i}: {name}")
else:
    print("No action meanings available from the env. Use index numbers (0..n-1).")


Action index -> meaning:
0: NOOP
1: FIRE
2: UP
3: RIGHT
4: LEFT
5: DOWN
6: UPRIGHT
7: UPLEFT
8: DOWNRIGHT
9: DOWNLEFT
10: UPFIRE
11: RIGHTFIRE
12: LEFTFIRE
13: DOWNFIRE
14: UPRIGHTFIRE
15: UPLEFTFIRE
16: DOWNRIGHTFIRE
17: DOWNLEFTFIRE


In [22]:
FIRE_ACTIONS = [1, 10, 11, 12, 13, 14, 15, 16, 17]
MOVE_ACTIONS = [2, 3, 4, 5, 6, 7, 8, 9]

In [23]:
seed = 42
np.random.seed(seed)

In [24]:
class Sarsa:
    alpha = 1e-4
    gamma = 0.99
    epsilon = 1
    # feature_h, feature_w = 21, 21
    lmbda = 0.9

    def __init__(self, n_actions):
        self.state_dim = 6
        feature_dim = self.state_dim + n_actions
        self.w = np.zeros(feature_dim, dtype=np.float32)
        self.n_actions = n_actions
        self.z = np.zeros_like(self.w, dtype=np.float32)

    def phi_from_state_action(self, features, action):
        a_onehot = np.zeros(self.n_actions, dtype=np.float32)
        a_onehot[action] = 1.0
        return np.concatenate([features, a_onehot])

    def q_value(self, phi):
        return np.dot(self.w, phi)

    def _q_values_all_actions(self, state_features):
        q_base = np.dot(state_features, self.w[:self.state_dim])
        return q_base + self.w[self.state_dim:]

    def epsilon_greedy(self, features):
        eps = float(getattr(self, "epsilon", 0.0))
        eps = max(0.0, min(1.0, eps))

        q_vals = self._q_values_all_actions(features)

        if np.random.rand() < eps:
            action = np.random.randint(self.n_actions)
        else:
            action = np.argmax(q_vals)

        return action, q_vals

    def save(self, file_name="sarsa_weights.npz"):
        np.savez(file_name, w=self.w)

    @staticmethod
    def load(file_name="sarsa_weights.npz"):
        data = np.load(file_name)
        n_featues = 6
        ag = Sarsa(n_actions=data['w'].shape[0] - n_featues)
        ag.w = data['w']

        print("Loaded SARSA agent with rules characteristics:")
        print("w shape:", ag.w.shape)
        print("w norm:", np.linalg.norm(ag.w))
        print("non-zero weights:", np.count_nonzero(ag.w))
        return ag

    def restrict_exploration(self):
        self.epsilon = 0.0

    def reset_traces(self):
        self.z.fill(0.0)

In [25]:
def is_model_trained():
    try:
        _ = np.load("sarsa_weights.npz")
        return True
    except FileNotFoundError:
        return False

In [26]:
def file_exist(file_name):
    try:
        _ = np.load(file_name)
        return True
    except FileNotFoundError:
        return False

In [27]:
R_LIVING = -0.001            # MANUAL: Must be a small negative "living penalty".
R_INACTION = -0.2603       # IRL: Learned from your data.
R_WALL = -2.082           # IRL: Learned from your data.
# R_INACTION = -1.082       # IRL: Learned from your data.
# R_WALL = -1.082           # IRL: Learned from your data.
R_PROXIMITY = -0.0499      # IRL: Learned from your data.
R_HUNTING = 0.2            # MANUAL: Override. The model failed to learn this.
R_KILL = 15.0              # MANUAL: Override. The model failed to learn this.
R_DEATH = -20.0            # MANUAL: Override. The model's value was too small.

class StateObserver:
    def __init__(self, w, h):
        self.w = w
        self.h = h
        self.prev_pos = None
        self.prev_num_enemies = 0
        self.prev_min_distance = None

    def _find_entities(self, state):
        player_coords = np.argwhere(state[:, :, CAT_PLAYER] == 1)
        enemy_coords = np.argwhere(state[:, :, CAT_ENEMY] == 1)
        player_pos = tuple(player_coords[0]) if len(player_coords) > 0 else None
        return player_pos, enemy_coords

    def analyze_state(self, state, reward, action):
        player_pos, enemies = self._find_entities(state)
        num_enemies = len(enemies)

        # Start with the real reward from the game
        shaped_reward = reward

        # 1. Living Penalty: Small cost for every step to encourage speed.
        shaped_reward += R_LIVING

        # if action not in MOVE_ACTIONS:
        #     shaped_reward += R_INACTION

        if (action in MOVE_ACTIONS and
            player_pos is not None and
            player_pos == self.prev_pos):
            shaped_reward += R_WALL

        if num_enemies < self.prev_num_enemies:
            shaped_reward += R_KILL

        # 4. Proximity Penalty: Penalize being too close to an enemy.
        if player_pos is not None and num_enemies > 0:
            distances = np.sum(np.abs(enemies - player_pos), axis=1)
            min_distance = np.min(distances)

            if min_distance <= 2: # Too close!
                shaped_reward += R_PROXIMITY

            if self.prev_min_distance is not None and min_distance < self.prev_min_distance:
                shaped_reward += R_HUNTING

            self.prev_min_distance = min_distance
        else:
            self.prev_min_distance = None

        # 5. Death Penalty: Big penalty if player disappears.
        if player_pos is None and self.prev_pos is not None:
             # Player was alive last step, but is gone now
             shaped_reward += R_DEATH

        # Update memory for the next step
        self.prev_pos = player_pos
        self.prev_num_enemies = num_enemies
        return shaped_reward

    def reset(self):
        self.prev_pos = None
        self.prev_num_enemies = 0
        self.prev_min_distance = None


In [28]:
class Trainer:
    def __init__(self, epsilon_min = 0.05, epsilon_decay_fraction = 0.999, initial_epsilon = 1.0):
        self.epsilon_min = epsilon_min
        self.epsilon_decay_fraction = epsilon_decay_fraction
        self.initial_epsilon = initial_epsilon


    @staticmethod
    def _file_name_for_class(class_name):
        return f"sarsa-weights-{class_name.lower()}.npz"

    def train_if_needed(self, model, env, class_name, n_episodes=1000):
        file_name = Trainer._file_name_for_class(class_name)
        print(f'Checking for existing model file: {file_name}')
        if not file_exist(file_name):
            self.train(model, env, class_name, n_episodes)
            return model

        return Sarsa.load(file_name)

    def train(self, model, env, class_name, n_episodes=1000):
        print(f"Training {class_name} agent...")
        state_observer = StateObserver(w=21, h=21)
        action_counts = np.zeros(env.action_space.n, dtype=np.float32)

        max_score_ever = -np.inf
        rewards = []
        w_changes = []
        previous_w = model.w.copy()

        model.epsilon = self.initial_epsilon
        decay_episodes = int(n_episodes * self.epsilon_decay_fraction)
        if decay_episodes > 0:
             epsilon_decay_step = (self.initial_epsilon - self.epsilon_min) / decay_episodes
        else:
             epsilon_decay_step = 0
        print(f"Epsilon will decay from {self.initial_epsilon} to {self.epsilon_min} over {decay_episodes} episodes.")

        log_step = max(1, n_episodes // 100)

        for episode in range(n_episodes):
            state, _ = env.reset()
            state_observer.reset()
            model.reset_traces()
            features = np.array(state, dtype=np.float32)
            action, q_values = model.epsilon_greedy(features)
            action_counts[action] += 1
            phi = model.phi_from_state_action(features, action)

            done = False
            ep_reward = 0

            while not done:
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated

                # Ensure next_state is properly flattened
                next_features = np.array(next_state, dtype=np.float32)
                next_action, next_q_values = model.epsilon_greedy(next_features)
                action_counts[next_action] += 1
                next_phi = model.phi_from_state_action(next_features, next_action)

                if len(q_values) != model.n_actions:
                    raise ValueError(f"Expected q_values of length {model.n_actions}, got {len(q_values)}")

                q = q_values[action]
                if done:
                    q_next = 0.0
                    continue
                else:
                    # next_action, next_q_values = model.epsilon_greedy(next_features)
                    q_next = next_q_values[next_action]

                    # Store these for the *next* loop iteration

                action_counts[next_action] += 1
                # next_phi = model.phi_from_state_action(next_features, next_action)


                # Use the original state (2D array) for state observer
                shaped_reward = reward + R_LIVING
                # if action not in MOVE_ACTIONS:
                #     shaped_reward += R_INACTION

                if action == 2: # UP
                    shaped_reward += next_features[0] * R_WALL
                elif action == 5: # DOWN
                    shaped_reward += next_features[1] * R_WALL # f_wall_down
                elif action == 4: # LEFT
                    shaped_reward += next_features[2] * R_WALL # f_wall_left
                elif action == 3: # RIGHT
                    shaped_reward += next_features[3] * R_WALL

                if reward > 0:
                     shaped_reward += R_KILL

                # 4. Hunting Reward
                f_enemy_dir_y, f_enemy_dir_x = features[4], features[5]
                if f_enemy_dir_y != 0 or f_enemy_dir_x != 0: # If an enemy exists
                    if (action == 2 and f_enemy_dir_y == -1.0) or \
                       (action == 5 and f_enemy_dir_y == 1.0) or \
                       (action == 4 and f_enemy_dir_x == -1.0) or \
                       (action == 3 and f_enemy_dir_x == 1.0):
                        shaped_reward += R_HUNTING

                if terminated:
                    shaped_reward += R_DEATH

                delta = shaped_reward + model.gamma * q_next - q
                model.z = (model.gamma * model.lmbda * model.z) + phi
                model.w += model.alpha * delta * model.z


                state = next_state
                action = next_action
                q_values = next_q_values
                phi = next_phi
                ep_reward += reward

            new_epsilon = model.epsilon - epsilon_decay_step
            model.epsilon = max(self.epsilon_min, new_epsilon)

            w_change = np.mean(np.abs(model.w - previous_w))
            w_changes.append(w_change)
            previous_w = model.w.copy()

            rewards.append(ep_reward)
            if ep_reward > max_score_ever:
                max_score_ever = ep_reward


            if (episode + 1) % log_step == 0:
                recent_max = float(np.max(rewards[-log_step:])) if len(rewards) > 0 else float(ep_reward)
                print(f"Episode {episode+1}/{n_episodes}: Max reward for period={recent_max:.2f}, Eps={model.epsilon:.4f}")

        px.line(x=np.arange(1, n_episodes + 1), y=w_changes, labels={'x': 'Episode', 'y': 'Mean |Δw|'},
                title='Mean Weight Change over Episodes').show()
        print(f'Action distribution during training: {action_counts}')
        print(f"Training completed. Max score ever: {max_score_ever:.2f}")
        model.save(self._file_name_for_class(class_name))


In [29]:
CLASS_NAME = "Berzerk-Default"

agent = Sarsa(env.action_space.n)

epsilon_min = 0.1

trainer = Trainer(epsilon_min, 0.8, initial_epsilon=0.5)
agent = trainer.train_if_needed(agent, env, class_name=CLASS_NAME, n_episodes=1000)

env.close()


Checking for existing model file: sarsa-weights-berzerk-default.npz
Training Berzerk-Default agent...
Epsilon will decay from 0.5 to 0.1 over 800 episodes.
Episode 10/1000: Max reward for period=500.00, Eps=0.4950
Episode 20/1000: Max reward for period=450.00, Eps=0.4900
Episode 30/1000: Max reward for period=300.00, Eps=0.4850
Episode 40/1000: Max reward for period=500.00, Eps=0.4800
Episode 50/1000: Max reward for period=400.00, Eps=0.4750
Episode 60/1000: Max reward for period=500.00, Eps=0.4700
Episode 70/1000: Max reward for period=400.00, Eps=0.4650
Episode 80/1000: Max reward for period=450.00, Eps=0.4600
Episode 90/1000: Max reward for period=400.00, Eps=0.4550
Episode 100/1000: Max reward for period=450.00, Eps=0.4500
Episode 110/1000: Max reward for period=450.00, Eps=0.4450
Episode 120/1000: Max reward for period=450.00, Eps=0.4400
Episode 130/1000: Max reward for period=350.00, Eps=0.4350
Episode 140/1000: Max reward for period=350.00, Eps=0.4300
Episode 150/1000: Max rewar

Action distribution during training: [ 10379.  10666.  10370.  10633.  10406.  10302.  10293.  10511.  10725.
  10488.  10421.  10465.  10396. 639596.  10515.  10735.  10476.  10617.]
Training completed. Max score ever: 750.00


# Test

In [30]:
agent = Sarsa.load(Trainer._file_name_for_class(CLASS_NAME))

Loaded SARSA agent with rules characteristics:
w shape: (24,)
w norm: 53.34482
non-zero weights: 24


In [31]:
ale = ALEInterface()
gym.register_envs(ale)

test_env = gym.make("ALE/Berzerk-v5", render_mode="human", frameskip=4)
agent.restrict_exploration()

In [32]:
def normalize_weights(w):
    avg = np.mean(w)
    irq = np.percentile(w, 75) - np.percentile(w, 25)
    lower_bound = avg - irq * 0.75
    upper_bound = avg + irq * 0.75
    w_clipped = np.clip(w, lower_bound, upper_bound)
    return w_clipped


In [None]:
n_episodes = 5
total_rewards = []

agent.w = normalize_weights(agent.w)

for ep in range(n_episodes):
    state, _ = test_env.reset()
    done = False
    ep_reward = 0

    actions_count = np.zeros(test_env.action_space.n, dtype=np.int32)
    while not done:
        features = extract_intelligent_features(state)
        action, _ = agent.epsilon_greedy(features)
        actions_count[action] += 1
        next_state, reward, terminated, truncated, _ = test_env.step(action)
        done = terminated or truncated

        state = next_state
        ep_reward += reward

    test_env.render()
    print(f"Episode {ep + 1}: Total Reward = {ep_reward}")
    print(f'Action count during round: {actions_count}')
    print('---------------------------')
    total_rewards.append(ep_reward)

test_env.close()

print(f"\nAverage Test Reward over {n_episodes} episodes: {np.mean(total_rewards):.2f}")