Skip to content

Commit

Permalink
Added __getattr__ to VecEnvWrapper (#286)
Browse files Browse the repository at this point in the history
* added __getattr__ to VecEnvWrapper

* added helper methods to VecEnvWrapper.__getattr__ to prevent ambiguous attribute references

* changed string formatting to be compatible with Python 3.5

* fixed typo

* added test for VecEnvWrapper.__getattr__ and fixed code accordingly

* updated changelog

* linting

* removed unused import

* Line wrapping in changelog

* Bugfixes + code cleanup

* Linting

* Improve variable naming + docstrings

* Add whitespace
  • Loading branch information
kantneel authored and AdamGleave committed May 2, 2019
1 parent baa4aa5 commit 333c593
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ Pre-Release 2.5.1a0 (WIP)
- added support for multi env recording to ``generate_expert_traj`` (@XMaster96)
- added support for LSTM model recording to ``generate_expert_traj`` (@XMaster96)
- ``GAIL``: remove mandatory matplotlib dependency and refactor as subclass of ``TRPO`` (@kantneel and @AdamGleave)
- added ``get_attr()``, ``env_method()`` and ``set_attr()`` methods for all VecEnv.
- added ``get_attr()``, ``env_method()`` and ``set_attr()`` methods for all VecEnv.
Those methods now all accept ``indices`` keyword to select a subset of envs.
``set_attr`` now returns ``None`` rather than a list of ``None``. (@kantneel)
- ``GAIL``: ``gail.dataset.ExpertDataset` supports loading from memory rather than file, and
``gail.dataset.record_expert`` supports returning in-memory rather than saving to file.
- added support in ``VecEnvWrapper`` for accessing attributes of arbitrarily deeply nested
instances of ``VecEnvWrapper`` and ``VecEnv``. This is allowed as long as the attribute belongs
to exactly one of the nested instances i.e. it must be unambiguous. (@kantneel)
- fixed bug where result plotter would crash on very short runs (@Pastafarianist)
- added option to not trim output of result plotter by number of timesteps (@Pastafarianist)
- clarified the public interface of ``BasePolicy`` and ``ActorCriticPolicy``. **Breaking change** when using custom policies: ``masks_ph`` is now called ``dones_ph``.
Expand Down
60 changes: 60 additions & 0 deletions stable_baselines/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,18 @@ def unwrapped(self):
else:
return self

def getattr_depth_check(self, name, already_found):
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
:param name: (str) name of attribute to check for
:param already_found: (bool) whether this attribute has already been found in a wrapper
:return: (str or None) name of module whose attribute is being shadowed, if any.
"""
if hasattr(self, name) and already_found:
return "{0}.{1}".format(type(self).__module__, type(self).__name__)
else:
return None

def _get_indices(self, indices):
"""
Convert a flexibly-typed reference to environment indices to an implied list of indices.
Expand Down Expand Up @@ -207,6 +219,54 @@ def set_attr(self, attr_name, value, indices=None):
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)

def __getattr__(self, name):
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
which have unique attributes of interest.
"""
blocked_class = self.getattr_depth_check(name, already_found=False)
if blocked_class is not None:
own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
format_str = ("Error: Recursive attribute lookup for {0} from {1} is "
"ambiguous and hides attribute from {2}")
raise AttributeError(format_str.format(name, own_class, blocked_class))

return self.getattr_recursive(name)

def getattr_recursive(self, name):
"""Recursively check wrappers to find attribute.
:param name (str) name of attribute to look for
:return: (object) attribute
"""
if name in self.__dict__: # attribute is present in this wrapper
attr = self.__dict__[name]
elif hasattr(self.venv, 'getattr_recursive'):
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
# to avoid a duplicate call to getattr_depth_check.
attr = self.venv.getattr_recursive(name)
else: # attribute not present, child is an unwrapped VecEnv
attr = getattr(self.venv, name)

return attr

def getattr_depth_check(self, name, already_found):
"""See base class.
:return: (str or None) name of module whose attribute is being shadowed, if any.
"""
if name in self.__dict__ and already_found:
# this venv's attribute is being hidden because of a higher venv.
shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
elif name in self.__dict__ and not already_found:
# we have found the first reference to the attribute. Now check for duplicates.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
else:
# this wrapper does not have the attribute. Keep searching.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)

return shadowed_wrapper_class


class CloudpickleWrapper(object):
def __init__(self, var):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,32 @@ def obs_assert(obs):
with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"):
vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method')
check_vecenv_spaces(vec_env_class, space, obs_assert)


class CustomWrapperA(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
self.var_a = 'a'


class CustomWrapperB(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
self.var_b = 'b'


def test_vecenv_wrapper_getattr():
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
wrapped = CustomWrapperA(CustomWrapperB(vec_env))
assert wrapped.var_b == 'b'
assert wrapped.var_a == 'a'

double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
dummy = double_wrapped.var_a # should not raise as it is directly defined here
with pytest.raises(AttributeError): # should raise due to ambiguity
dummy = double_wrapped.var_b
with pytest.raises(AttributeError): # should raise as does not exist
dummy = double_wrapped.nonexistent_attribute
del dummy # keep linter happy

0 comments on commit 333c593

Please sign in to comment.