In [16]:
from IPython.core.pylabtools import figsize
from ale_py import ALEInterface
import gymnasium as gym
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import display, clear_output
import cv2

In [17]:
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, h=21, w=21):
        super().__init__(env)
        self.h, self.w = h, w
        self.observation_space = gym.spaces.Box(0, 1, (h, w, 3), np.float32)

    def observation(self, obs):
        obs = cv2.resize(obs, (self.w, self.h), interpolation=cv2.INTER_NEAREST)
        return obs.astype(np.float32) / 255.0


In [18]:
ale = ALEInterface()
gym.register_envs(ale)

env = gym.make("ALE/Berzerk-v5", render_mode="rgb_array", frameskip=4)
env = ResizeObservation(env, h=21, w=21)
observation, info = env.reset()


In [19]:
print("action_space:", env.action_space)
print("n actions:", env.action_space.n)

action_space: Discrete(18)
n actions: 18


In [20]:
try:
    meanings = env.unwrapped.get_action_meanings()
except Exception:
    try:
        meanings = env.get_action_meanings()
    except Exception:
        meanings = None

if meanings:
    print("Action index -> meaning:")
    for i, name in enumerate(meanings):
        print(f"{i}: {name}")
else:
    print("No action meanings available from the env. Use index numbers (0..n-1).")


Action index -> meaning:
0: NOOP
1: FIRE
2: UP
3: RIGHT
4: LEFT
5: DOWN
6: UPRIGHT
7: UPLEFT
8: DOWNRIGHT
9: DOWNLEFT
10: UPFIRE
11: RIGHTFIRE
12: LEFTFIRE
13: DOWNFIRE
14: UPRIGHTFIRE
15: UPLEFTFIRE
16: DOWNRIGHTFIRE
17: DOWNLEFTFIRE


In [21]:
def plot_frame(frame):
    # clear_output(wait=True)
    plt.figure(figsize=(6, 4))
    plt.imshow(frame)
    plt.axis('off')
    display(plt.gcf())
    # plt.show()

In [22]:
def prepare_image(frame):
    img = Image.fromarray(frame)
    # img = img.convert('L')  # Convert to grayscale
    img = np.array(img)
    img = cv2.resize(img, (21, 21))
    img = img / 255.0
    # img = img.flatten()
    return np.array(img)

In [23]:
seed = 42
np.random.seed(seed)

In [24]:
class Sarsa:
    alpha = 1e-4
    gamma = 0.99
    epsilon = 0.1
    feature_h, feature_w = 21, 21
    use_traces = False
    lmbda = 0.9

    def __init__(self, n_actions):
        feature_dim = self.feature_h * self.feature_w * 3 + n_actions
        self.w = np.zeros(feature_dim, dtype=np.float32)
        self.n_actions = n_actions

    def _extract_rgb_features(self, frame):
        f = cv2.resize(frame, (self.feature_h, self.feature_w), interpolation=cv2.INTER_AREA)
        f = f.astype(np.float32) / 255.0
        return f.flatten()

    def phi_from_state_action(self, features, action):
        # features = self._extract_rgb_features(features)
        a_onehot = np.zeros(self.n_actions, dtype=np.float32)
        a_onehot[action] = 1.0
        return np.concatenate([features, a_onehot])

    def q_value(self, phi):
        return np.dot(self.w, phi)

    def _q_values_all_actions(self, state_features):
        features_tiled = np.tile(state_features, (self.n_actions, 1))  # shape (n_actions, state_dim)
        eye_actions = np.eye(self.n_actions, dtype=np.float32)
        phis = np.concatenate([features_tiled, eye_actions], axis=1)   # shape (n_actions, total_dim)
        return np.dot(phis, self.w)

    def epsilon_greedy(self, features):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.n_actions)

        q_vals = self._q_values_all_actions(features)
        return np.argmax(q_vals)

    def save(self):
        np.savez("sarsa_weights.npz", w=self.w)

    @staticmethod
    def load():
        data = np.load("sarsa_weights.npz")
        agent = Sarsa(n_actions=data['w'].shape[0] - 21*21*3)
        agent.w = data['w']
        return agent

    def restrict_exploration(self):
        self.epsilon = 0.0

In [25]:
agent = Sarsa(env.action_space.n)

for episode in range(100):
    state, _ = env.reset()
    features = np.array(state).flatten()
    action = agent.epsilon_greedy(features)
    phi = agent.phi_from_state_action(features, action)

    done = False
    ep_reward = 0

    q_next = None
    phi_next = None

    while not done:
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        next_features = np.array(next_state).flatten()
        next_action = agent.epsilon_greedy(next_features)
        next_phi = agent.phi_from_state_action(next_features, next_action)

        q = np.dot(agent.w, phi)
        q_next = np.dot(agent.w, next_phi)
        delta = reward + agent.gamma * q_next - q

        agent.w += agent.alpha * delta * phi

        state = next_state
        action = next_action
        phi = next_phi
        ep_reward += reward

    print(f"Episode {episode + 1}: Total Reward: {ep_reward}")

env.close()


Episode 1: Total Reward: 400.0
Episode 2: Total Reward: 400.0
Episode 3: Total Reward: 100.0
Episode 4: Total Reward: 150.0
Episode 5: Total Reward: 100.0
Episode 6: Total Reward: 550.0
Episode 7: Total Reward: 450.0
Episode 8: Total Reward: 50.0
Episode 9: Total Reward: 450.0
Episode 10: Total Reward: 470.0
Episode 11: Total Reward: 200.0
Episode 12: Total Reward: 0.0
Episode 13: Total Reward: 250.0
Episode 14: Total Reward: 100.0
Episode 15: Total Reward: 50.0
Episode 16: Total Reward: 150.0
Episode 17: Total Reward: 100.0
Episode 18: Total Reward: 640.0
Episode 19: Total Reward: 150.0
Episode 20: Total Reward: 100.0
Episode 21: Total Reward: 200.0
Episode 22: Total Reward: 250.0
Episode 23: Total Reward: 400.0
Episode 24: Total Reward: 100.0
Episode 25: Total Reward: 350.0
Episode 26: Total Reward: 200.0
Episode 27: Total Reward: 560.0
Episode 28: Total Reward: 500.0
Episode 29: Total Reward: 200.0
Episode 30: Total Reward: 640.0
Episode 31: Total Reward: 150.0
Episode 32: Total Rew

In [26]:
agent.save()

In [27]:
agent = Sarsa.load()

In [28]:
test_env = gym.make("ALE/Berzerk-v5", render_mode="human")
agent.restrict_exploration()

In [29]:
n_episodes = 10
total_rewards = []

for ep in range(n_episodes):
    state, _ = env.reset()
    done = False
    ep_reward = 0

    while not done:
        action = agent.epsilon_greedy(state)  # now deterministic
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        state = next_state
        ep_reward += reward

    print(f"Episode {ep + 1}: Total Reward = {ep_reward}")
    total_rewards.append(ep_reward)

env.close()

print(f"\nAverage Test Reward over {n_episodes} episodes: {np.mean(total_rewards):.2f}")

ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 2 dimension(s)