Skip to content

Commit

Permalink
Merge pull request #91 from DLR-RM/base-review-2
Browse files Browse the repository at this point in the history
Refactor BasePolicy into BaseModel + other minor changes
  • Loading branch information
AdamGleave committed Jul 8, 2020
2 parents c39ed39 + fc0b0b8 commit 758b140
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 82 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ New Features:
when ``psutil`` is available
- Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped)
- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped)
- Introduced ``BaseModel`` abstract parent for ``BasePolicy``, which critics inherit from.

Bug Fixes:
^^^^^^^^^^
Expand Down
10 changes: 6 additions & 4 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}")

# check if observation space and action space are part of the saved parameters
if ("observation_space" not in data or "action_space" not in data) and "env" not in data:
raise ValueError("The observation_space and action_space was not given, can't verify new environments")
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# check if given env is valid
if env is not None:
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
Expand Down Expand Up @@ -425,8 +425,10 @@ def _setup_learn(self,
:return: (Tuple[int, BaseCallback])
"""
self.start_time = time.time()
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)

if self.action_noise is not None:
self.action_noise.reset()
Expand Down
14 changes: 14 additions & 0 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ class Distribution(ABC):
def __init__(self):
super(Distribution, self).__init__()

@abstractmethod
def proba_distribution_net(self, *args, **kwargs):
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
concrete classes."""

@abstractmethod
def proba_distribution(self, *args, **kwargs) -> 'Distribution':
"""Set parameters of the distribution.
:return: (Distribution) self
"""

@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
Expand Down
173 changes: 95 additions & 78 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
StateDependentNoiseDistribution)


class BasePolicy(nn.Module, ABC):
class BaseModel(nn.Module, ABC):
"""
The base policy object
The base model object: makes predictions in response to observations.
In the case of policies, the prediction is an action. In the case of critics, it is the
estimated value of the observation.
:param observation_space: (gym.spaces.Space) The observation space of the environment
:param action_space: (gym.spaces.Space) The action space of the environment
Expand All @@ -39,8 +42,6 @@ class BasePolicy(nn.Module, ABC):
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param squash_output: (bool) For continuous actions, whether the output is squashed
or not using a ``tanh()`` function.
"""

def __init__(self,
Expand All @@ -52,9 +53,8 @@ def __init__(self,
features_extractor: Optional[nn.Module] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
squash_output: bool = False):
super(BasePolicy, self).__init__()
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(BaseModel, self).__init__()

if optimizer_kwargs is None:
optimizer_kwargs = {}
Expand All @@ -67,7 +67,6 @@ def __init__(self,
self.device = get_device(device)
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self._squash_output = squash_output

self.optimizer_class = optimizer_class
self.optimizer_kwargs = optimizer_kwargs
Expand All @@ -76,6 +75,10 @@ def __init__(self,
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs

@abstractmethod
def forward(self, *args, **kwargs):
del args, kwargs

def extract_features(self, obs: th.Tensor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
Expand All @@ -87,9 +90,88 @@ def extract_features(self, obs: th.Tensor) -> th.Tensor:
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return self.features_extractor(preprocessed_obs)

def _get_data(self) -> Dict[str, Any]:
"""
Get data that need to be saved in order to re-create the model.
This corresponds to the arguments of the constructor.
:return: (Dict[str, Any])
"""
return dict(
observation_space=self.observation_space,
action_space=self.action_space,
# Passed to the constructor by child class
# squash_output=self.squash_output,
# features_extractor=self.features_extractor
normalize_images=self.normalize_images,
)

def save(self, path: str) -> None:
"""
Save model to a given location.
:param path: (str)
"""
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)

@classmethod
def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel':
"""
Load model from path.
:param path: (str)
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
:return: (BasePolicy)
"""
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Create policy object
model = cls(**saved_variables['data'])
# Load weights
model.load_state_dict(saved_variables['state_dict'])
model.to(device)
return model

def load_from_vector(self, vector: np.ndarray):
"""
Load parameters from a 1D vector.
:param vector: (np.ndarray)
"""
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())

def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.
:return: (np.ndarray)
"""
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()


class BasePolicy(BaseModel):
"""The base policy object.
Parameters are mostly the same as `BaseModel`; additions are documented below.
:param args: positional arguments passed through to `BaseModel`.
:param kwargs: keyword arguments passed through to `BaseModel`.
:param squash_output: (bool) For continuous actions, whether the output is squashed
or not using a ``tanh()`` function.
"""
def __init__(self, *args, squash_output: bool = False, **kwargs):
super(BasePolicy, self).__init__(*args, **kwargs)
self._squash_output = squash_output

@staticmethod
def _dummy_schedule(progress_remaining: float) -> float:
""" (float) Useful for pickling policy."""
del progress_remaining
return 0.0

@property
def squash_output(self) -> bool:
""" (bool) Getter for squash_output."""
"""(bool) Getter for squash_output."""
return self._squash_output

@staticmethod
Expand All @@ -101,16 +183,7 @@ def init_weights(module: nn.Module, gain: float = 1) -> None:
nn.init.orthogonal_(module.weight, gain=gain)
module.bias.data.fill_(0.0)

@staticmethod
def _dummy_schedule(progress_remaining: float) -> float:
""" (float) Useful for pickling policy."""
del progress_remaining
return 0.0

@abstractmethod
def forward(self, *args, **kwargs):
del args, kwargs

def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
Expand All @@ -122,7 +195,6 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
:param deterministic: (bool) Whether to use stochastic or deterministic actions
:return: (th.Tensor) Taken action according to the policy
"""
raise NotImplementedError()

def predict(self,
observation: np.ndarray,
Expand All @@ -140,6 +212,7 @@ def predict(self,
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
Expand Down Expand Up @@ -204,64 +277,6 @@ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))

def _get_data(self) -> Dict[str, Any]:
"""
Get data that need to be saved in order to re-create the policy.
This corresponds to the arguments of the constructor.
:return: (Dict[str, Any])
"""
return dict(
observation_space=self.observation_space,
action_space=self.action_space,
# Passed to the constructor by child class
# squash_output=self.squash_output,
# features_extractor=self.features_extractor
normalize_images=self.normalize_images,
)

def save(self, path: str) -> None:
"""
Save policy to a given location.
:param path: (str)
"""
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)

@classmethod
def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy':
"""
Load policy from path.
:param path: (str)
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
:return: (BasePolicy)
"""
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Create policy object
model = cls(**saved_variables['data'])
# Load weights
model.load_state_dict(saved_variables['state_dict'])
model.to(device)
return model

def load_from_vector(self, vector: np.ndarray):
"""
Load parameters from a 1D vector.
:param vector: (np.ndarray)
"""
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())

def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.
:return: (np.ndarray)
"""
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()


class ActorCriticPolicy(BasePolicy):
"""
Expand Down Expand Up @@ -438,6 +453,8 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None:
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, BernoulliDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
else:
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")

self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
Expand Down Expand Up @@ -626,7 +643,7 @@ def __init__(self,
optimizer_kwargs)


class ContinuousCritic(BasePolicy):
class ContinuousCritic(BaseModel):
"""
Critic network(s) for DDPG/SAC/TD3.
It represents the action-state value function (Q-value function).
Expand Down

0 comments on commit 758b140

Please sign in to comment.