Skip to content

Commit

Permalink
Cleanup + reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed May 8, 2020
1 parent c20af23 commit 413a238
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 44 deletions.
9 changes: 4 additions & 5 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ def check_env(env: GymEnv, observation_space: gym.spaces.Space, action_space: gy
if (observation_space != env.observation_space
# Special cases for images that need to be transposed
and not (is_image_space(env.observation_space)
and observation_space == VecTransposeImage.transpose_space(env.observation_space)
)):
and observation_space == VecTransposeImage.transpose_space(env.observation_space))):
raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}')
if action_space != env.action_space:
raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}')
Expand Down Expand Up @@ -820,11 +819,11 @@ def collect_rollouts(self,
if action_noise is not None:
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
# Update(October 2019): Not anymore
clipped_action = np.clip(scaled_action + action_noise(), -1, 1)
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

# We store the scaled action in the buffer
buffer_action = clipped_action
action = self.policy.unscale_action(clipped_action)
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class KVWriter(object):
"""
Key Value writer
"""

def writekvs(self, kvs: Dict) -> None:
"""
write a dictionary to file
Expand All @@ -39,6 +40,7 @@ class SeqWriter(object):
"""
sequence writer
"""

def writeseq(self, seq: List):
"""
write an array to file
Expand All @@ -49,7 +51,7 @@ def writeseq(self, seq: List):


class HumanOutputFormat(KVWriter, SeqWriter):
def __init__(self, filename_or_file: Union [str, TextIO]):
def __init__(self, filename_or_file: Union[str, TextIO]):
"""
log to a file, in a human readable format
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def init_weights(module: nn.Module, gain: float = 1) -> None:
module.bias.data.fill_(0.0)

@staticmethod
def _dummy_schedule(progress: float) -> float:
def _dummy_schedule(_progress: float) -> float:
""" (float) Useful for pickling policy."""
return 0.0

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy

from stable_baselines3.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError,
VecEnv, VecEnvWrapper, CloudpickleWrapper)
VecEnv, VecEnvWrapper, CloudpickleWrapper)
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
class DummyVecEnv(VecEnv):
"""
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
Python process. This is useful for computationally simple environment such as ````cartpole-v1````, as the overhead of
multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that
Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
as the overhead of multiprocess or multithread outweighs the environment computation time.
This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: ([Gym Environment]) the list of environments to vectorize
Expand Down
41 changes: 22 additions & 19 deletions stable_baselines3/ppo/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import numpy as np

from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpExtractor,
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
DiagGaussianDistribution, CategoricalDistribution,
StateDependentNoiseDistribution)
DiagGaussianDistribution, CategoricalDistribution,
StateDependentNoiseDistribution)


class PPOPolicy(BasePolicy):
Expand Down Expand Up @@ -47,6 +47,7 @@ class PPOPolicy(BasePolicy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
Expand Down Expand Up @@ -122,20 +123,20 @@ def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()

data.update(dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
return data

Expand All @@ -145,7 +146,8 @@ def reset_noise(self, n_envs: int = 1) -> None:
:param n_envs: (int)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
assert isinstance(self.action_dist,
StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)

def _build(self, lr_schedule: Callable) -> None:
Expand Down Expand Up @@ -319,6 +321,7 @@ class CnnPolicy(PPOPolicy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution

# CAP the standard deviation of the actor
Expand Down
29 changes: 16 additions & 13 deletions stable_baselines3/td3/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
NatureCNN, BaseFeaturesExtractor, FlattenExtractor)
NatureCNN, BaseFeaturesExtractor, FlattenExtractor)


class Actor(BasePolicy):
Expand All @@ -24,6 +24,7 @@ class Actor(BasePolicy):
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""

def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
Expand All @@ -39,7 +40,6 @@ def __init__(self,
device=device,
squash_output=True)


self.features_extractor = features_extractor
self.normalize_images = normalize_images
self.net_arch = net_arch
Expand All @@ -55,10 +55,10 @@ def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()

data.update(dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor
))
return data

Expand Down Expand Up @@ -87,6 +87,7 @@ class Critic(BasePolicy):
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""

def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
Expand Down Expand Up @@ -141,6 +142,7 @@ class TD3Policy(BasePolicy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
Expand Down Expand Up @@ -204,13 +206,13 @@ def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()

data.update(dict(
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
return data

Expand Down Expand Up @@ -250,6 +252,7 @@ class CnnPolicy(TD3Policy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""

def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
Expand Down

0 comments on commit 413a238

Please sign in to comment.