Skip to content

Commit

Permalink
Some typing improvements, add py.typed (#284)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
  • Loading branch information
micimize and pseudo-rnd-thoughts committed Nov 12, 2022
1 parent d87a6d3 commit fc41b5e
Show file tree
Hide file tree
Showing 58 changed files with 251 additions and 169 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ __pycache__

# Virtual environments
.env
.venv
venv
6 changes: 4 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
# TODO: change to minigrid version
# from TODO import __version__ as minigrid_version

from __future__ import annotations

import os
import sys
from typing import Any, Dict
from typing import Any

project = "MiniGrid"
copyright = "2022"
Expand Down Expand Up @@ -79,7 +81,7 @@
"dark_logo": "img/minigrid-white.svg",
"gtag": "G-FBXJQQLXKD",
}
html_context: Dict[str, Any] = {}
html_context: dict[str, Any] = {}
html_context["conf_py_path"] = "/docs/"
html_context["display_github"] = True
html_context["github_user"] = "Farama-Foundation"
Expand Down
2 changes: 2 additions & 0 deletions docs/scripts/gen_envs_display.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os

import gymnasium
Expand Down
2 changes: 2 additions & 0 deletions docs/scripts/gen_gifs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import re

Expand Down
2 changes: 2 additions & 0 deletions docs/scripts/move404.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import sys

if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions docs/scripts/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import re


Expand Down
2 changes: 2 additions & 0 deletions minigrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from gymnasium.envs.registration import register

from minigrid import minigrid_env, wrappers
Expand Down
2 changes: 2 additions & 0 deletions minigrid/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python3

from __future__ import annotations

import time

import gymnasium as gym
Expand Down
2 changes: 2 additions & 0 deletions minigrid/core/actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Enumeration of possible actions
from __future__ import annotations

from enum import IntEnum


Expand Down
2 changes: 2 additions & 0 deletions minigrid/core/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np

TILE_PIXELS = 32
Expand Down
56 changes: 31 additions & 25 deletions minigrid/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import math
from typing import Any, List, Optional, Tuple, Type
from typing import Any, Callable

import numpy as np

Expand All @@ -21,7 +23,7 @@ class Grid:
"""

# Static cache of pre-renderer tiles
tile_cache = {}
tile_cache: dict[tuple[Any, ...], Any] = {}

def __init__(self, width: int, height: int):
assert width >= 3
Expand All @@ -30,7 +32,7 @@ def __init__(self, width: int, height: int):
self.width: int = width
self.height: int = height

self.grid: List[Optional[WorldObj]] = [None] * width * height
self.grid: list[WorldObj | None] = [None] * (width * height)

def __contains__(self, key: Any) -> bool:
if isinstance(key, WorldObj):
Expand All @@ -47,25 +49,29 @@ def __contains__(self, key: Any) -> bool:
return True
return False

def __eq__(self, other: "Grid") -> bool:
def __eq__(self, other: Grid) -> bool:
grid1 = self.encode()
grid2 = other.encode()
return np.array_equal(grid2, grid1)

def __ne__(self, other: "Grid") -> bool:
def __ne__(self, other: Grid) -> bool:
return not self == other

def copy(self) -> "Grid":
def copy(self) -> Grid:
from copy import deepcopy

return deepcopy(self)

def set(self, i: int, j: int, v: Optional[WorldObj]):
assert 0 <= i < self.width
assert 0 <= j < self.height
def set(self, i: int, j: int, v: WorldObj | None):
assert (
0 <= i < self.width
), f"column index {j} outside of grid of width {self.width}"
assert (
0 <= j < self.height
), f"row index {j} outside of grid of height {self.height}"
self.grid[j * self.width + i] = v

def get(self, i: int, j: int) -> Optional[WorldObj]:
def get(self, i: int, j: int) -> WorldObj | None:
assert 0 <= i < self.width
assert 0 <= j < self.height
assert self.grid is not None
Expand All @@ -75,8 +81,8 @@ def horz_wall(
self,
x: int,
y: int,
length: Optional[int] = None,
obj_type: Type[WorldObj] = Wall,
length: int | None = None,
obj_type: Callable[[], WorldObj] = Wall,
):
if length is None:
length = self.width - x
Expand All @@ -87,8 +93,8 @@ def vert_wall(
self,
x: int,
y: int,
length: Optional[int] = None,
obj_type: Type[WorldObj] = Wall,
length: int | None = None,
obj_type: Callable[[], WorldObj] = Wall,
):
if length is None:
length = self.height - y
Expand All @@ -101,7 +107,7 @@ def wall_rect(self, x: int, y: int, w: int, h: int):
self.vert_wall(x, y, h)
self.vert_wall(x + w - 1, y, h)

def rotate_left(self) -> "Grid":
def rotate_left(self) -> Grid:
"""
Rotate the grid to the left (counter-clockwise)
"""
Expand All @@ -115,7 +121,7 @@ def rotate_left(self) -> "Grid":

return grid

def slice(self, topX: int, topY: int, width: int, height: int) -> "Grid":
def slice(self, topX: int, topY: int, width: int, height: int) -> Grid:
"""
Get a subset of the grid
"""
Expand All @@ -139,8 +145,8 @@ def slice(self, topX: int, topY: int, width: int, height: int) -> "Grid":
@classmethod
def render_tile(
cls,
obj: WorldObj,
agent_dir: Optional[int] = None,
obj: WorldObj | None,
agent_dir: int | None = None,
highlight: bool = False,
tile_size: int = TILE_PIXELS,
subdivs: int = 3,
Expand All @@ -150,7 +156,7 @@ def render_tile(
"""

# Hash map lookup key for the cache
key = (agent_dir, highlight, tile_size)
key: tuple[Any, ...] = (agent_dir, highlight, tile_size)
key = obj.encode() + key if obj else key

if key in cls.tile_cache:
Expand Down Expand Up @@ -194,9 +200,9 @@ def render_tile(
def render(
self,
tile_size: int,
agent_pos: Tuple[int, int],
agent_dir: Optional[int] = None,
highlight_mask: Optional[np.ndarray] = None,
agent_pos: tuple[int, int],
agent_dir: int | None = None,
highlight_mask: np.ndarray | None = None,
) -> np.ndarray:
"""
Render this grid at a given scale
Expand Down Expand Up @@ -235,7 +241,7 @@ def render(

return img

def encode(self, vis_mask: Optional[np.ndarray] = None) -> np.ndarray:
def encode(self, vis_mask: np.ndarray | None = None) -> np.ndarray:
"""
Produce a compact numpy encoding of the grid
"""
Expand All @@ -262,7 +268,7 @@ def encode(self, vis_mask: Optional[np.ndarray] = None) -> np.ndarray:
return array

@staticmethod
def decode(array: np.ndarray) -> Tuple["Grid", np.ndarray]:
def decode(array: np.ndarray) -> tuple[Grid, np.ndarray]:
"""
Decode an array grid encoding back into a grid
"""
Expand All @@ -282,7 +288,7 @@ def decode(array: np.ndarray) -> Tuple["Grid", np.ndarray]:

return grid, vis_mask

def process_vis(self, agent_pos: Tuple[int, int]) -> np.ndarray:
def process_vis(self, agent_pos: tuple[int, int]) -> np.ndarray:
mask = np.zeros(shape=(self.width, self.height), dtype=bool)

mask[agent_pos[0], agent_pos[1]] = True
Expand Down
8 changes: 5 additions & 3 deletions minigrid/core/mission.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Callable, Optional, Union
from __future__ import annotations

from typing import Any, Callable

from gymnasium import spaces
from gymnasium.utils import seeding
Expand Down Expand Up @@ -26,8 +28,8 @@ class MissionSpace(spaces.Space[str]):
def __init__(
self,
mission_func: Callable[..., str],
ordered_placeholders: Optional["list[list[str]]"] = None,
seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
ordered_placeholders: list[list[str]] | None = None,
seed: int | seeding.RandomNumberGenerator | None = None,
):
r"""Constructor of :class:`MissionSpace` space.
Expand Down
44 changes: 22 additions & 22 deletions minigrid/core/roomgrid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from __future__ import annotations

import numpy as np

Expand All @@ -8,7 +8,7 @@
from minigrid.minigrid_env import MiniGridEnv


def reject_next_to(env: MiniGridEnv, pos: Tuple[int, int]):
def reject_next_to(env: MiniGridEnv, pos: tuple[int, int]):
"""
Function to filter out object positions that are right next to
the agent's starting point
Expand All @@ -21,27 +21,27 @@ def reject_next_to(env: MiniGridEnv, pos: Tuple[int, int]):


class Room:
def __init__(self, top: Tuple[int, int], size: Tuple[int, int]):
def __init__(self, top: tuple[int, int], size: tuple[int, int]):
# Top-left corner and size (tuples)
self.top = top
self.size = size

# List of door objects and door positions
# Order of the doors is right, down, left, up
self.doors: List[Optional[Union[bool, Door]]] = [None] * 4
self.door_pos: List[Optional[Tuple[int, int]]] = [None] * 4
self.doors: list[bool | Door | None] = [None] * 4
self.door_pos: list[tuple[int, int] | None] = [None] * 4

# List of rooms adjacent to this one
# Order of the neighbors is right, down, left, up
self.neighbors: List[Optional[Room]] = [None] * 4
self.neighbors: list[Room | None] = [None] * 4

# Indicates if this room is behind a locked door
self.locked: bool = False

# List of objects contained
self.objs: List[WorldObj] = []
self.objs: list[WorldObj] = []

def rand_pos(self, env: MiniGridEnv) -> Tuple[int, int]:
def rand_pos(self, env: MiniGridEnv) -> tuple[int, int]:
topX, topY = self.top
sizeX, sizeY = self.size
return env._randPos(topX + 1, topX + sizeX - 1, topY + 1, topY + sizeY - 1)
Expand Down Expand Up @@ -180,7 +180,7 @@ def _gen_grid(self, width, height):

def place_in_room(
self, i: int, j: int, obj: WorldObj
) -> Tuple[WorldObj, Tuple[int, int]]:
) -> tuple[WorldObj, tuple[int, int]]:
"""
Add an existing object to room (i, j)
"""
Expand All @@ -199,9 +199,9 @@ def add_object(
self,
i: int,
j: int,
kind: Optional[str] = None,
color: Optional[str] = None,
) -> Tuple[WorldObj, Tuple[int, int]]:
kind: str | None = None,
color: str | None = None,
) -> tuple[WorldObj, tuple[int, int]]:
"""
Add a new object to room (i, j)
"""
Expand Down Expand Up @@ -231,10 +231,10 @@ def add_door(
self,
i: int,
j: int,
door_idx: Optional[int] = None,
color: Optional[str] = None,
locked: Optional[bool] = None,
) -> Tuple[Door, Tuple[int, int]]:
door_idx: int | None = None,
color: str | None = None,
locked: bool | None = None,
) -> tuple[Door, tuple[int, int]]:
"""
Add a door to a room, connecting it to a neighbor
"""
Expand Down Expand Up @@ -311,7 +311,7 @@ def remove_wall(self, i: int, j: int, wall_idx: int):
neighbor.doors[(wall_idx + 2) % 4] = True

def place_agent(
self, i: Optional[int] = None, j: Optional[int] = None, rand_dir: bool = True
self, i: int | None = None, j: int | None = None, rand_dir: bool = True
) -> np.ndarray:
"""
Place the agent in a room
Expand All @@ -334,8 +334,8 @@ def place_agent(
return self.agent_pos

def connect_all(
self, door_colors: List[str] = COLOR_NAMES, max_itrs: int = 5000
) -> List[Door]:
self, door_colors: list[str] = COLOR_NAMES, max_itrs: int = 5000
) -> list[Door]:
"""
Make sure that all rooms are reachable by the agent from its
starting position
Expand Down Expand Up @@ -395,11 +395,11 @@ def find_reach():

def add_distractors(
self,
i: Optional[int] = None,
j: Optional[int] = None,
i: int | None = None,
j: int | None = None,
num_distractors: int = 10,
all_unique: bool = True,
) -> List[WorldObj]:
) -> list[WorldObj]:
"""
Add random objects that can potentially distract/confuse the agent.
"""
Expand Down

0 comments on commit fc41b5e

Please sign in to comment.