In [1]:
!pip install pettingzoo gymnasium




[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import numpy as np
import functools
 
# Gym env
import gymnasium as gym
from gymnasium import spaces
from gymnasium.spaces import Discrete, MultiDiscrete
from gymnasium.utils import seeding
 
# MARL env using petting zoo (ParallelEnv: run all agents at simultaneously)
from pettingzoo.utils.env import ParallelEnv
import random
from pettingzoo.utils import parallel_to_aec, wrappers

In [3]:
def env(render_mode=None):
    """
    The env function often wraps the environment in wrappers by default.
    You can find full documentation for these methods
    elsewhere in the developer documentation.
    """
    internal_render_mode = render_mode if render_mode != "ansi" else "human"
    env = raw_env(render_mode=internal_render_mode)
    # This wrapper is only for environments which print results to the terminal
    if render_mode == "ansi":
        env = wrappers.CaptureStdoutWrapper(env)
    # this wrapper helps error handling for discrete action spaces
    env = wrappers.AssertOutOfBoundsWrapper(env)
    # Provides a wide vareity of helpful user errors
    # Strongly recommended
    env = wrappers.OrderEnforcingWrapper(env)
    return env
 
def raw_env(render_mode=None):
    """
    To support the AEC API, the raw_env() function just uses the from_parallel
    function to convert from a ParallelEnv to an AEC env
    """
    env = parallel_env(render_mode=render_mode)
    env = parallel_to_aec(env)
    return env
 
class parallel_env(ParallelEnv):
    metadata = {
        "name": "mouse_cat_cheese",
        "render_modes": ["human"],
        "render_fps": 4
    }
 
    def __init__(self, grid_size=5, vision_range=1, max_steps=50, render_mode=None):
        self.grid_size = grid_size
        self.vision_range = vision_range
        self.max_steps = max_steps
 
        self.possible_agents = ["mouse", "cat"]
        self.agents = self.possible_agents[:]
 
        self.action_spaces = {a: spaces.Discrete(4) for a in self.possible_agents}
        obs_shape = ((2 * vision_range + 1), (2 * vision_range + 1))
        self.observation_spaces = {
            a: spaces.Box(low=-1, high=2, shape=obs_shape, dtype=np.int8)
            for a in self.possible_agents
        }
 
        self.timestep = 0
        self.render_mode = render_mode
 
    def reset(self, seed=None, options=None):
        if seed is not None:
            self.np_random, self.np_random_seed = seeding.np_random(seed)
       
        self.agents = self.possible_agents[:]
        self.timestep = 0
 
        # mouse position
        self.mouse_pos = self.np_random.integers(0, self.grid_size, size=(2,), dtype=np.int32)
 
        # cat position (ensure not same as mouse)
        while True:
            self.cat_pos = self.np_random.integers(0, self.grid_size, size=(2,), dtype=np.int32)
            if not np.array_equal(self.cat_pos, self.mouse_pos):
                break
 
        # cheese position (ensure not same as mouse/cat)
        while True:
            self.cheese_pos = self.np_random.integers(0, self.grid_size, size=(2,), dtype=np.int32)
            if (
                not np.array_equal(self.cheese_pos, self.mouse_pos)
                and not np.array_equal(self.cheese_pos, self.cat_pos)
            ):
                break
 
        observations = self._get_all_obs()
 
        infos = {agent: {} for agent in self.agents}
        return observations, infos
 
    def step(self, actions):
        if not actions:
            self.agents = []
            return {}, {}, {}, {}, {}
       
        rewards = {a: -0.1 for a in self.agents}
        terminations = {agent: False for agent in self.agents}
 
        # Move mouse
        self.mouse_pos = self._move(self.mouse_pos, actions["mouse"])
 
        # Move cat
        self.cat_pos = self._move(self.cat_pos, actions["cat"])
 
        # Check win/lose conditions
        if np.array_equal(self.mouse_pos, self.cheese_pos):
            rewards = {"mouse": 1, "cat": -1}
            terminations = {a: True for a in self.agents}
 
        elif np.array_equal(self.cat_pos, self.mouse_pos):
            rewards = {"mouse": -1, "cat": 1}
            terminations = {a: True for a in self.agents}
 
        # Max steps â†’ draw
        truncations = {a: False for a in self.agents}
        if self.timestep > 100:
            rewards = {"mouse": 0, "cat": 0}
            truncations = {"mouse": True, "cat": True}
        self.timestep += 1
        terminations["__all__"] = all(terminations.values())
 
        obs = self._get_all_obs()
 
        infos = {agent: {} for agent in self.agents}
       
        if any(terminations.values()) or all(truncations.values()):
            self.agents = []
 
        if self.render_mode == "human":
            self.render()
 
        return obs, rewards, terminations, truncations, infos
 
    def _move(self, pos, action):
        x, y = pos
        temp = np.array([x, y])
        if action == 0 and y > 0:  # Up
            y -= 1
        elif action == 1 and x < self.grid_size - 1:  # Right
            x += 1
        elif action == 2 and y < self.grid_size - 1:  # Down
            y += 1
        elif action == 3 and x > 0:  # Left
            x -= 1
 
        return np.array([x, y])
 
    def _get_obs(self, agent):
        """Return partial grid view for the given agent."""
        full_grid = np.zeros((self.grid_size, self.grid_size), dtype=np.int8)
        # Encoding: 0=empty, 1=goal, 2=self, -1=unseen
        gx, gy = self.cheese_pos
        full_grid[gy][gx] = 1
        if agent == "mouse":
            ax, ay = self.mouse_pos
            ox, oy = self.cat_pos
        else:
            ax, ay = self.cat_pos
            ox, oy = self.mouse_pos
        full_grid[ay][ax] = 2  # self
        full_grid[oy][ox] = 3  # opponent
 
        vr = self.vision_range
        view = np.full((2 * vr + 1, 2 * vr + 1), -1, dtype=np.int8)
        for dx in range(-vr, vr + 1):
            for dy in range(-vr, vr + 1):
                gx_ = ax + dx
                gy_ = ay + dy
                if 0 <= gx_ < self.grid_size and 0 <= gy_ < self.grid_size:
                    view[dy + vr][dx + vr] = full_grid[gy_][gx_]
        return view
 
    def _get_all_obs(self):
        return {
            agent: self._get_obs(agent) for agent in self.agents
        }
   
    def render(self):
        if self.render_mode is None:
            gym.logger.warn(
                "You are calling render method without specifying any render mode."
            )
            return
 
        grid = np.full((self.grid_size, self.grid_size), ".", dtype='U6')
        gx, gy = self.cheese_pos
        grid[gy][gx] = "Cheese"
        rx, ry = self.mouse_pos
        grid[ry][rx] = "Mouse"
        cx, cy = self.cat_pos
        grid[cy][cx] = "Cat"
        print("\n".join("".join(f"{cell:^6}" for cell in row) for row in grid))
        print()
 
    def close(self):
        pass
 
        # Observation space should be defined here.
    # lru_cache allows observation and action spaces to be memoized, reducing clock cycles required to get each agent's space.
    # If your spaces change over time, remove this line (disable caching).
    @functools.lru_cache(maxsize=None)
    def observation_space(self, agent):
        # gymnasium spaces are defined and documented here: https://gymnasium.farama.org/api/spaces/
        return MultiDiscrete([self.grid_size * self.grid_size] * 3)
 
    # Action space should be defined here.
    # If your spaces change over time, remove this line (disable caching).
    @functools.lru_cache(maxsize=None)
    def action_space(self, agent):
        return Discrete(4)
 

In [4]:
env = env(render_mode='human')

In [5]:
env.reset(42)
env.render()

  .     .     .     .     .   
  .     .     .     .     .   
  .     .     .    Cat    .   
Mouse   .     .     .     .   
  .     .   Cheese  .     .   

