In [3]:
import gymnasium as gym
from gymnasium import spaces
from gymnasium.wrappers import RecordVideo
import pygame
import numpy as np
import time
from tqdm.notebook import tqdm
from collections.abc import Mapping

# Helper function to flatten the complex observation space for state mapping
def flatten_obs(obs):
    """Flattens the MultiDiscrete observation into a single tuple."""
    return tuple(obs.flatten())

print(f"Gymnasium version: {gym.__version__}")
print(f"Pygame version: {pygame.version.ver}")
print(f"Numpy version: {np.__version__}")

Gymnasium version: 1.2.1
Pygame version: 2.6.1
Numpy version: 2.0.2


In [4]:
class GridMazeEnv(gym.Env):
    """
    Custom Stochastic Grid Maze Environment based on Gymnasium.

    State: [agent_x, agent_y, goal_x, goal_y, bad1_x, bad1_y, bad2_x, bad2_y]
    Actions: 0: Right, 1: Up, 2: Left, 3: Down
    """
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, size=5, render_mode=None):
        super().__init__()
        self.size = size
        self.window_size = 512  # Pygame window size
        self.window = None
        self.clock = None

        # Observation Space: [ax, ay, gx, gy, b1x, b1y, b2x, b2y]
        # Each coordinate is from 0 to size-1
        self.observation_space = spaces.MultiDiscrete(np.array([size] * 8))

        # Action Space: 4 discrete actions
        self.action_space = spaces.Discrete(4)

        # Action to (x, y) change mapping
        # 0: Right, 1: Up, 2: Left, 3: Down
        self._action_to_direction = {
            0: np.array([1, 0]),  # Right
            1: np.array([0, -1]), # Up
            2: np.array([-1, 0]), # Left
            3: np.array([0, 1]),  # Down
        }

        # Perpendicular directions for stochasticity
        self._perpendicular_dirs = {
            0: [np.array([0, -1]), np.array([0, 1])],  # Right -> Up/Down
            1: [np.array([-1, 0]), np.array([1, 0])],  # Up -> Left/Right
            2: [np.array([0, -1]), np.array([0, 1])],  # Left -> Up/Down
            3: [np.array([-1, 0]), np.array([1, 0])],  # Down -> Left/Right
        }

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        if self.render_mode == "human":
            self._pygame_init()

    def _pygame_init(self):
        """Initialize PyGame for rendering."""
        pygame.init()
        pygame.display.init()
        self.window = pygame.display.set_mode((self.window_size, self.window_size))
        self.clock = pygame.time.Clock()
        self._font = pygame.font.Font(None, 36)

    def _get_obs(self):
        """Get the current observation from the state."""
        return np.array([
            self._agent_location[0], self._agent_location[1],
            self._goal_location[0], self._goal_location[1],
            self._bad1_location[0], self._bad1_location[1],
            self._bad2_location[0], self._bad2_location[1]
        ])

    def _get_info(self):
        """Get auxiliary info."""
        return {}

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        # Choose 4 unique locations for agent, goal, and bad cells
        locations = self.np_random.choice(
            self.size * self.size, 4, replace=False
        )
        coords = [np.array([loc % self.size, loc // self.size]) for loc in locations]

        self._agent_location = coords[0]
        self._goal_location = coords[1]
        self._bad1_location = coords[2]
        self._bad2_location = coords[3]

        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, info

    def _design_reward(self, terminated):
        """Implements the reward function as required."""
        if terminated:
            if np.array_equal(self._agent_location, self._goal_location):
                return 1.0  # High positive reward for reaching the goal
            else:
                return -1.0 # High negative reward for hitting a bad cell
        else:
            return -0.01 # Small negative penalty for each step

    def step(self, action):
        # Stochastic movement logic
        p = self.np_random.random()

        if p < 0.70:
            # 70% chance of intended direction
            direction = self._action_to_direction[action]
        elif p < 0.85:
            # 15% chance of perpendicular 1
            direction = self._perpendicular_dirs[action][0]
        else:
            # 15% chance of perpendicular 2
            direction = self._perpendicular_dirs[action][1]

        # Apply movement and clip to stay within grid boundaries (0 to size-1)
        self._agent_location = np.clip(
            self._agent_location + direction, 0, self.size - 1
        )

        # Check for termination
        terminated = (
            np.array_equal(self._agent_location, self._goal_location) or
            np.array_equal(self._agent_location, self._bad1_location) or
            np.array_equal(self._agent_location, self._bad2_location)
        )

        # Get reward
        reward = self._design_reward(terminated)

        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, reward, terminated, False, info # False for truncated

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()
        elif self.render_mode == "human":
            self._render_frame()
            pygame.event.pump()
            pygame.display.update()
            self.clock.tick(self.metadata["render_fps"])

    def _render_frame(self):
        if self.window is None and self.render_mode == "human":
            self._pygame_init()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        pix_size = self.window_size / self.size

        # Draw Goal (Green)
        pygame.draw.rect(
            canvas,
            (0, 255, 0),
            pygame.Rect(
                pix_size * self._goal_location[0],
                pix_size * self._goal_location[1],
                pix_size,
                pix_size,
            ),
        )

        # Draw Bad Cells (Red)
        for bad_loc in [self._bad1_location, self._bad2_location]:
            pygame.draw.rect(
                canvas,
                (255, 0, 0),
                pygame.Rect(
                    pix_size * bad_loc[0],
                    pix_size * bad_loc[1],
                    pix_size,
                    pix_size,
                ),
            )

        # Draw Agent (Blue)
        pygame.draw.circle(
            canvas,
            (0, 0, 255),
            (self._agent_location + 0.5) * pix_size,
            pix_size / 3,
        )

        # Draw gridlines
        for x in range(self.size + 1):
            pygame.draw.line(
                canvas,
                (0, 0, 0),
                (0, pix_size * x),
                (self.window_size, pix_size * x),
                width=3,
            )
            pygame.draw.line(
                canvas,
                (0, 0, 0),
                (pix_size * x, 0),
                (pix_size * x, self.window_size),
                width=3,
            )

        if self.render_mode == "human":
            self.window.blit(canvas, canvas.get_rect())
        else:  # rgb_array
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )

    def close(self):
        if self.window:
            pygame.display.quit()
            pygame.quit()
            self.window = None

In [5]:
class StateMapper:
    """
    Handles mapping the 8-tuple state to a unique integer index and back.
    Implements the optimization for Q2 by treating bad cells as an unordered set.
    """
    def __init__(self, size=5):
        self.size = size
        self.n_pos = size * size  # 25

        # Calculate size of an unordered pair of bad cells (combinations w/ replacement)
        # N*(N+1) / 2 = 25*(26)/2 = 325
        self.n_bad_pairs = (self.n_pos * (self.n_pos + 1)) // 2

        self.n_states = self.n_pos * self.n_pos * self.n_bad_pairs

        # Pre-compute mapping for unordered pairs for speed
        self._bad_pair_to_idx = {}
        self._idx_to_bad_pair = [None] * self.n_bad_pairs
        idx = 0
        for b1 in range(self.n_pos):
            for b2 in range(b1, self.n_pos): # b2 >= b1
                self._bad_pair_to_idx[(b1, b2)] = idx
                self._idx_to_bad_pair[idx] = (b1, b2)
                idx += 1

    def _map_bad_pair(self, b1_pos, b2_pos):
        """Maps an unordered (b1, b2) pair to a unique index 0-324."""
        if b1_pos > b2_pos:
            b1_pos, b2_pos = b2_pos, b1_pos
        return self._bad_pair_to_idx[(b1_pos, b2_pos)]

    def _unmap_bad_pair(self, idx):
        """Maps an index 0-324 back to a (b1, b2) pair."""
        return self._idx_to_bad_pair[idx]

    def state_to_index(self, obs):
        """Converts an 8-tuple observation [ax,ay,gx,gy,b1x,b1y,b2x,b2y] to an index."""
        ax, ay, gx, gy, b1x, b1y, b2x, b2y = obs

        a_pos = ax * self.size + ay
        g_pos = gx * self.size + gy
        b1_pos = b1x * self.size + b1y
        b2_pos = b2x * self.size + b2y

        b_pair_idx = self._map_bad_pair(b1_pos, b2_pos)

        # Final index calculation
        index = (a_pos * self.n_pos * self.n_bad_pairs) + \
                (g_pos * self.n_bad_pairs) + \
                b_pair_idx
        return index

    def index_to_state(self, index):
        """Converts a unique index back to an 8-tuple state."""
        b_pair_idx = index % self.n_bad_pairs
        index //= self.n_bad_pairs

        g_pos = index % self.n_pos
        index //= self.n_pos

        a_pos = index

        b1_pos, b2_pos = self._unmap_bad_pair(b_pair_idx)

        return np.array([
            a_pos // self.size, a_pos % self.size,  # ax, ay
            g_pos // self.size, g_pos % self.size,  # gx, gy
            b1_pos // self.size, b1_pos % self.size, # b1x, b1y
            b2_pos // self.size, b2_pos % self.size  # b2x, b2y
        ])

# --- Test the mapper ---
mapper = StateMapper(size=5)
print(f"Grid size: {mapper.size}x{mapper.size}")
print(f"Total positions (N): {mapper.n_pos}")
print(f"Unordered bad pairs: {mapper.n_bad_pairs}")
print(f"Optimized State Space Size |S|: {mapper.n_states}")

# Test case: Bad1=(1,2), Bad2=(3,4)
obs1 = np.array([0,0, 4,4, 1,2, 3,4])
# Test case: Bad1=(3,4), Bad2=(1,2)
obs2 = np.array([0,0, 4,4, 3,4, 1,2])

idx1 = mapper.state_to_index(obs1)
idx2 = mapper.state_to_index(obs2)

print(f"\nIndex for obs1: {idx1}")
print(f"Index for obs2: {idx2}")
assert idx1 == idx2
print("Test PASSED: Symmetrical states map to the same index.")

restored_obs = mapper.index_to_state(idx1)
print(f"Restored obs from index: {restored_obs}")
# Note: The restored obs may have b1/b2 swapped, which is correct.

Grid size: 5x5
Total positions (N): 25
Unordered bad pairs: 325
Optimized State Space Size |S|: 203125

Index for obs1: 7966
Index for obs2: 7966
Test PASSED: Symmetrical states map to the same index.
Restored obs from index: [0 0 4 4 1 2 3 4]


In [8]:
def build_transition_model(mapper):
    """
    Builds the P(s'|s,a) and R(s,a,s') model for the entire MDP.

    Returns:
    model[s_idx][a] = [ (prob, next_s_idx, reward), ... ]
    """
    n_states = mapper.n_states
    n_actions = 4
    size = mapper.size

    # Action map
    _action_to_direction = {
        0: np.array([1, 0]), 1: np.array([0, -1]),
        2: np.array([-1, 0]), 3: np.array([0, 1])
    }
    _perpendicular_dirs = {
        0: [np.array([0, -1]), np.array([0, 1])],
        1: [np.array([-1, 0]), np.array([1, 0])],
        2: [np.array([0, -1]), np.array([0, 1])],
        3: [np.array([-1, 0]), np.array([1, 0])]
    }

    # model[s_idx][action] = list of (prob, next_s_idx, reward)
    model = [[[] for _ in range(n_actions)] for _ in range(n_states)]

    print(f"Building transition model for {n_states} states...")

    for s_idx in tqdm(range(n_states)):
        state = mapper.index_to_state(s_idx)
        agent_pos = state[0:2]
        goal_pos = state[2:4]
        bad1_pos = state[4:6]
        bad2_pos = state[6:8]

        # Check if s_idx is a terminal state
        is_terminal = (
            np.array_equal(agent_pos, goal_pos) or
            np.array_equal(agent_pos, bad1_pos) or
            np.array_equal(agent_pos, bad2_pos)
        )

        if is_terminal:
            # In a terminal state, all actions lead to self-loop w/ 0 reward
            for a in range(n_actions):
                model[s_idx][a].append((1.0, s_idx, 0.0))
            continue

        # If not terminal, calculate transitions
        for a in range(n_actions):
            outcomes = [
                (0.70, _action_to_direction[a]),
                (0.15, _perpendicular_dirs[a][0]),
                (0.15, _perpendicular_dirs[a][1])
            ]

            for prob, direction in outcomes:
                # Get next agent position
                next_agent_pos = np.clip(agent_pos + direction, 0, size - 1)

                # Check for termination
                terminated = (
                    np.array_equal(next_agent_pos, goal_pos) or
                    np.array_equal(next_agent_pos, bad1_pos) or
                    np.array_equal(next_agent_pos, bad2_pos)
                )

                # Calculate reward
                if terminated:
                    reward = 1.0 if np.array_equal(next_agent_pos, goal_pos) else -1.0
                else:
                    reward = -0.01 # Step penalty

                # Construct next state
                next_state_obs = np.array([
                    next_agent_pos[0], next_agent_pos[1],
                    goal_pos[0], goal_pos[1],
                    bad1_pos[0], bad1_pos[1],
                    bad2_pos[0], bad2_pos[1]
                ])
                next_s_idx = mapper.state_to_index(next_state_obs)

                # Store (prob, next_state, reward)
                model[s_idx][a].append((prob, next_s_idx, reward))

    print("Model build complete.")
    return model


def policy_iteration(model, mapper, gamma=0.99, theta=1e-6):
    """
    Performs Policy Iteration.

    Args:
    - model: The transition model P(s'|s,a)
    - mapper: The StateMapper object
    - gamma: Discount factor
    - theta: Convergence threshold for Policy Evaluation
    """
    n_states = mapper.n_states
    n_actions = 4

    # 1. Initialize V(s) and pi(s)
    V = np.zeros(n_states)
    policy = np.zeros(n_states, dtype=int) # Default to action 0 (Right)

    policy_iteration_loops = 0
    total_evaluation_sweeps = 0

    while True:
        policy_iteration_loops += 1
        print(f"\n--- Policy Iteration Loop #{policy_iteration_loops} ---")

        # 2. Policy Evaluation
        # 2. Policy Evaluation
        print("Running Policy Evaluation...")
        eval_sweeps = 0
        while True:
            delta = 0
            eval_sweeps += 1

            # --- ADD THIS LINE ---
            if eval_sweeps % 10 == 0:
                print(f"    ... PE sweep #{eval_sweeps}, current delta: {delta:.2e}")
            # --- END ADD ---

            for s_idx in range(n_states):
                v_old = V[s_idx]
                v_new = 0

                # Get the action from the current policy
                a = policy[s_idx]

                # Calculate expected value
                for (prob, next_s_idx, reward) in model[s_idx][a]:
                    v_new += prob * (reward + gamma * V[next_s_idx])

                V[s_idx] = v_new
                delta = max(delta, abs(v_old - v_new))

            if delta < theta:
                print(f"Policy Evaluation converged in {eval_sweeps} sweeps.")
                total_evaluation_sweeps += eval_sweeps
                break

        # 3. Policy Improvement
        print("Running Policy Improvement...")
        policy_stable = True
        for s_idx in range(n_states):
            old_action = policy[s_idx]

            # Find the best action
            action_values = np.zeros(n_actions)
            for a in range(n_actions):
                for (prob, next_s_idx, reward) in model[s_idx][a]:
                    action_values[a] += prob * (reward + gamma * V[next_s_idx])

            policy[s_idx] = np.argmax(action_values)

            if old_action != policy[s_idx]:
                policy_stable = False

        if policy_stable:
            print(f"\nPolicy converged and is stable after {policy_iteration_loops} iterations.")
            break

    return policy, V, policy_iteration_loops, total_evaluation_sweeps

In [9]:
# --- 1. Setup ---
start_time = time.time()
mapper = StateMapper(size=5)

# --- 2. Build Model ---
model = build_transition_model(mapper)
model_build_time = time.time() - start_time
print(f"Model build took {model_build_time:.2f} seconds.")

# --- 3. Run Policy Iteration ---
train_start_time = time.time()
policy_table, value_table, pi_loops, pe_loops = policy_iteration(
    model, mapper, gamma=0.99
)
train_time = time.time() - train_start_time
print(f"Policy Iteration took {train_time:.2f} seconds.")

total_time = time.time() - start_time
print(f"\n--- Total process complete in {total_time:.2f} seconds ---")

Building transition model for 203125 states...


  0%|          | 0/203125 [00:00<?, ?it/s]

Model build complete.
Model build took 60.21 seconds.

--- Policy Iteration Loop #1 ---
Running Policy Evaluation...
    ... PE sweep #10, current delta: 0.00e+00
    ... PE sweep #20, current delta: 0.00e+00
    ... PE sweep #30, current delta: 0.00e+00
    ... PE sweep #40, current delta: 0.00e+00
    ... PE sweep #50, current delta: 0.00e+00
    ... PE sweep #60, current delta: 0.00e+00
    ... PE sweep #70, current delta: 0.00e+00
    ... PE sweep #80, current delta: 0.00e+00
    ... PE sweep #90, current delta: 0.00e+00
    ... PE sweep #100, current delta: 0.00e+00
    ... PE sweep #110, current delta: 0.00e+00
    ... PE sweep #120, current delta: 0.00e+00
    ... PE sweep #130, current delta: 0.00e+00
    ... PE sweep #140, current delta: 0.00e+00
    ... PE sweep #150, current delta: 0.00e+00
    ... PE sweep #160, current delta: 0.00e+00
    ... PE sweep #170, current delta: 0.00e+00
    ... PE sweep #180, current delta: 0.00e+00
    ... PE sweep #190, current delta: 0.00e+00

In [12]:
print("Testing the trained policy and recording video...")

# Create a new env, wrapped for video recording
env = GridMazeEnv(size=5, render_mode="rgb_array")
video_env = RecordVideo(
    env,
    video_folder="videos",
    name_prefix="policy-iteration-agent",
    episode_trigger=lambda e: True  # Record every episode
)

# Run one full episode
(obs, info) = video_env.reset()
terminated = False
truncated = False
total_reward = 0
steps = 0

while not (terminated or truncated):
    # 1. Get the state index from the observation
    s_idx = mapper.state_to_index(obs)

    # 2. Look up the optimal action from our trained policy
    action = policy_table[s_idx]

    # 3. Take the action
    obs, reward, terminated, truncated, info = video_env.step(action)

    total_reward += reward
    steps += 1

video_env.close()
env.close()

print(f"\nEpisode finished in {steps} steps.")
print(f"Total reward: {total_reward:.2f}")
print("Video saved to 'videos' folder.")

Testing the trained policy and recording video...

Episode finished in 11 steps.
Total reward: 0.90
Video saved to 'videos' folder.
