## Initialising the new environment

In [12]:
import gymnasium as gym
from gymnasium import spaces
import pygame
import numpy as np


class GridWorldEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None, size=5):
        self.size = size  # size of the square grid
        self.window_size = 512  # The size of the pygame window

        # Observations are dictionaries with the agent's and the target's location
        self.observation_space = spaces.Dict({
            "agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
            "target": spaces.Box(0, size-1, shape=(2,), dtype=int)
        })

        # There are 4 actions corresponding to right, up, left and down
        self.action_space = spaces.Discrete(4)

        self.action_to_direction = {
            0: np.array([1, 0]),
            1: np.array([0, 1]),
            2: np.array([-1, 0]),
            3: np.array([0, -1]),
        }

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        # For human render modes: window is the window we draw to and clock is the clock used to endure the environment runs at the proper frame rate
        self.window = None
        self.clock = None

    def _get_obs(self):
        return {"agent": self._agent_location, "target": self._target_location}

    def _get_info(self):
        return {"distance": np.linalg.norm(self._agent_location - self._target_location, ord=1)}

    def reset(self, seed=None, options=None):
        # seeding self.np_random
        super().reset(seed=seed)

        # Choosing the agent's location
        self._agent_location = self.np_random.integers(
            0, self.size, size=2, dtype=int)

        # Choosing the target location until it doesnt coincide with the agents location
        self._target_location = self._agent_location
        while np.array_equal(self._target_location, self._agent_location):
            self._target_location = self.np_random.integers(
                0, self.size, size=2, dtype=int)

        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, info

    def step(self, action):
        # Mapping the action to the direction of motion
        direction = self.action_to_direction[action]

        # We np.clip to make sure not to leave the grid
        self._agent_location = np.clip(
            self._agent_location+direction, 0, self.size-1)

        # An episode is done if the agent has reached the target
        terminated = np.array_equal(
            self._agent_location, self._target_location)
        reward = 1 if terminated else 0  # Binary sparse rewards
        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, reward, terminated, False, info

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode(
                (self.window_size, self.window_size))

        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        # size of a single grid square in pixels
        pix_square_size = (self.window_size/self.size)

        # drawing the target
        pygame.draw.rect(canvas, (255, 0, 0), pygame.rect(
            pix_square_size * self._target_location, (pix_square_size, pix_square_size),),)

        # Drawing the agent
        pygame.draw.circle(
            canvas,
            (0, 0, 255),
            (self._agent_location + 0.5) * pix_square_size,
            pix_square_size/3
        )

        #adding some gridlines
        for x in range(self.size + 1):
            pygame.draw.line(
                canvas, 
                0, 
                (0, pix_square_size * x),
                (self.window_size, pix_square_size * x), 
                width = 3,
            )
            pygame.draw.line(
                canvas,0,(pix_square_size * x,0), (pix_square_size * x, self.window_size),width=3,)
            
        if self.render_mode == "human":
            #Updating the values
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            #ensuring framerate
            self.clock.tick(self.metadata["render_fps"])
        else:  #rgb_array
            return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes = (1,0,2))
        
    def close(self): 
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()


{'agent': array([2, 2]), 'target': array([2, 0])}
{'agent': array([2, 3]), 'target': array([2, 0])} 0 False False {'distance': 3.0}


: 