Skip to content

Commit

Permalink
first gym env
Browse files Browse the repository at this point in the history
  • Loading branch information
29th-Day committed Dec 23, 2023
1 parent 8b7fb2c commit 7c51c63
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/retropy/frontends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"""

from .pygame.pygame import RetroPyGame
from .gym.env import RetroGym
from .gym.gym import RetroGym
from .pyglet.pyglet import RetroPyGlet
6 changes: 6 additions & 0 deletions src/retropy/frontends/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Gym is made for single agents envs, therefore only one player is supported
# MAYBE a multi agent version (likly following PettingZoo's API) will be made eventually

# from gymnasium.envs.registration import register

# register(id="retropy/RetroGym-v0", entry_point="retropy.frontends.gym.gym:RetroGym")
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

# gymnasium environment with continuous inputs

# https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.MultiDiscrete

class RetroGym(gym.Env):
def __init__(self, core: str, rom: str, player: int = 1):
self.player = player

class RetroGym(gym.Env):
def __init__(self, core: str, rom: str):
self.core = RetroPy(core, True)
self.core.load(rom)

Expand All @@ -22,7 +22,7 @@ def __init__(self, core: str, rom: str, player: int = 1):
obs_shape = (geometry.base_height, geometry.base_width, 3)

self.observation_space = spaces.Box(0, 255, obs_shape, dtype=np.uint8)
self.action_space = spaces.Box(0, 1, (self.player, len(GamePadInput)))
self.action_space = spaces.Box(0, 1, (len(GamePadInput),))

def reset(self, seed=None, options=None):
super().reset(seed=seed)
Expand All @@ -34,29 +34,27 @@ def reset(self, seed=None, options=None):
return observation, info

def step(
self, action: Sequence[Sequence[float]]
self, action: Sequence[float]
) -> tuple[np.ndarray, float, bool, bool, dict]:
# Set input in core
for player, action in enumerate(action):
for i, name in enumerate(GamePadInput):
player[name] = action[i]
# gym API
self.__set_controller_input(action)

observation = self.core.frame_advance()
reward = self._reward_function(observation)
terminated, truncated = self._stopping_criterion()
info = {}

return observation, reward, terminated, truncated, info

# Helper functions
# Helper function

def _reward_function(self, obs: np.ndarray) -> float:
reward = 0.0
def __set_controller_input(self, action):
for input, value in zip(GamePadInput, action):
self.core.controllers[0][input] = value

return reward
# RL functions

def _stopping_criterion(self) -> tuple[bool, bool]:
terminated = False # Goal reached
truncated = False # Timesteps elapsed
def _reward_function(self, observation: np.ndarray) -> float:
raise NotImplementedError()

return terminated, truncated
def _stopping_criterion(self) -> tuple[bool, bool]:
raise NotImplementedError()
40 changes: 39 additions & 1 deletion src/retropy/frontends/gym/wrapper.py
Original file line number Diff line number Diff line change
@@ -1 +1,39 @@
# gymnasium wrapper with discrete inputs
# gymnasium wrapper with
# - discrete inputs

from .gym import RetroGym
from ...utils.input import GamePadInput

import gymnasium as gym
from gymnasium import spaces, ActionWrapper

# https://gymnasium.farama.org/tutorials/gymnasium_basics/implementing_custom_wrappers/


class DiscreteInputs(ActionWrapper):
"""
Map discrete actions into controller inputs.
Example:
>>> env = RetroGym(core, rom)
>>> action_map = [["A"], ["B"], ["LEFT"], ["RIGHT", "A"], ["LEFT_X"]]
>>> env = DiscreteInputs(env, action_map)
>>> print(env.action_space) # -> Discrete(5)
"""

def __init__(self, env: RetroGym, map: list[list[str]]):
super().__init__(env)
self.map = map
self.action_space = spaces.Discrete(len(self.map))

def action(self, action: int):
inputs = self.map[action]

action = [0] * len(GamePadInput)
for i, input in enumerate(GamePadInput):
if input in inputs:
action[i] = 1

print(action)

return action
29 changes: 5 additions & 24 deletions src/retropy/utils/input/gamepad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from enum import StrEnum


# __iter__() is *in order of declaration* and depended on
class GamePadInput(StrEnum):
"""Based on the Nintendo© button layout"""
"""Based on Nintendo© button layout"""

LEFT_X = "LEFT_X"
"""Left Control Stick - Horizontal ↔"""
Expand Down Expand Up @@ -70,30 +71,10 @@ def get_state(self, device: int, index: int, id: int) -> int:
raise ValueError(f"device ({device})")

def reset(self):
self.state = {
GamePadInput.LEFT_X: 0.0,
GamePadInput.LEFT_Y: 0.0,
GamePadInput.RIGHT_X: 0.0,
GamePadInput.RIGHT_Y: 0.0,
GamePadInput.B: 0,
GamePadInput.Y: 0,
GamePadInput.START: 0,
GamePadInput.SELECT: 0,
GamePadInput.UP: 0,
GamePadInput.DOWN: 0,
GamePadInput.LEFT: 0,
GamePadInput.RIGHT: 0,
GamePadInput.A: 0,
GamePadInput.X: 0,
GamePadInput.L1: 0,
GamePadInput.R1: 0,
GamePadInput.L2: 0,
GamePadInput.R2: 0,
GamePadInput.L3: 0,
GamePadInput.R3: 0,
}
self.state = {}

# print(self.state)
for input in GamePadInput:
self.state[input] = 0.0


RETRO_INPUT_TO_STR = {
Expand Down
10 changes: 10 additions & 0 deletions tests/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__core = "C:/Users/Gerald/AppData/Local/RetroArch/cores/"

__game = "C:/Users/Gerald/AppData/Local/RetroArch/roms/"

SYSTEMS = {
"GBA": (__core + "mgba_libretro.dll", __game + "Tetris.gb"),
"N64": (__core + "parallel_n64_libretro.dll", __game + "Super Mario 64.n64"),
}

__all__ = ["SYSTEMS"]
31 changes: 30 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,41 @@ def pyglet():

def gym():
from retropy.frontends import RetroGym
from retropy.frontends.gym.wrapper import DiscreteInputs

env = RetroGym(dll, game)

map = [["A"], ["B"], ["UP"], ["DOWN"], ["LEFT"], ["RIGHT"]]
env = DiscreteInputs(env, map)
print(env.observation_space)
print(env.action_space)

obs, info = env.reset()

done = False

while not done:
action = env.action_space.sample()
print(action)

obs, reward, term, trunc, info = env.step(action)
done = term or trunc

break


def test():
...
class Test:
_FIELDS_ = ("A", "B", "X", "Y")

def __init__(self) -> None:
for f in self._FIELDS_:
setattr(self, f, f)

def __getattr__(self, name):
...

print(Test().A)


if __name__ == "__main__":
Expand Down

0 comments on commit 7c51c63

Please sign in to comment.