In [1]:
from ale_py import ALEInterface
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import cv2
import numba
import plotly.express as px
import time

In [2]:
FIRE_ACTIONS = [1, 10, 11, 12, 13, 14, 15, 16, 17]
MOVE_ACTIONS = [2, 3, 4, 5, 6, 7, 8, 9]
np.random.seed(42)

In [3]:
UP_DIRECTION = 0
RIGHT_DIRECTION = 1
DOWN_DIRECTION = 2
LEFT_DIRECTION = 3
UPRIGHT_DIRECTION = 4
UPLEFT_DIRECTION = 5
DOWNRIGHT_DIRECTION = 6
DOWNLEFT_DIRECTION = 7

In [4]:
EMPTY_COLOR = np.array([0, 0, 0], dtype=np.uint8)
WALL_COLOR = np.array([84, 92, 214], dtype=np.uint8)
ENEMY_COLOR = np.array([210, 210, 64], dtype=np.uint8)
PLAYER_COLOR = np.array([240, 170, 103], dtype=np.uint8)

EMPTY_INDEX = 0
WALL_INDEX = 1
ENEMY_INDEX = 2
PLAYER_INDEX = 3

PORTAL_INDEX = 4 # add manualy
PORTAL_COLOR = np.array([74, 255, 56], dtype=np.uint8)

In [5]:
@numba.njit
def rgb_to_index(frame):
    state = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
    for i in range(frame.shape[0]):
        for j in range(frame.shape[1]):
            pixel = frame[i, j]
            if pixel[0] == WALL_COLOR[0]:
                state[i, j] = WALL_INDEX
            elif pixel[0] == ENEMY_COLOR[0]:
                state[i, j] = ENEMY_INDEX
            elif pixel[0] == PLAYER_COLOR[0]:
                state[i, j] = PLAYER_INDEX
            elif pixel[0] == PORTAL_COLOR[0]:
                state[i, j] = PORTAL_INDEX
            else:
                state[i, j] = EMPTY_INDEX
    return state

@numba.njit
def cut_empty_layers(state):
    skip_layers = 0
    while True:
        if state[skip_layers][skip_layers] == 0:
            skip_layers += 1
        else:
            break
    state = state[skip_layers:-skip_layers, skip_layers:-skip_layers]

    skip_layers_from_bottom = 0
    while True:
        if state[-(skip_layers_from_bottom + 1)][-(skip_layers_from_bottom + 1)] == 0:
            skip_layers_from_bottom += 1
        else:
            break
    state = state[:-skip_layers_from_bottom, :]
    return state

@numba.njit
def fill_part(matrix, a, b, c, d, fill_value):
    rows, cols = matrix.shape
    min_i = max(0, min(a[0], b[0], c[0], d[0]))
    max_i = min(rows - 1, max(a[0], b[0], c[0], d[0]))
    min_j = max(0, min(a[1], b[1], c[1], d[1]))
    max_j = min(cols - 1, max(a[1], b[1], c[1], d[1]))

    for ii in range(min_i, max_i + 1):
        for jj in range(min_j, max_j + 1):
            if matrix[ii, jj] == EMPTY_INDEX:
                matrix[ii, jj] = fill_value

@numba.njit
def fill_holes(state):
    PORTAL_MIN_LENGHT = 3
    PORTAL_WIDTH = 0
    WALL_WIDTH = 3

    if state[0][0] != WALL_INDEX:
        return

    P = 2 * (state.shape[0] + state.shape[1])

    i, j = 0, 0
    height, width = state.shape
    last_wall_pixel = (0, 0)
    iters = 0
    while True:
        if iters > P + 1:
            raise Warning("No walls found on border during hole filling")

        iters += 1

        if state[i][j] == WALL_INDEX:
            distance_to_last_wall = max(abs(i - last_wall_pixel[0]), abs(j - last_wall_pixel[1]))
            if distance_to_last_wall == 1:
                pass
            else:
                material_to_fill = WALL_INDEX if distance_to_last_wall < PORTAL_MIN_LENGHT else PORTAL_INDEX
                width_to_fill = WALL_WIDTH if material_to_fill == WALL_INDEX else PORTAL_WIDTH
                A,B,C,D = None, None, None, None
                if j == 0:
                    A = (last_wall_pixel[0] + 1, last_wall_pixel[1])
                    B = (last_wall_pixel[0] + 1, last_wall_pixel[1] + width_to_fill)
                    C = (i, j + width_to_fill)
                    D = (i, j)
                elif i == height - 1:
                    A = (last_wall_pixel[0], last_wall_pixel[1] + 1)
                    B = (last_wall_pixel[0] - width_to_fill, last_wall_pixel[1] + 1)
                    C = (i - width_to_fill, j)
                    D = (i, j)
                elif j == width - 1:
                    A = (last_wall_pixel[0] - 1, last_wall_pixel[1])
                    B = (last_wall_pixel[0] - 1, last_wall_pixel[1] - width_to_fill)
                    C = (i, j - width_to_fill)
                    D = (i, j)
                elif i == 0:
                    A = (last_wall_pixel[0], last_wall_pixel[1] - 1)
                    B = (last_wall_pixel[0] - width_to_fill, last_wall_pixel[1] - 1)
                    C = (i + width_to_fill, j)
                    D = (i, j)
                else:
                    raise ValueError("Unexpected wall pixel not on border")

                fill_part(state, A, B, C, D, material_to_fill)

            last_wall_pixel = (i, j)

        if j == 0 and i < height - 1:
            i += 1
        elif i == height - 1 and j < width - 1:
            j += 1
        elif j == width - 1 and i > 0:
            i -= 1
        elif i == 0 and j > 0:
            j -= 1

        if i == 0 and j == 0:
            break

@numba.njit
def find_player_box(state):
    player_positions = np.argwhere(state == PLAYER_INDEX)

    if player_positions.shape[0] == 0:
        return None

    i_pos = player_positions[:, 0]
    j_pos = player_positions[:, 1]
    box = (np.min(i_pos), np.min(j_pos), np.max(i_pos), np.max(j_pos))
    return box

@numba.njit
def basic_observation(state, box):
    if box is None:
        return None, None, None, None, None

    left_search_point = ((box[0] + box[2]) // 2, box[1])
    right_search_point = ((box[0] + box[2]) // 2, box[3])
    up_search_point = (box[0], (box[1] + box[3]) // 2)
    down_search_point = (box[2], (box[1] + box[3]) // 2)

    l_wall = (left_search_point[0], 0)
    r_wall = (right_search_point[0], state.shape[1] - 1)
    u_wall = (0, up_search_point[1])
    d_wall = (state.shape[0] - 1, down_search_point[1])

    reach_l, reach_r, reach_u, reach_d = False, False, False, False

    closest_enemy = None

    i = 0
    while True:
        if i > max(state.shape):
            raise ValueError("No walls found in any direction")

        if not reach_l:
            cord = (left_search_point[0], left_search_point[1] - i)
            material = state[cord]
            if material == WALL_INDEX:
                reach_l = True
                l_wall = cord
            elif material == ENEMY_INDEX:
                if closest_enemy is None:
                    closest_enemy = cord
        if not reach_r:
            cord = (right_search_point[0], right_search_point[1] + i)
            material = state[cord]
            if material == WALL_INDEX:
                reach_r = True
                r_wall = cord
            elif material == ENEMY_INDEX:
                if closest_enemy is None:
                    closest_enemy = cord
        if not reach_u:
            cord = (up_search_point[0] - i, up_search_point[1])
            material = state[cord]
            if material == WALL_INDEX:
                reach_u = True
                u_wall = cord
            elif material == ENEMY_INDEX:
                if closest_enemy is None:
                    closest_enemy = cord
        if not reach_d:
            cord = (down_search_point[0] + i, down_search_point[1])
            material = state[cord]
            if material == WALL_INDEX:
                reach_d = True
                d_wall = cord
            elif material == ENEMY_INDEX:
                if closest_enemy is None:
                    closest_enemy = cord
        if reach_l and reach_r and reach_u and reach_d:
            break
        i += 1

    return l_wall, r_wall, u_wall, d_wall, closest_enemy

@numba.njit
def has_enemy(state):
    return np.any(state == ENEMY_INDEX)


In [6]:
VECTOR_STATE_SIZE = 9
class State:
    def __init__(self, frame):
        state = rgb_to_index(frame)
        state = cut_empty_layers(state)
        self.is_empty = len(state) == 0 or len(state[0]) == 0
        if not self.is_empty:
            fill_holes(state)
        self.state = state
        self.state_h = state.shape[0]
        self.state_w = state.shape[1]
        self.player_box = find_player_box(state)
        obs = basic_observation(state, self.player_box) if not self.player_box is None else None

        self.left_wall = obs[0] if obs is not None and obs[0] is not None else (0,0)
        self.right_wall = obs[1] if obs is not None and obs[1] is not None else (0,0)
        self.up_wall = obs[2] if obs is not None and obs[2] is not None else (0,0)
        self.down_wall = obs[3] if obs is not None and obs[3] is not None else (0,0)
        self.closest_enemy = obs[4] if obs is not None and obs[4] is not None else None

    def has_enemy(self):
        return has_enemy(self.state)

    def as_vector(self):
        player_center = self.center_of_player()
        px_norm = player_center[0] / self.state_h
        py_norm = player_center[1] / self.state_w

        up_wall_dist = player_center[0] - self.up_wall[0]
        down_wall_dist = self.down_wall[0] - player_center[0]
        left_wall_dist = player_center[1] - self.left_wall[1]
        right_wall_dist = self.right_wall[1] - player_center[1]

        dist_up_norm = up_wall_dist / self.state_h
        dist_down_norm = down_wall_dist / self.state_h
        dist_left_norm = left_wall_dist / self.state_w
        dist_right_norm = right_wall_dist / self.state_w

        enemy_x_norm = self.closest_enemy[0] / self.state_h if self.closest_enemy is not None else 0
        enemy_y_norm = self.closest_enemy[1] / self.state_w if self.closest_enemy is not None else 0
        enemy_visible = 1.0 if self.closest_enemy is not None else 0.0

        return np.array([px_norm, py_norm,
                         dist_up_norm, dist_down_norm,
                         dist_left_norm, dist_right_norm,
                         enemy_x_norm, enemy_y_norm,
                         enemy_visible], dtype=np.float32)

    def center_of_player(self):
        if self.player_box is None:
            return 0, 0
        i_center = (self.player_box[0] + self.player_box[2]) // 2
        j_center = (self.player_box[1] + self.player_box[3]) // 2
        return i_center, j_center

    def direction_to_player(self, cord):
        i, j = cord
        left_up_corner = (self.player_box[0], self.player_box[1])
        left_down_corner = (self.player_box[2], self.player_box[1])
        right_up_corner = (self.player_box[0], self.player_box[3])
        right_down_corner = (self.player_box[2], self.player_box[3])

        if i < left_up_corner[0]:
            if j < left_up_corner[1]:
                return UPLEFT_DIRECTION
            elif j > right_up_corner[1]:
                return UPRIGHT_DIRECTION
            else:
                return UP_DIRECTION
        elif i > left_down_corner[0]:
            if j < left_down_corner[1]:
                return DOWNLEFT_DIRECTION
            elif j > right_down_corner[1]:
                return DOWNRIGHT_DIRECTION
            else:
                return DOWN_DIRECTION
        else:
            if j < left_up_corner[1]:
                return LEFT_DIRECTION
            elif j > right_up_corner[1]:
                return RIGHT_DIRECTION
            else:
                return -1


In [7]:
class Sarsa:
    alpha = 1e-5
    gamma = 0.99
    epsilon = 1
    lmbda = 0.9
    weight_decay = 1e-5
    z_clip = 10.0

    def __init__(self, n_actions):
        self.state_dim = VECTOR_STATE_SIZE
        feature_dim = self.state_dim + n_actions
        self.n_actions = n_actions
        # self.w = np.zeros(feature_dim, dtype=np.float32)
        # self.z = np.zeros_like(self.w, dtype=np.float32)
        self.w = np.zeros((self.n_actions, self.state_dim), dtype=np.float32)
        self.z = np.zeros_like(self.w, dtype=np.float32)

    def phi_from_state_action(self, features, action):
        phi = np.zeros_like(self.w, dtype=np.float32)
        phi[action] = features.astype(np.float32)
        return phi

    def _q_values_all_actions(self, state_features):
        # q_base = np.dot(state_features, self.w[:self.state_dim])
        # return q_base + self.w[self.state_dim:]
        return self.w.dot(state_features.astype(np.float32))

    def epsilon_greedy(self, features):
        eps = float(getattr(self, "epsilon", 0.0))
        eps = max(0.0, min(1.0, eps))

        q_vals = self._q_values_all_actions(features)

        if np.random.rand() < eps:
            action = np.random.randint(self.n_actions)
        else:
            action = np.argmax(q_vals)

        return action, q_vals

    def save(self, file_name="sarsa_weights.npz"):
        # np.savez(file_name, w=self.w)
        np.savez(file_name, w=self.w.reshape(-1), n_actions=self.n_actions, state_dim=self.state_dim)


    def restrict_exploration(self):
        self.epsilon = 0.0

    def reset_traces(self):
        self.z.fill(0.0)

    # @staticmethod
    # def load(file_name="sarsa_weights.npz"):
    #     data = np.load(file_name)
    #     n_actions = data['w'].shape[0] - VECTOR_STATE_SIZE
    #     agent = Sarsa(n_actions)
    #     agent.w = data['w']
    #     return agent
    @staticmethod
    def load(file_name="sarsa_weights.npz"):
        data = np.load(file_name)
        n_actions = int(data['n_actions'])
        state_dim = int(data['state_dim'])
        flat = data['w']
        agent = Sarsa(n_actions)
        agent.state_dim = state_dim
        agent.w = flat.reshape((n_actions, state_dim)).astype(np.float32)
        agent.z = np.zeros_like(agent.w, dtype=np.float32)
        return agent

In [8]:
def file_exist(file_name):
    try:
        _ = np.load(file_name)
        return True
    except FileNotFoundError:
        return False

In [9]:
class Examinator:
    NOT_SHOOT_ENEMY_PENALTY = -1.5
    LIVING_PENALTY = -0.1
    DEATH_PENALTY = -12
    STAY_IN_DANGER_PENALTY = -0.5
    FIRE_ENEMY_BONUS = 6.0
    FAR_FROM_WALL_BONUS = 0.03
    MOVE_BONUS = 0.01
    def __init__(self):
        pass

    def examine(self, state, action, model, reward):
        shaped_reward = reward

        shaped_reward += self.LIVING_PENALTY

        if state.closest_enemy is not None:
            enemy_dir = state.direction_to_player(state.closest_enemy)
            required_action = FIRE_ACTIONS[1] + enemy_dir
            if action != required_action and reward <= 0:
                shaped_reward += Examinator.NOT_SHOOT_ENEMY_PENALTY
            if action == 0:
                shaped_reward += Examinator.STAY_IN_DANGER_PENALTY
            if action == required_action:
                shaped_reward += Examinator.FIRE_ENEMY_BONUS


        alive = state.player_box is not None
        if not alive:
            shaped_reward += Examinator.DEATH_PENALTY

        min_distance_to_wall = min(
            state.center_of_player()[0] - state.up_wall[0],
            state.down_wall[0] - state.center_of_player()[0],
            state.center_of_player()[1] - state.left_wall[1],
            state.right_wall[1] - state.center_of_player()[1]
        )
        if min_distance_to_wall > 15:
            shaped_reward += Examinator.FAR_FROM_WALL_BONUS

        if action in MOVE_ACTIONS:
            shaped_reward += Examinator.MOVE_BONUS

        return shaped_reward


In [10]:
class Trainer:
    def __init__(self, epsilon_min = 0.05, epsilon_decay_fraction = 0.999, initial_epsilon = 1.0):
        self.epsilon_min = epsilon_min
        self.epsilon_decay_fraction = epsilon_decay_fraction
        self.initial_epsilon = initial_epsilon

    @staticmethod
    def _file_name_for_class(class_name):
        return f"sarsa-weights-{class_name.lower()}.npz"

    def train_if_needed(self, model, env, class_name, n_episodes=1000):
        file_name = Trainer._file_name_for_class(class_name)
        print(f'Checking for existing model file: {file_name}')
        if not file_exist(file_name):
            self.train(model, env, class_name, n_episodes)
            return model

        return Sarsa.load(file_name)

    def train(self, model, env, class_name, n_episodes=1000):
        print(f"Training {class_name} agent...")
        action_counts = np.zeros(env.action_space.n, dtype=np.float32)
        examinator = Examinator()

        rewards = []
        w_changes = []
        previous_w = model.w.copy()

        model.epsilon = self.initial_epsilon
        decay_episodes = int(n_episodes * self.epsilon_decay_fraction)
        if decay_episodes > 0:
             epsilon_decay_step = (self.initial_epsilon - self.epsilon_min) / decay_episodes
        else:
             epsilon_decay_step = 0
        print(f"Epsilon will decay from {self.initial_epsilon} to {self.epsilon_min} over {decay_episodes} episodes.")

        log_step = max(1, n_episodes // 100)

        for episode in range(n_episodes):
            _ = env.reset()

            for j in range(0, 6):
                _ = env.step(0)

            state = env.render()
            featured_state = State(state)
            state_vector = featured_state.as_vector()
            model.reset_traces()
            action, q_values = model.epsilon_greedy(state_vector)
            action_counts[action] += 1
            phi = model.phi_from_state_action(state_vector, action)

            done = False
            ep_reward = 0
            episode_start_time = time.time()

            while not done:
                next_state, reward, terminated, truncated, _ = env.step(action)
                next_featured_state = State(next_state)

                if next_featured_state.is_empty:
                    break

                done = terminated or truncated or next_featured_state.is_empty or next_featured_state.player_box is None
                next_features = next_featured_state.as_vector()

                next_action, next_q_values = model.epsilon_greedy(next_features)
                action_counts[next_action] += 1
                next_phi = model.phi_from_state_action(next_features, next_action)

                if len(q_values) != model.n_actions:
                    raise ValueError(f"Expected q_values of length {model.n_actions}, got {len(q_values)}")

                q = q_values[action]
                if done:
                    break
                else:
                    q_next = next_q_values[next_action]

                action_counts[next_action] += 1
                shaped_reward = examinator.examine(next_featured_state, action, model, reward)

                delta = shaped_reward + model.gamma * q_next - q
                model.z = (model.gamma * model.lmbda * model.z) + phi
                model.w += model.alpha * delta * model.z


                action = next_action
                q_values = next_q_values
                phi = next_phi
                ep_reward += reward

            new_epsilon = model.epsilon - epsilon_decay_step
            model.epsilon = max(self.epsilon_min, new_epsilon)

            w_change = np.mean(np.abs(model.w - previous_w))
            w_changes.append(w_change)
            previous_w = model.w.copy()

            rewards.append(ep_reward)


            if (episode + 1) % log_step == 0:
                recent_max = float(np.max(rewards[-log_step:])) if len(rewards) > 0 else float(ep_reward)
                print(f"Episode {episode+1}/{n_episodes}: Max reward for period={recent_max:.2f}, Eps={model.epsilon:.4f}")

        px.line(x=np.arange(1, n_episodes + 1), y=w_changes, labels={'x': 'Episode', 'y': 'Mean |Δw|'},
                title='Mean Weight Change over Episodes').show()
        px.line(x=np.arange(1, n_episodes + 1), y=rewards, labels={'x': 'Episode', 'y': 'Reward'},
                title='Episode Rewards over Time').show()
        print(f'Action distribution during training: {action_counts}')
        print(f"Training completed. Max score ever: {np.max(rewards)}")
        model.save(self._file_name_for_class(class_name))


In [11]:
CLASS_NAME = "Berzerk"

ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
agent = Sarsa(env.action_space.n)
observation, info = env.reset()

In [12]:

epsilon_min = 0.1

trainer = Trainer(epsilon_min, 0.8, initial_epsilon=0.5)
agent = trainer.train_if_needed(agent, env, class_name=CLASS_NAME, n_episodes=10000)

env.close()

Checking for existing model file: sarsa-weights-berzerk.npz
Training Berzerk agent...
Epsilon will decay from 0.5 to 0.1 over 8000 episodes.
Episode 100/10000: Max reward for period=300.00, Eps=0.4950
Episode 200/10000: Max reward for period=250.00, Eps=0.4900
Episode 300/10000: Max reward for period=300.00, Eps=0.4850
Episode 400/10000: Max reward for period=250.00, Eps=0.4800
Episode 500/10000: Max reward for period=300.00, Eps=0.4750
Episode 600/10000: Max reward for period=300.00, Eps=0.4700
Episode 700/10000: Max reward for period=300.00, Eps=0.4650
Episode 800/10000: Max reward for period=300.00, Eps=0.4600
Episode 900/10000: Max reward for period=250.00, Eps=0.4550
Episode 1000/10000: Max reward for period=250.00, Eps=0.4500
Episode 1100/10000: Max reward for period=300.00, Eps=0.4450
Episode 1200/10000: Max reward for period=250.00, Eps=0.4400
Episode 1300/10000: Max reward for period=300.00, Eps=0.4350
Episode 1400/10000: Max reward for period=250.00, Eps=0.4300
Episode 1500/1

Action distribution during training: [  40541.   38435.   38708.   38422.   37970.   38571.   37796.   38066.
   37980.   38222.   37844.   38263. 2456270.   38127.   38376.   38562.
   38231.   38272.]
Training completed. Max score ever: 300.0


# Test

In [13]:
agent = Sarsa.load(Trainer._file_name_for_class(CLASS_NAME))

In [14]:
ale = ALEInterface()
gym.register_envs(ale)

test_env = gym.make("ALE/Berzerk-v5", render_mode="human", frameskip=4)
agent.restrict_exploration()

In [15]:
n_episodes = 5
total_rewards = []

# agent.w = normalize_weights(agent.w)
print(f'w={agent.w}')

for ep in range(n_episodes):
    state, _ = test_env.reset()
    done = False
    ep_reward = 0

    actions_count = np.zeros(test_env.action_space.n, dtype=np.int32)
    while not done:
        feature_state = State(state)
        if feature_state.is_empty:
            action = 0
        else:
            action, _ = agent.epsilon_greedy(feature_state.as_vector())
        actions_count[action] += 1
        next_state, reward, terminated, truncated, _ = test_env.step(action)
        done = terminated or truncated

        state = next_state
        ep_reward += reward

    test_env.render()
    print(f"Episode {ep + 1}: Total Reward = {ep_reward}")
    print(f'Action count during round: {actions_count}')
    print('---------------------------')
    total_rewards.append(ep_reward)

test_env.close()

print(f"\nAverage Test Reward over {n_episodes} episodes: {np.mean(total_rewards):.2f}")

w=[[15.014662    2.0062275  14.356636   15.845463    1.4370284   3.8519533
   6.054344    0.7237504  10.074114  ]
 [15.185865    2.0275002  14.524857   16.014742    1.4513685   3.8977253
   6.714184    0.79274744 11.138592  ]
 [15.366541    2.041494   14.662314   16.230688    1.461675    3.9565687
   6.4250693   0.76273334 10.622774  ]
 [16.593931    2.142232   15.861583   17.51792     1.5172107   4.3349066
   6.971366    0.8217192  11.496807  ]
 [14.108233    1.9147707  13.455069   14.823907    1.3803091   3.588958
   5.470274    0.68132323  9.193173  ]
 [15.465123    2.0528522  14.767695   16.325874    1.4690257   3.9925737
   6.4288044   0.76725084 10.62956   ]
 [16.455408    2.136054   15.710891   17.37002     1.5162139   4.289538
   6.741434    0.79333055 11.16471   ]
 [14.011303    1.8981627  13.389121   14.727691    1.367764    3.5557282
   5.2557907   0.656496    8.939997  ]
 [16.566257    2.1341515  15.800438   17.477926    1.5105457   4.3308907
   6.874131    0.7980438  11.23