Skip to content

Commit

Permalink
Upgrade black formatting (#1310)
Browse files Browse the repository at this point in the history
* apply black

* Reformat tests

---------

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
  • Loading branch information
qgallouedec and araffin committed Feb 2, 2023
1 parent bea3c44 commit 82bc63f
Show file tree
Hide file tree
Showing 22 changed files with 4 additions and 44 deletions.
3 changes: 0 additions & 3 deletions stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):

super().__init__(
policy,
env,
Expand Down Expand Up @@ -132,7 +131,6 @@ def train(self) -> None:

# This will only loop once (get all data in one go)
for rollout_data in self.rollout_buffer.get(batch_size=None):

actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
Expand Down Expand Up @@ -189,7 +187,6 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfA2C:

return super().learn(
total_timesteps=total_timesteps,
callback=callback,
Expand Down
7 changes: 0 additions & 7 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def add(
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:

# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
Expand Down Expand Up @@ -346,7 +345,6 @@ def __init__(
gamma: float = 0.99,
n_envs: int = 1,
):

super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
Expand All @@ -356,7 +354,6 @@ def __init__(
self.reset()

def reset(self) -> None:

self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
Expand Down Expand Up @@ -451,7 +448,6 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:

_tensor_names = [
"observations",
"actions",
Expand Down Expand Up @@ -688,7 +684,6 @@ def __init__(
gamma: float = 0.99,
n_envs: int = 1,
):

super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)

assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
Expand Down Expand Up @@ -763,7 +758,6 @@ def get(
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:

for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)

Expand All @@ -787,7 +781,6 @@ def _get_samples(
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME

return DictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,9 @@ def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any
self._is_success_buffer.append(maybe_is_success)

def _on_step(self) -> bool:

continue_training = True

if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:

# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def evaluate_policy(
current_lengths += 1
for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]:

# unpack values so that the callback can access the local variables
reward = rewards[i]
done = dones[i]
Expand Down
5 changes: 1 addition & 4 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
key2str = {}
tag = None
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):

if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue

Expand Down Expand Up @@ -342,7 +341,7 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for (i, key) in enumerate(self.keys):
for i, key in enumerate(self.keys):
if i > 0:
self.file.write(",")
self.file.write(key)
Expand Down Expand Up @@ -399,9 +398,7 @@ def __init__(self, folder: str):
self.writer = SummaryWriter(log_dir=folder)

def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:

for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):

if excluded is not None and "tensorboard" in excluded:
continue

Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):

super().__init__(
policy=policy,
env=env,
Expand Down Expand Up @@ -319,7 +318,6 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOffPolicyAlgorithm:

total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):

super().__init__(
policy=policy,
env=env,
Expand Down Expand Up @@ -244,7 +243,6 @@ def learn(
callback.on_training_start(locals(), globals())

while self.num_timesteps < total_timesteps:

continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)

if continue_training is False:
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ def __init__(
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):

if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/results_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def plot_curves(
plt.figure(title, figsize=figsize)
max_x = max(xy[0][-1] for xy in xy_list)
min_x = 0
for (_, (x, y)) in enumerate(xy_list):
for _, (x, y) in enumerate(xy_list):
plt.scatter(x, y, s=2)
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
if x.shape[0] >= EPISODES_WINDOW:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def load_from_zip_file(
device: Union[th.device, str] = "auto",
verbose: int = 0,
print_system_info: bool = False,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
) -> Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]:
"""
Load model data from a .zip archive
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/vec_env/stacked_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
observation_space: spaces.Space,
channels_order: Optional[str] = None,
):

self.n_stack = n_stack
(
self.channels_first,
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/vec_env/vec_frame_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[st
def step_wait(
self,
) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]:

observations, rewards, dones, infos = self.venv.step_wait()

observations, infos = self.stackedobs.update(observations, dones, infos)
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/common/vec_env/vec_video_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
video_length: int = 200,
name_prefix: str = "rl-video",
):

VecEnvWrapper.__init__(self, venv)

self.env = venv
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def __init__(
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):

super().__init__(
policy=policy,
env=env,
Expand Down Expand Up @@ -121,7 +120,6 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDDPG:

return super().learn(
total_timesteps=total_timesteps,
callback=callback,
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):

super().__init__(
policy,
env,
Expand Down Expand Up @@ -261,7 +260,6 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDQN:

return super().learn(
total_timesteps=total_timesteps,
callback=callback,
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/her/her_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(
online_sampling: bool = True,
handle_timeout_termination: bool = True,
):

super().__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs)

# convert goal_selection_strategy into GoalSelectionStrategy if string
Expand Down Expand Up @@ -389,7 +388,6 @@ def add(
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:

if self.current_idx == 0 and self.full:
# Clear info buffer
self.info_buffer[self.pos] = deque(maxlen=self.max_episode_length)
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):

super().__init__(
policy,
env,
Expand Down Expand Up @@ -303,7 +302,6 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfPPO:

return super().learn(
total_timesteps=total_timesteps,
callback=callback,
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):

super().__init__(
policy,
env,
Expand Down Expand Up @@ -295,7 +294,6 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfSAC:

return super().learn(
total_timesteps=total_timesteps,
callback=callback,
Expand Down
3 changes: 0 additions & 3 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):

super().__init__(
policy,
env,
Expand Down Expand Up @@ -151,7 +150,6 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:

actor_losses, critic_losses = [], []
for _ in range(gradient_steps):

self._n_updates += 1
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
Expand Down Expand Up @@ -210,7 +208,6 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfTD3:

return super().learn(
total_timesteps=total_timesteps,
callback=callback,
Expand Down
1 change: 0 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def test_dqn():

@pytest.mark.parametrize("train_freq", [4, (4, "step"), (1, "episode")])
def test_train_freq(tmp_path, train_freq):

model = SAC(
"MlpPolicy",
"Pendulum-v1",
Expand Down
1 change: 0 additions & 1 deletion tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,6 @@ def test_open_file_str_pathlib(tmp_path, pathtype):


def test_open_file(tmp_path):

# path must much the type
with pytest.raises(TypeError):
open_path(123, None, None, None)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _make_warmstart_dict_env(**kwargs):

def test_runningmeanstd():
"""Test RunningMeanStd object"""
for (x_1, x_2, x_3) in [
for x_1, x_2, x_3 in [
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
(np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
]:
Expand Down Expand Up @@ -336,7 +336,6 @@ def test_normalize_dict_selected_keys():
@pytest.mark.parametrize("model_class", [SAC, TD3, HerReplayBuffer])
@pytest.mark.parametrize("online_sampling", [False, True])
def test_offpolicy_normalization(model_class, online_sampling):

if online_sampling and model_class != HerReplayBuffer:
pytest.skip()

Expand Down

0 comments on commit 82bc63f

Please sign in to comment.