## Prerequisities

In [1]:
!pip install gymnasium
!pip install pygame
!pip install stable-baselines3
!pip install torch
!pip install numpy

In [1]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from gymnasium.wrappers import RecordVideo
import pygame
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
import random

  from pkg_resources import resource_stream, resource_exists


## Custom Maze Env

In [2]:
class GridMazeEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 2}

    def __init__(self, grid_size=5, mines=None, cell_size=100, render_mode="human",rnd=True):

        super().__init__()
        self.grid_size = grid_size
        self.cell_size = cell_size
        self.width = self.grid_size * self.cell_size
        self.height = self.grid_size * self.cell_size
        self.render_mode = render_mode
        self.action_space = spaces.Discrete(4)  # 0: Up, 1: Down, 2: Left, 3: Right
        self.observation_space = spaces.Box(low=0, high=grid_size-1, shape=(2,), dtype=np.int32)
        self.rnd = rnd
        self.goal_pos = np.array([grid_size-1, grid_size-1])

        # Mines: list of (row, col) tuples
        if mines is None:
            self.mines = [(1, 1), (2, 3)]
        else:
            self.mines = mines

        # PyGame setup
        self.screen = None
        self.clock = None
        self.running = False

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # Random player position that is not goal or mine
        if self.rnd:
            while True: 
                pos = np.array([random.randint(0, self.grid_size-1),
                                random.randint(0, self.grid_size-1)])
                if not np.array_equal(pos, self.goal_pos) and tuple(pos) not in self.mines:
                    self.agent_pos = pos
                    break
        else:
            self.agent_pos = np.array([0, 0])
        return self.agent_pos, {} #(observation, info)

    def step(self, action):
        # Move the agent
        if action == 0:  # Up
            self.agent_pos[0] = max(self.agent_pos[0] - 1, 0)
        elif action == 1:  # Down
            self.agent_pos[0] = min(self.agent_pos[0] + 1, self.grid_size - 1)
        elif action == 2:  # Left
            self.agent_pos[1] = max(self.agent_pos[1] - 1, 0)
        elif action == 3:  # Right
            self.agent_pos[1] = min(self.agent_pos[1] + 1, self.grid_size - 1)

        # Check for goal
        if np.array_equal(self.agent_pos, self.goal_pos):
            reward = 5
            done = True
        # Check for mine
        elif tuple(self.agent_pos) in self.mines:
            reward = -50
            done = True
        else:
            reward = -5
            done = False

        return self.agent_pos, reward, done, False, {} #(observation, reward, terminated, truncated [due to time limit], info)


    def render(self,action=None):
        if self.screen is None:
            pygame.init()
            if self.render_mode == "human":
                self.screen = pygame.display.set_mode((self.width, self.height))
                pygame.display.set_caption("Grid Maze")
            else:  # rgb_array mode
                self.screen = pygame.Surface((self.width, self.height))
            self.clock = pygame.time.Clock()
            self.running = True

        # Fill background
        self.screen.fill((255, 255, 255))  # white background

        # Draw grid lines
        for x in range(0, self.width, self.cell_size):
            pygame.draw.line(self.screen, (0, 0, 0), (x, 0), (x, self.height))
        for y in range(0, self.height, self.cell_size):
            pygame.draw.line(self.screen, (0, 0, 0), (0, y), (self.width, y))

        # Draw goal
        goal_rect = pygame.Rect(self.goal_pos[1]*self.cell_size, self.goal_pos[0]*self.cell_size,
                                self.cell_size, self.cell_size)
        pygame.draw.rect(self.screen, (0, 255, 0), goal_rect)  # green

        # Draw mines
        for mine in self.mines:
            mine_rect = pygame.Rect(mine[1]*self.cell_size, mine[0]*self.cell_size,
                                    self.cell_size, self.cell_size)
            pygame.draw.rect(self.screen, (255, 0, 0), mine_rect)  # red

        # Draw agent as circle
        center_x = self.agent_pos[1]*self.cell_size + self.cell_size//2
        center_y = self.agent_pos[0]*self.cell_size + self.cell_size//2
        radius = self.cell_size // 3
        pygame.draw.circle(self.screen, (0, 0, 255), (center_x, center_y), radius)  # blue
      
        # Handle render modes
        if self.render_mode == "human":
            pygame.display.get_surface().blit(self.screen, (0, 0))
            pygame.display.flip()
            self.clock.tick(self.metadata["render_fps"])
            # pygame.time.delay(100)
        elif self.render_mode == "rgb_array":
            # return RGB array for video recording
            return np.transpose(pygame.surfarray.array3d(self.screen), (1, 0, 2))


    def close(self):
        if self.screen is not None:
            pygame.quit()
            self.screen = None
            self.clock = None
            self.running = False


## Train a DQN model

In [3]:
env = GridMazeEnv(grid_size=5, mines=[(1,1), (2,3)])
vec_env = make_vec_env(lambda: env, n_envs=1)

model = DQN(
    "MlpPolicy",
    vec_env,
    learning_rate=0.001,
    buffer_size=200_000,
    learning_starts=1000,
    batch_size=64,
    gamma=0.99,
    target_update_interval=500,
    verbose=1
)

model.learn(total_timesteps=20000)
model.save("gridmaze_dqn")


Using cpu device
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.25     |
|    ep_rew_mean      | -33.8    |
|    exploration_rate | 0.994    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 6032     |
|    time_elapsed     | 0        |
|    total_timesteps  | 13       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 5.12     |
|    ep_rew_mean      | -56.9    |
|    exploration_rate | 0.981    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 7865     |
|    time_elapsed     | 0        |
|    total_timesteps  | 41       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 8.17     |
|    ep_rew_mean      | -72.1    |
|    exploration_rate | 0.953    |
| time/               |          |
|  

## DQN in action

In [4]:
model = DQN.load("gridmaze_dqn")
env = GridMazeEnv(grid_size=5, mines=[ (1,3),(3,2)],rnd=False,  render_mode="rgb_array")
video_env = RecordVideo(env, video_folder="videos", episode_trigger=lambda e: True, fps=2)

obs, _ = video_env.reset()
done = False
while not done:
    action, _ = model.predict(obs)
    print(action)
    obs, reward, done, _, _ = video_env.step(action)
    
    # Render the frame with the action overlay
    video_env.render()  # your env already uses render_mode="rgb_array"
    env.render(action)  # pass the action for text overlay
video_env.close()


  logger.warn(


3
1
3
1
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
2
3
0
