Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StableBaselines3 and PettingZoo: Unrecognized type of observation <class 'tuple'> #222

Closed
adam-crowther opened this issue Jul 13, 2023 · 2 comments

Comments

@adam-crowther
Copy link

Hi,

I'm having a similar issue to Issue #220, however mine is not being resolved by using ss.flatten_v0().

I'm using the latest current versions of stable-baselines3==2.1.0a0, pettingzoo==1.23.1, supersuit==3.8.1 and gymnasium==0.28.1. I have also tried with stable-baselines3==2.0.0 and had the same issue.

I have adapted @PieroMacaluso's dummy project from Issue #169 to reproduce the issue: https://github.com/adam-crowther/test-supersuit-baseline3-pettingzoo-parallel-env

The ParallelEnv looks like this:

import random
from typing import Dict

import numpy as np
from gymnasium import spaces
from gymnasium.utils import EzPickle
from pettingzoo import ParallelEnv
from pettingzoo.utils.env import ObsDict, ActionDict


class DummyParallelEnv(ParallelEnv, EzPickle):
    metadata = {'render_modes': ['ansi'], "name": "TestParallelEnv-v0"}

    def __init__(self, n_agents: int = 20, new_step_api: bool = True) -> None:
        EzPickle.__init__(
            self,
            n_agents,
            new_step_api
        )

        self._terminated = False
        self.current_step = 0

        self.n_agents = n_agents
        self.possible_agents = [f"player_{idx}" for idx in range(n_agents)]
        self.agents = self.possible_agents[:]

        self.agent_name_mapping = dict(
            zip(self.possible_agents, list(range(len(self.possible_agents))))
        )

        self.observation_spaces = {
            agent: spaces.Box(shape=(len(self.agents),), dtype=np.float64, low=0.0, high=1.0)
            for agent in self.possible_agents
        }

        self.action_spaces = {
            agent: spaces.Discrete(4) for agent in self.possible_agents}

    def observation_space(self, agent):
        return self.observation_spaces[agent]

    def action_space(self, agent):
        return self.action_spaces[agent]

    def step(self, actions: ActionDict) \
            -> tuple[ObsDict, dict[str, float], dict[str, bool], dict[str, bool], dict[str, dict]]:
        self.current_step += 1
        self._terminated = self.current_step >= 100

        observations = self.__calculate_observations()
        rewards = {
            self.agents[agent]: random.randint(0, 100) for agent in range(len(self.agents))
        }
        terminated = {agent: self._terminated for agent in self.agents}
        truncated = {agent: False for agent in self.agents}
        infos = {agent: {} for agent in self.agents}

        if self._terminated:
            self.agents = []

        return observations, rewards, terminated, truncated, infos

    def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[ObsDict, dict[str, dict]]:
        self.agents = self.possible_agents[:]
        self._terminated = False
        self.current_step = 0
        observations = self.__calculate_observations()
        infos = {agent: {} for agent in self.agents}

        return observations, infos

    def __calculate_observations(self) -> Dict[str, np.ndarray]:
        return {
            agent: self.observation_space(agent).sample() for agent in self.agents
        }

And is executed like this:

import supersuit as ss
from pettingzoo.test import parallel_api_test
from stable_baselines3 import PPO

from dummy_env import dummy

if __name__ == '__main__':
    env_parallel = dummy.DummyParallelEnv()
    parallel_api_test(env_parallel)

    # env_parallel = ss.flatten_v0(env_parallel)
    env_parallel = ss.pettingzoo_env_to_vec_env_v1(env_parallel)
    env_parallel = ss.concat_vec_envs_v1(env_parallel, 1, base_class="stable_baselines3")

    model = PPO("MlpPolicy", env_parallel, verbose=1)
    
    model.learn(total_timesteps=10_000)

When I execute I get this error:

Traceback (most recent call last):
  File "C:\dev\repo\test-supersuit-baseline3-pettingzoo-parallel-env\main_dummy.py", line 17, in <module>
    model.learn(total_timesteps=10_000)
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\ppo\ppo.py", line 308, in learn
    return super().learn(
           ^^^^^^^^^^^^^^
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 259, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 168, in collect_rollouts
    obs_tensor = obs_as_tensor(self._last_obs, self.device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\adamcc\AppData\Roaming\Python\Python311\site-packages\stable_baselines3\common\utils.py", line 487, in obs_as_tensor
    raise Exception(f"Unrecognized type of observation {type(obs)}")
Exception: Unrecognized type of observation <class 'tuple'>

Process finished with exit code 1

Exception: Unrecognized type of observation <class 'tuple'>

If I set a breakpoint in stable_baselines3 class on_policy_algorithm.py at line 168, I see that self._last_obs is being set to the tuple of observation and info that is being returned by the ParallelEnv reset() method. obs_as_tensor() is expecting a np.ndarray.

Have I got something wrong or is there a compatibility issue here somewhere?

Thanks,

Adam

@adam-crowther
Copy link
Author

adam-crowther commented Jul 13, 2023

I created a workaround using a shim that wraps the SB3VecEnvWrapper:

class Sb3ShimWrapper(VecEnvWrapper):
    metadata = {'render_modes': ['human', 'files', 'none'], "name": "Sb3ShimWrapper-v0"}

    def __init__(self, venv):
        super().__init__(venv)

    def reset(self, seed=None, options=None):
        return self.venv.reset()[0]

    def step_wait(self) -> VecEnvStepReturn:
        return self.venv.step_wait()

As you can see it overrides the reset method and returns the first element of the tuple.

I integrate it like this:

if __name__ == '__main__':
    env_parallel = dummy.DummyParallelEnv()
    parallel_api_test(env_parallel)

    # env_parallel = ss.flatten_v0(env_parallel)
    env_parallel = ss.pettingzoo_env_to_vec_env_v1(env_parallel)
    env_parallel = ss.concat_vec_envs_v1(env_parallel, 1, base_class="stable_baselines3")
    env_parallel = Sb3ShimWrapper(env_parallel)

    model = PPO("MlpPolicy", env_parallel, verbose=1)
    
    model.learn(total_timesteps=10_000)

I will push this change to my demo repo.

Now I have a new problem with render(), which I will document in a new Issue.

@elliottower
Copy link
Member

To my knowledge, this has been fixed with #226 (I was getting the same issue, it's because they expect only an observation whereas by default PettingZoo and Gymnasium return an observation and info)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants