# Escaping the maze
In this notebook we will cover the basics of a reinforcement learning (RL) environment.

Specifically, we will cover the observation, action, and state space following the example of a maze.

In [2]:
import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

from qgym.environment import Environment
from qgym.rewarder import Rewarder

## Map of the maze

Our maze will have 4 different field types.

- `S`: start position
- `F`: a free field
- `W`: a wall
- `G`: the goal

In [3]:
maze_map_4x4 = ["FSFF", "SWFW", "FFFW", "WFFG"]

- `0`: UP
- `1`: RIGHT
- `2`: DOWN
- `3`: LEFT

In [7]:
class MazeRewarder(Rewarder):
    def compute_reward(self, old_state, action, new_state):
        row, col = new_state["position"]
        if new_state["maze_map"][row][col] == b"G":
            return 1
        else:
            return 0

In [8]:
class Maze(Environment):
    def __init__(self, maze_map):
        maze_map = np.asarray(maze_map, dtype="c")

        self.nrows = maze_map.shape[0]
        self.ncols = maze_map.shape[1]

        self.start_position_distribution = (maze_map == b"S").ravel().astype("float64")
        self.start_position_distribution /= self.start_position_distribution.sum()

        self.action_space = gym.spaces.Discrete(4)  # {0,1,2,3}
        self.observation_space = gym.spaces.Discrete(self.nrows * self.ncols)
        self._state = {"position": None, "maze_map": maze_map}
        self._rewarder = MazeRewarder()

    def rowcol_to_pos(self, row, col):
        return row * self.nrows + col

    def pos_to_rowcol(self, pos):
        return int(pos / self.nrows), pos % self.nrows

    def reset(self, *, seed=None, return_info=False):
        start_position = self.rng.choice(
            self.nrows * self.ncols, p=self.start_position_distribution
        )
        self._state["position"] = self.pos_to_rowcol(start_position)

        return super().reset(seed=seed, return_info=return_info)

    def _update_state(self, action):
        row, col = self._state["position"]

        # compute new position
        if action == 0:  # up
            row = max(row - 1, 0)
        elif action == 1:  # right
            col = min(col + 1, self.ncols - 1)
        elif action == 2:  # down
            row = min(row + 1, self.nrows - 1)
        elif action == 3:  # left
            col = max(col - 1, 0)
        else:
            raise ValueError("Invalid action supplied.")

        # go to new position if it is not a wall
        if self._state["maze_map"][row][col] != b"W":
            self._state["position"] = (row, col)
        # else we stay where we are

    def _obtain_observation(self):
        return self.rowcol_to_pos(*self._state["position"])

    def _is_done(self):
        row, col = self._state["position"]
        return self._state["maze_map"][row][col] == b"G"

    def _obtain_info(self):
        return {}

    def _compute_reward(self, old_state, action):
        return super()._compute_reward(
            old_state=old_state, action=action, new_state=self._state
        )

## Training an agent

In [9]:
env = Maze(maze_map_4x4)
check_env(env, warn=True)

model = PPO("MlpPolicy", env, verbose=1)

model.learn(int(1e5))

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 84.8     |
|    ep_rew_mean     | 1        |
| time/              |          |
|    fps             | 2294     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 58.1         |
|    ep_rew_mean          | 1            |
| time/                   |              |
|    fps                  | 1378         |
|    iterations           | 2            |
|    time_elapsed         | 2            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0132674305 |
|    clip_fraction        | 0.124        |
|    clip_range           | 0.2          |
|    en

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 5.06        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 994         |
|    iterations           | 11          |
|    time_elapsed         | 22          |
|    total_timesteps      | 22528       |
| train/                  |             |
|    approx_kl            | 0.003540226 |
|    clip_fraction        | 0.0196      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.14       |
|    explained_variance   | 0.923       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00064     |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.0133     |
|    value_loss           | 1.49e-05    |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 5.12

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 5           |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 980         |
|    iterations           | 21          |
|    time_elapsed         | 43          |
|    total_timesteps      | 43008       |
| train/                  |             |
|    approx_kl            | 0.004838213 |
|    clip_fraction        | 0.0333      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.0607     |
|    explained_variance   | 0.999       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.000707   |
|    n_updates            | 200         |
|    policy_gradient_loss | -0.00105    |
|    value_loss           | 1.57e-07    |
-----------------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 5 

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 5.63        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 965         |
|    iterations           | 31          |
|    time_elapsed         | 65          |
|    total_timesteps      | 63488       |
| train/                  |             |
|    approx_kl            | 0.032854073 |
|    clip_fraction        | 0.0749      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.109      |
|    explained_variance   | 0.997       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.022      |
|    n_updates            | 300         |
|    policy_gradient_loss | -0.02       |
|    value_loss           | 3.72e-08    |
-----------------------------------------
---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 5.12      

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6.66        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 974         |
|    iterations           | 41          |
|    time_elapsed         | 86          |
|    total_timesteps      | 83968       |
| train/                  |             |
|    approx_kl            | 0.030019362 |
|    clip_fraction        | 0.0463      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.0962     |
|    explained_variance   | 0.997       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0295     |
|    n_updates            | 400         |
|    policy_gradient_loss | -0.0142     |
|    value_loss           | 4.22e-07    |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 5.15    

<stable_baselines3.ppo.ppo.PPO at 0x1d436fd38e0>