Skip to content

Commit

Permalink
Bugfix for VecEnvWrapper.__getattr__ (#307)
Browse files Browse the repository at this point in the history
* fixed bug in VecEnvWrapper.__getattr__ where inherited methods were inaccessible

* improved test for VecEnvWrapper.__getattr__ to be more comprehensive

* changed test function to satisfy code checks

* updated changelog and simplified declaration of self.class_attributes

* modified getattr_depth_check for consistency and added helper method for getting all attributes
  • Loading branch information
kantneel authored and araffin committed May 8, 2019
1 parent bddd1ab commit fbd9f35
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
6 changes: 6 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Changelog

For download links, please look at `Github release page <https://github.com/hill-a/stable-baselines/releases>`_.

Release 2.5.2a0 (WIP)
--------------------

- Bugfix for ``VecEnvWrapper.__getattr__`` which enables access to class attributes inherited from parent classes.


Release 2.5.1 (2019-05-04)
--------------------------

Expand Down
21 changes: 17 additions & 4 deletions stable_baselines/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import inspect
import pickle

import cloudpickle
Expand Down Expand Up @@ -189,6 +190,7 @@ def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space)
self.class_attributes = dict(inspect.getmembers(self.__class__))

def step_async(self, actions):
self.venv.step_async(actions)
Expand Down Expand Up @@ -233,14 +235,24 @@ def __getattr__(self, name):

return self.getattr_recursive(name)

def _get_all_attributes(self):
"""Get all (inherited) instance and class attributes
:return: (dict<str, object>) all_attributes
"""
all_attributes = self.__dict__.copy()
all_attributes.update(self.class_attributes)
return all_attributes

def getattr_recursive(self, name):
"""Recursively check wrappers to find attribute.
:param name (str) name of attribute to look for
:return: (object) attribute
"""
if name in self.__dict__: # attribute is present in this wrapper
attr = self.__dict__[name]
all_attributes = self._get_all_attributes()
if name in all_attributes: # attribute is present in this wrapper
attr = getattr(self, name)
elif hasattr(self.venv, 'getattr_recursive'):
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
# to avoid a duplicate call to getattr_depth_check.
Expand All @@ -255,10 +267,11 @@ def getattr_depth_check(self, name, already_found):
:return: (str or None) name of module whose attribute is being shadowed, if any.
"""
if name in self.__dict__ and already_found:
all_attributes = self._get_all_attributes()
if name in all_attributes and already_found:
# this venv's attribute is being hidden because of a higher venv.
shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__)
elif name in self.__dict__ and not already_found:
elif name in all_attributes and not already_found:
# we have found the first reference to the attribute. Now check for duplicates.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
else:
Expand Down
17 changes: 15 additions & 2 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,27 @@ def __init__(self, venv):
VecNormalize.__init__(self, venv)
self.var_b = 'b'

def func_b(self):
return self.var_b

def name_test(self):
return self.__class__

class CustomWrapperBB(CustomWrapperB):
def __init__(self, venv):
CustomWrapperB.__init__(self, venv)
self.var_bb = 'bb'

def test_vecenv_wrapper_getattr():
def make_env():
return CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2)))
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
wrapped = CustomWrapperA(CustomWrapperB(vec_env))
assert wrapped.var_b == 'b'
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
assert wrapped.var_a == 'a'
assert wrapped.var_b == 'b'
assert wrapped.var_bb == 'bb'
assert wrapped.func_b() == 'b'
assert wrapped.name_test() == CustomWrapperBB

double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
dummy = double_wrapped.var_a # should not raise as it is directly defined here
Expand Down

0 comments on commit fbd9f35

Please sign in to comment.